PK: add nice interface functions
Also fix a const-corectness issue.
diff --git a/include/polarssl/error.h b/include/polarssl/error.h
index 45a6640..10e68f8 100644
--- a/include/polarssl/error.h
+++ b/include/polarssl/error.h
@@ -84,7 +84,7 @@
* ECP 4 4 (Started from top)
* MD 5 4
* CIPHER 6 5
- * SSL 6 5 (Started from top)
+ * SSL 6 6 (Started from top)
* SSL 7 31
*
* Module dependent error code (5 bits 0x.08.-0x.F8.)
diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h
index 83dc2d0..4f9fdb1 100644
--- a/include/polarssl/pk.h
+++ b/include/polarssl/pk.h
@@ -93,7 +93,7 @@
const char *name;
/** Get key size in bits */
- size_t (*get_size)( void * );
+ size_t (*get_size)( const void * );
/** Tell if the context implements this type (eg ECKEY can do ECDSA) */
int (*can_do)( pk_type_t type );
@@ -146,6 +146,42 @@
*/
int pk_set_type( pk_context *ctx, pk_type_t type );
+/**
+ * \brief Get the size in bits of the underlying key
+ *
+ * \param ctx Context to use
+ *
+ * \return Key size in bits, or 0 on error
+ */
+size_t pk_get_size( const pk_context *ctx );
+
+/**
+ * \brief Tell if a context can do the operation given by type
+ *
+ * \param ctx Context to test
+ * \param type Target type
+ *
+ * \return 0 if context can't do the operations,
+ * 1 otherwise.
+ */
+int pk_can_do( pk_context *ctx, pk_type_t type );
+
+/**
+ * \brief Verify signature
+ *
+ * \param ctx PK context to use
+ * \param hash Hash of the message to sign
+ * \param md_info Information about the hash function used
+ * \param sig Signature to verify
+ * \param sig_len Signature length
+ *
+ * \return 0 on success (signature is valid),
+ * or a specific error code.
+ */
+int pk_verify( pk_context *ctx,
+ const unsigned char *hash, const md_info_t *md_info,
+ const unsigned char *sig, size_t sig_len );
+
#ifdef __cplusplus
}
#endif
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index 7a468c4..d5a2fc0 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -110,7 +110,7 @@
#define POLARSSL_ERR_SSL_BAD_HS_PROTOCOL_VERSION -0x6E80 /**< Handshake protocol not within min/max boundaries */
#define POLARSSL_ERR_SSL_BAD_HS_NEW_SESSION_TICKET -0x6E00 /**< Processing of the NewSessionTicket handshake message failed. */
#define POLARSSL_ERR_SSL_SESSION_TICKET_EXPIRED -0x6D80 /**< Session ticket has expired. */
-
+#define POLARSSL_ERR_SSL_PK_TYPE_MISMATCH -0x6D00 /**< Public key type mismatch (eg, asked for RSA key exchange and presented EC key) */
/*
* Various constants
diff --git a/library/error.c b/library/error.c
index 23f4a85..333a8f8 100644
--- a/library/error.c
+++ b/library/error.c
@@ -373,6 +373,8 @@
snprintf( buf, buflen, "SSL - Processing of the NewSessionTicket handshake message failed" );
if( use_ret == -(POLARSSL_ERR_SSL_SESSION_TICKET_EXPIRED) )
snprintf( buf, buflen, "SSL - Session ticket has expired" );
+ if( use_ret == -(POLARSSL_ERR_SSL_PK_TYPE_MISMATCH) )
+ snprintf( buf, buflen, "SSL - Public key type mismatch (eg, asked for RSA key exchange and presented EC key)" );
#endif /* POLARSSL_SSL_TLS_C */
#if defined(POLARSSL_X509_PARSE_C)
diff --git a/library/pk.c b/library/pk.c
index 9f33641..ce3b88a 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -124,3 +124,39 @@
return( 0 );
}
+
+/*
+ * Tell if a PK can do the operations of the given type
+ */
+int pk_can_do( pk_context *ctx, pk_type_t type )
+{
+ /* null of NONE context can't do anything */
+ if( ctx == NULL || ctx->info == NULL )
+ return( 0 );
+
+ return( ctx->info->can_do( type ) );
+}
+
+/*
+ * Verify a signature
+ */
+int pk_verify( pk_context *ctx,
+ const unsigned char *hash, const md_info_t *md_info,
+ const unsigned char *sig, size_t sig_len )
+{
+ if( ctx == NULL || ctx->info == NULL )
+ return( POLARSSL_ERR_PK_TYPE_MISMATCH ); // TODO
+
+ return( ctx->info->verify_func( ctx->data, hash, md_info, sig, sig_len ) );
+}
+
+/*
+ * Get key size in bits
+ */
+size_t pk_get_size( const pk_context *ctx )
+{
+ if( ctx == NULL || ctx->info == NULL )
+ return( 0 );
+
+ return( ctx->info->get_size( ctx->data ) );
+}
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 50e8db5..239ff78 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -53,9 +53,9 @@
return( type == POLARSSL_PK_RSA );
}
-static size_t rsa_get_size( void * ctx )
+static size_t rsa_get_size( const void * ctx )
{
- return( mpi_size( &((rsa_context *) ctx)->N ) * 8 );
+ return( 8 * ((rsa_context *) ctx)->len );
}
static int rsa_verify_wrap( void *ctx,
@@ -101,7 +101,7 @@
return( type == POLARSSL_PK_ECDSA );
}
-static size_t ecdsa_get_size( void *ctx )
+static size_t ecdsa_get_size( const void *ctx )
{
return( ((ecdsa_context *) ctx)->grp.pbits );
}
@@ -152,7 +152,7 @@
type == POLARSSL_PK_ECDSA );
}
-static size_t eckey_get_size( void *ctx )
+static size_t eckey_get_size( const void *ctx )
{
return( ((ecp_keypair *) ctx)->grp.pbits );
}
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 6674348..1c2c395 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1346,12 +1346,15 @@
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
}
- /* EC NOT IMPLEMENTED YET */
- if( ssl->session_negotiate->peer_cert->pk.type != POLARSSL_PK_RSA )
- return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
+ if( ! pk_can_do( &ssl->session_negotiate->peer_cert->pk,
+ POLARSSL_PK_RSA ) )
+ {
+ SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
+ return( POLARSSL_ERR_SSL_PK_TYPE_MISMATCH );
+ }
- if( (unsigned int)( end - p ) !=
- pk_rsa( ssl->session_negotiate->peer_cert->pk )->len )
+ if( 8 * (unsigned int)( end - p ) !=
+ pk_get_size( &ssl->session_negotiate->peer_cert->pk ) )
{
SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
@@ -1795,12 +1798,15 @@
if( ret != 0 )
return( ret );
- /* EC NOT IMPLEMENTED YET */
- if( ssl->session_negotiate->peer_cert->pk.type != POLARSSL_PK_RSA )
- return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
+ if( ! pk_can_do( &ssl->session_negotiate->peer_cert->pk,
+ POLARSSL_PK_RSA ) )
+ {
+ SSL_DEBUG_MSG( 1, ( "certificate key type mismatch" ) );
+ return( POLARSSL_ERR_SSL_PK_TYPE_MISMATCH );
+ }
i = 4;
- n = pk_rsa( ssl->session_negotiate->peer_cert->pk )->len;
+ n = pk_get_size( &ssl->session_negotiate->peer_cert->pk ) / 8;
if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
{
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 36c4f2f..0780da5 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2517,10 +2517,13 @@
}
/* EC NOT IMPLEMENTED YET */
- if( ssl->session_negotiate->peer_cert->pk.type != POLARSSL_PK_RSA )
+ if( ! pk_can_do( &ssl->session_negotiate->peer_cert->pk,
+ POLARSSL_PK_RSA ) )
+ {
return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
+ }
- n1 = pk_rsa( ssl->session_negotiate->peer_cert->pk )->len;
+ n1 = pk_get_size( &ssl->session_negotiate->peer_cert->pk ) / 8;
n2 = ( ssl->in_msg[4 + n] << 8 ) | ssl->in_msg[5 + n];
if( n + n1 + 6 != ssl->in_hslen || n1 != n2 )
diff --git a/library/x509parse.c b/library/x509parse.c
index a8fcc0b..225f45d 100644
--- a/library/x509parse.c
+++ b/library/x509parse.c
@@ -3147,7 +3147,7 @@
}
ret = snprintf( p, n, "\n%s%-" BC "s: %d bits\n", prefix, key_size_str,
- (int) crt->pk.info->get_size( crt->pk.data ) );
+ (int) pk_get_size( &crt->pk ) );
SAFE_SNPRINTF();
return( (int) ( size - n ) );
@@ -3399,9 +3399,9 @@
md( md_info, crl_list->tbs.p, crl_list->tbs.len, hash );
- if( ca->pk.info->can_do( crl_list->sig_pk ) == 0 ||
- ca->pk.info->verify_func( ca->pk.data, hash, md_info,
- crl_list->sig.p, crl_list->sig.len ) != 0 )
+ if( pk_can_do( &ca->pk, crl_list->sig_pk ) == 0 ||
+ pk_verify( &ca->pk, hash, md_info,
+ crl_list->sig.p, crl_list->sig.len ) != 0 )
{
flags |= BADCRL_NOT_TRUSTED;
break;
@@ -3516,9 +3516,9 @@
md( md_info, child->tbs.p, child->tbs.len, hash );
- if( trust_ca->pk.info->can_do( child->sig_pk ) == 0 ||
- trust_ca->pk.info->verify_func( trust_ca->pk.data, hash, md_info,
- child->sig.p, child->sig.len ) != 0 )
+ if( pk_can_do( &trust_ca->pk, child->sig_pk ) == 0 ||
+ pk_verify( &trust_ca->pk, hash, md_info,
+ child->sig.p, child->sig.len ) != 0 )
{
trust_ca = trust_ca->next;
continue;
@@ -3593,9 +3593,9 @@
{
md( md_info, child->tbs.p, child->tbs.len, hash );
- if( parent->pk.info->can_do( child->sig_pk ) == 0 ||
- parent->pk.info->verify_func( parent->pk.data, hash, md_info,
- child->sig.p, child->sig.len ) != 0 )
+ if( pk_can_do( &parent->pk, child->sig_pk ) == 0 ||
+ pk_verify( &parent->pk, hash, md_info,
+ child->sig.p, child->sig.len ) != 0 )
{
*flags |= BADCERT_NOT_TRUSTED;
}