Check key type against selected key exchange
diff --git a/include/polarssl/ssl_ciphersuites.h b/include/polarssl/ssl_ciphersuites.h
index 714cdcd..85392c1 100644
--- a/include/polarssl/ssl_ciphersuites.h
+++ b/include/polarssl/ssl_ciphersuites.h
@@ -27,6 +27,7 @@
#ifndef POLARSSL_SSL_CIPHERSUITES_H
#define POLARSSL_SSL_CIPHERSUITES_H
+#include "pk.h"
#include "cipher.h"
#include "md.h"
@@ -197,6 +198,8 @@
const ssl_ciphersuite_t *ssl_ciphersuite_from_string( const char *ciphersuite_name );
const ssl_ciphersuite_t *ssl_ciphersuite_from_id( int ciphersuite_id );
+pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info );
+
#ifdef __cplusplus
}
#endif
diff --git a/library/ssl_ciphersuites.c b/library/ssl_ciphersuites.c
index 63601f6..759845e 100644
--- a/library/ssl_ciphersuites.c
+++ b/library/ssl_ciphersuites.c
@@ -916,4 +916,20 @@
return( cur->id );
}
+pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info )
+{
+ switch( info->key_exchange )
+ {
+ case POLARSSL_KEY_EXCHANGE_DHE_RSA:
+ case POLARSSL_KEY_EXCHANGE_ECDHE_RSA:
+ return( POLARSSL_PK_RSA );
+
+ case POLARSSL_KEY_EXCHANGE_ECDHE_ECDSA:
+ return( POLARSSL_PK_ECDSA );
+
+ default:
+ return( POLARSSL_PK_NONE );
+ }
+}
+
#endif
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 267e385..605d466 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1394,6 +1394,19 @@
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
}
+ if( pk_alg != POLARSSL_PK_NONE )
+ {
+ if( pk_alg != ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info ) )
+ {
+ SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
+ return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+ }
+ }
+ else
+ {
+ pk_alg = ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info );
+ }
+
sig_len = ( p[0] << 8 ) | p[1];
p += 2;