diff options
Diffstat (limited to 'modules/mbedtls')
-rw-r--r-- | modules/mbedtls/crypto_mbedtls.cpp | 18 | ||||
-rw-r--r-- | modules/mbedtls/packet_peer_mbed_dtls.cpp | 4 | ||||
-rw-r--r-- | modules/mbedtls/stream_peer_mbedtls.cpp | 4 |
3 files changed, 18 insertions, 8 deletions
diff --git a/modules/mbedtls/crypto_mbedtls.cpp b/modules/mbedtls/crypto_mbedtls.cpp index 47c0dc9bb6..381ed42fe1 100644 --- a/modules/mbedtls/crypto_mbedtls.cpp +++ b/modules/mbedtls/crypto_mbedtls.cpp @@ -256,7 +256,7 @@ Error HMACContextMbedTLS::start(HashingContext::HashType p_hash_type, PackedByte } Error HMACContextMbedTLS::update(PackedByteArray p_data) { - ERR_FAIL_COND_V_MSG(ctx == nullptr, ERR_INVALID_DATA, "Start must be called before update."); + ERR_FAIL_NULL_V_MSG(ctx, ERR_INVALID_DATA, "Start must be called before update."); ERR_FAIL_COND_V_MSG(p_data.is_empty(), ERR_INVALID_PARAMETER, "Src must not be empty."); @@ -265,7 +265,7 @@ Error HMACContextMbedTLS::update(PackedByteArray p_data) { } PackedByteArray HMACContextMbedTLS::finish() { - ERR_FAIL_COND_V_MSG(ctx == nullptr, PackedByteArray(), "Start must be called before finish."); + ERR_FAIL_NULL_V_MSG(ctx, PackedByteArray(), "Start must be called before finish."); ERR_FAIL_COND_V_MSG(hash_len == 0, PackedByteArray(), "Unsupported hash type."); PackedByteArray out; @@ -342,7 +342,7 @@ void CryptoMbedTLS::load_default_certificates(String p_path) { ERR_FAIL_COND(default_certs != nullptr); default_certs = memnew(X509CertificateMbedTLS); - ERR_FAIL_COND(default_certs == nullptr); + ERR_FAIL_NULL(default_certs); if (!p_path.is_empty()) { // Use certs defined in project settings. @@ -419,9 +419,19 @@ Ref<X509Certificate> CryptoMbedTLS::generate_self_signed_certificate(Ref<CryptoK } PackedByteArray CryptoMbedTLS::generate_random_bytes(int p_bytes) { + ERR_FAIL_COND_V(p_bytes < 0, PackedByteArray()); PackedByteArray out; out.resize(p_bytes); - mbedtls_ctr_drbg_random(&ctr_drbg, out.ptrw(), p_bytes); + int left = p_bytes; + int pos = 0; + // Ensure we generate random in chunks of no more than MBEDTLS_CTR_DRBG_MAX_REQUEST bytes or mbedtls_ctr_drbg_random will fail. + while (left > 0) { + int to_read = MIN(left, MBEDTLS_CTR_DRBG_MAX_REQUEST); + int ret = mbedtls_ctr_drbg_random(&ctr_drbg, out.ptrw() + pos, to_read); + ERR_FAIL_COND_V_MSG(ret != 0, PackedByteArray(), vformat("Failed to generate %d random bytes(s). Error: %d.", p_bytes, ret)); + left -= to_read; + pos += to_read; + } return out; } diff --git a/modules/mbedtls/packet_peer_mbed_dtls.cpp b/modules/mbedtls/packet_peer_mbed_dtls.cpp index ed1a97cc2c..c7373481ca 100644 --- a/modules/mbedtls/packet_peer_mbed_dtls.cpp +++ b/modules/mbedtls/packet_peer_mbed_dtls.cpp @@ -40,7 +40,7 @@ int PacketPeerMbedDTLS::bio_send(void *ctx, const unsigned char *buf, size_t len PacketPeerMbedDTLS *sp = static_cast<PacketPeerMbedDTLS *>(ctx); - ERR_FAIL_COND_V(sp == nullptr, 0); + ERR_FAIL_NULL_V(sp, 0); Error err = sp->base->put_packet((const uint8_t *)buf, len); if (err == ERR_BUSY) { @@ -58,7 +58,7 @@ int PacketPeerMbedDTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) { PacketPeerMbedDTLS *sp = static_cast<PacketPeerMbedDTLS *>(ctx); - ERR_FAIL_COND_V(sp == nullptr, 0); + ERR_FAIL_NULL_V(sp, 0); int pc = sp->base->get_available_packet_count(); if (pc == 0) { diff --git a/modules/mbedtls/stream_peer_mbedtls.cpp b/modules/mbedtls/stream_peer_mbedtls.cpp index a9d187bd64..a359b42041 100644 --- a/modules/mbedtls/stream_peer_mbedtls.cpp +++ b/modules/mbedtls/stream_peer_mbedtls.cpp @@ -40,7 +40,7 @@ int StreamPeerMbedTLS::bio_send(void *ctx, const unsigned char *buf, size_t len) StreamPeerMbedTLS *sp = static_cast<StreamPeerMbedTLS *>(ctx); - ERR_FAIL_COND_V(sp == nullptr, 0); + ERR_FAIL_NULL_V(sp, 0); int sent; Error err = sp->base->put_partial_data((const uint8_t *)buf, len, sent); @@ -60,7 +60,7 @@ int StreamPeerMbedTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) { StreamPeerMbedTLS *sp = static_cast<StreamPeerMbedTLS *>(ctx); - ERR_FAIL_COND_V(sp == nullptr, 0); + ERR_FAIL_NULL_V(sp, 0); int got; Error err = sp->base->get_partial_data((uint8_t *)buf, len, got); |