summaryrefslogtreecommitdiffstats
path: root/modules/mbedtls
diff options
context:
space:
mode:
Diffstat (limited to 'modules/mbedtls')
-rw-r--r--modules/mbedtls/crypto_mbedtls.cpp18
-rw-r--r--modules/mbedtls/packet_peer_mbed_dtls.cpp4
-rw-r--r--modules/mbedtls/stream_peer_mbedtls.cpp4
3 files changed, 18 insertions, 8 deletions
diff --git a/modules/mbedtls/crypto_mbedtls.cpp b/modules/mbedtls/crypto_mbedtls.cpp
index 47c0dc9bb6..381ed42fe1 100644
--- a/modules/mbedtls/crypto_mbedtls.cpp
+++ b/modules/mbedtls/crypto_mbedtls.cpp
@@ -256,7 +256,7 @@ Error HMACContextMbedTLS::start(HashingContext::HashType p_hash_type, PackedByte
}
Error HMACContextMbedTLS::update(PackedByteArray p_data) {
- ERR_FAIL_COND_V_MSG(ctx == nullptr, ERR_INVALID_DATA, "Start must be called before update.");
+ ERR_FAIL_NULL_V_MSG(ctx, ERR_INVALID_DATA, "Start must be called before update.");
ERR_FAIL_COND_V_MSG(p_data.is_empty(), ERR_INVALID_PARAMETER, "Src must not be empty.");
@@ -265,7 +265,7 @@ Error HMACContextMbedTLS::update(PackedByteArray p_data) {
}
PackedByteArray HMACContextMbedTLS::finish() {
- ERR_FAIL_COND_V_MSG(ctx == nullptr, PackedByteArray(), "Start must be called before finish.");
+ ERR_FAIL_NULL_V_MSG(ctx, PackedByteArray(), "Start must be called before finish.");
ERR_FAIL_COND_V_MSG(hash_len == 0, PackedByteArray(), "Unsupported hash type.");
PackedByteArray out;
@@ -342,7 +342,7 @@ void CryptoMbedTLS::load_default_certificates(String p_path) {
ERR_FAIL_COND(default_certs != nullptr);
default_certs = memnew(X509CertificateMbedTLS);
- ERR_FAIL_COND(default_certs == nullptr);
+ ERR_FAIL_NULL(default_certs);
if (!p_path.is_empty()) {
// Use certs defined in project settings.
@@ -419,9 +419,19 @@ Ref<X509Certificate> CryptoMbedTLS::generate_self_signed_certificate(Ref<CryptoK
}
PackedByteArray CryptoMbedTLS::generate_random_bytes(int p_bytes) {
+ ERR_FAIL_COND_V(p_bytes < 0, PackedByteArray());
PackedByteArray out;
out.resize(p_bytes);
- mbedtls_ctr_drbg_random(&ctr_drbg, out.ptrw(), p_bytes);
+ int left = p_bytes;
+ int pos = 0;
+ // Ensure we generate random in chunks of no more than MBEDTLS_CTR_DRBG_MAX_REQUEST bytes or mbedtls_ctr_drbg_random will fail.
+ while (left > 0) {
+ int to_read = MIN(left, MBEDTLS_CTR_DRBG_MAX_REQUEST);
+ int ret = mbedtls_ctr_drbg_random(&ctr_drbg, out.ptrw() + pos, to_read);
+ ERR_FAIL_COND_V_MSG(ret != 0, PackedByteArray(), vformat("Failed to generate %d random bytes(s). Error: %d.", p_bytes, ret));
+ left -= to_read;
+ pos += to_read;
+ }
return out;
}
diff --git a/modules/mbedtls/packet_peer_mbed_dtls.cpp b/modules/mbedtls/packet_peer_mbed_dtls.cpp
index ed1a97cc2c..c7373481ca 100644
--- a/modules/mbedtls/packet_peer_mbed_dtls.cpp
+++ b/modules/mbedtls/packet_peer_mbed_dtls.cpp
@@ -40,7 +40,7 @@ int PacketPeerMbedDTLS::bio_send(void *ctx, const unsigned char *buf, size_t len
PacketPeerMbedDTLS *sp = static_cast<PacketPeerMbedDTLS *>(ctx);
- ERR_FAIL_COND_V(sp == nullptr, 0);
+ ERR_FAIL_NULL_V(sp, 0);
Error err = sp->base->put_packet((const uint8_t *)buf, len);
if (err == ERR_BUSY) {
@@ -58,7 +58,7 @@ int PacketPeerMbedDTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) {
PacketPeerMbedDTLS *sp = static_cast<PacketPeerMbedDTLS *>(ctx);
- ERR_FAIL_COND_V(sp == nullptr, 0);
+ ERR_FAIL_NULL_V(sp, 0);
int pc = sp->base->get_available_packet_count();
if (pc == 0) {
diff --git a/modules/mbedtls/stream_peer_mbedtls.cpp b/modules/mbedtls/stream_peer_mbedtls.cpp
index a9d187bd64..a359b42041 100644
--- a/modules/mbedtls/stream_peer_mbedtls.cpp
+++ b/modules/mbedtls/stream_peer_mbedtls.cpp
@@ -40,7 +40,7 @@ int StreamPeerMbedTLS::bio_send(void *ctx, const unsigned char *buf, size_t len)
StreamPeerMbedTLS *sp = static_cast<StreamPeerMbedTLS *>(ctx);
- ERR_FAIL_COND_V(sp == nullptr, 0);
+ ERR_FAIL_NULL_V(sp, 0);
int sent;
Error err = sp->base->put_partial_data((const uint8_t *)buf, len, sent);
@@ -60,7 +60,7 @@ int StreamPeerMbedTLS::bio_recv(void *ctx, unsigned char *buf, size_t len) {
StreamPeerMbedTLS *sp = static_cast<StreamPeerMbedTLS *>(ctx);
- ERR_FAIL_COND_V(sp == nullptr, 0);
+ ERR_FAIL_NULL_V(sp, 0);
int got;
Error err = sp->base->get_partial_data((uint8_t *)buf, len, got);