Make ssl_cookie.c thread-safe
diff --git a/library/ssl_cookie.c b/library/ssl_cookie.c
index 8b993ac..cc88905 100644
--- a/library/ssl_cookie.c
+++ b/library/ssl_cookie.c
@@ -83,6 +83,10 @@
ctx->serial = 0;
#endif
ctx->timeout = MBEDTLS_SSL_COOKIE_TIMEOUT;
+
+#if defined(MBEDTLS_THREADING_C)
+ mbedtls_mutex_init( &ctx->mutex );
+#endif
}
void mbedtls_ssl_cookie_set_timeout( mbedtls_ssl_cookie_ctx *ctx, unsigned long delay )
@@ -93,6 +97,12 @@
void mbedtls_ssl_cookie_free( mbedtls_ssl_cookie_ctx *ctx )
{
mbedtls_md_free( &ctx->hmac_ctx );
+
+#if defined(MBEDTLS_THREADING_C)
+ mbedtls_mutex_init( &ctx->mutex );
+#endif
+
+ mbedtls_zeroize( ctx, sizeof( mbedtls_ssl_cookie_ctx ) );
}
int mbedtls_ssl_cookie_setup( mbedtls_ssl_cookie_ctx *ctx,
@@ -152,6 +162,7 @@
unsigned char **p, unsigned char *end,
const unsigned char *cli_id, size_t cli_id_len )
{
+ int ret;
mbedtls_ssl_cookie_ctx *ctx = (mbedtls_ssl_cookie_ctx *) p_ctx;
unsigned long t;
@@ -173,8 +184,21 @@
(*p)[3] = (unsigned char)( t );
*p += 4;
- return( ssl_cookie_hmac( &ctx->hmac_ctx, *p - 4,
- p, end, cli_id, cli_id_len ) );
+#if defined(MBEDTLS_THREADING_C)
+ if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
+ return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + ret );
+#endif
+
+ ret = ssl_cookie_hmac( &ctx->hmac_ctx, *p - 4,
+ p, end, cli_id, cli_id_len );
+
+#if defined(MBEDTLS_THREADING_C)
+ if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 )
+ return( MBEDTLS_ERR_SSL_INTERNAL_ERROR +
+ MBEDTLS_ERR_THREADING_MUTEX_ERROR );
+#endif
+
+ return( ret );
}
/*
@@ -185,6 +209,7 @@
const unsigned char *cli_id, size_t cli_id_len )
{
unsigned char ref_hmac[COOKIE_HMAC_LEN];
+ int ret = 0;
unsigned char *p = ref_hmac;
mbedtls_ssl_cookie_ctx *ctx = (mbedtls_ssl_cookie_ctx *) p_ctx;
unsigned long cur_time, cookie_time;
@@ -195,10 +220,24 @@
if( cookie_len != COOKIE_LEN )
return( -1 );
+#if defined(MBEDTLS_THREADING_C)
+ if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 )
+ return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + ret );
+#endif
+
if( ssl_cookie_hmac( &ctx->hmac_ctx, cookie,
&p, p + sizeof( ref_hmac ),
cli_id, cli_id_len ) != 0 )
- return( -1 );
+ ret = -1;
+
+#if defined(MBEDTLS_THREADING_C)
+ if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 )
+ return( MBEDTLS_ERR_SSL_INTERNAL_ERROR +
+ MBEDTLS_ERR_THREADING_MUTEX_ERROR );
+#endif
+
+ if( ret != 0 )
+ return( ret );
if( mbedtls_ssl_safer_memcmp( cookie + 4, ref_hmac, sizeof( ref_hmac ) ) != 0 )
return( -1 );