Use a specific function in the PSK callback
diff --git a/library/ssl_tls.c b/library/ssl_tls.c index c599761..d3ec5dc 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c
@@ -1066,6 +1066,15 @@ { unsigned char *p = ssl->handshake->premaster; unsigned char *end = p + sizeof( ssl->handshake->premaster ); + const unsigned char *psk = ssl->conf->psk; + size_t psk_len = ssl->conf->psk_len; + + /* If the psk callback was called, use its result */ + if( ssl->handshake->psk != NULL ) + { + psk = ssl->handshake->psk; + psk_len = ssl->handshake->psk_len; + } /* * PMS = struct { @@ -1077,12 +1086,12 @@ #if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED) if( key_ex == MBEDTLS_KEY_EXCHANGE_PSK ) { - if( end - p < 2 + (int) ssl->conf->psk_len ) + if( end - p < 2 + (int) psk_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 ); - *(p++) = (unsigned char)( ssl->conf->psk_len ); - p += ssl->conf->psk_len; + *(p++) = (unsigned char)( psk_len >> 8 ); + *(p++) = (unsigned char)( psk_len ); + p += psk_len; } else #endif /* MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */ @@ -1149,13 +1158,13 @@ } /* opaque psk<0..2^16-1>; */ - if( end - p < 2 + (int) ssl->conf->psk_len ) + if( end - p < 2 + (int) psk_len ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 ); - *(p++) = (unsigned char)( ssl->conf->psk_len ); - memcpy( p, ssl->conf->psk, ssl->conf->psk_len ); - p += ssl->conf->psk_len; + *(p++) = (unsigned char)( psk_len >> 8 ); + *(p++) = (unsigned char)( psk_len ); + memcpy( p, psk, psk_len ); + p += psk_len; ssl->handshake->pmslen = p - ssl->handshake->premaster; @@ -5353,8 +5362,9 @@ #endif /* MBEDTLS_X509_CRT_PARSE_C */ #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) -int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len, - const unsigned char *psk_identity, size_t psk_identity_len ) +int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len, + const unsigned char *psk_identity, size_t psk_identity_len ) { if( psk == NULL || psk_identity == NULL ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); @@ -5385,6 +5395,31 @@ return( 0 ); } +int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl, + const unsigned char *psk, size_t psk_len ) +{ + if( psk == NULL || ssl->handshake == NULL ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( psk_len > MBEDTLS_PSK_MAX_LEN ) + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + + if( ssl->handshake->psk != NULL ) + mbedtls_free( ssl->conf->psk ); + + if( ( ssl->handshake->psk = mbedtls_malloc( psk_len ) ) == NULL ) + { + mbedtls_free( ssl->handshake->psk ); + ssl->handshake->psk = NULL; + return( MBEDTLS_ERR_SSL_MALLOC_FAILED ); + } + + ssl->handshake->psk_len = psk_len; + memcpy( ssl->handshake->psk, psk, ssl->handshake->psk_len ); + + return( 0 ); +} + void mbedtls_ssl_set_psk_cb( mbedtls_ssl_config *conf, int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *, size_t), @@ -6441,6 +6476,14 @@ mbedtls_free( (void *) handshake->curves ); #endif +#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) + if( handshake->psk != NULL ) + { + mbedtls_zeroize( handshake->psk, handshake->psk_len ); + mbedtls_free( handshake->psk ); + } +#endif + #if defined(MBEDTLS_X509_CRT_PARSE_C) && \ defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) /*