Simplify SSL session cache implementation via session serialization

Signed-off-by: Hanno Becker <hanno.becker@arm.com>
diff --git a/library/ssl_cache.c b/library/ssl_cache.c
index 3fdab5b..3677e51 100644
--- a/library/ssl_cache.c
+++ b/library/ssl_cache.c
@@ -69,9 +69,9 @@
             continue;
 #endif
 
-        if( session_id_len != cur->session.id_len ||
-            memcmp( session_id, cur->session.id,
-                    cur->session.id_len ) != 0 )
+        if( session_id_len != cur->session_id_len ||
+            memcmp( session_id, cur->session_id,
+                    cur->session_id_len ) != 0 )
         {
             continue;
         }
@@ -107,40 +107,12 @@
     if( ret != 0 )
         goto exit;
 
-    ret = mbedtls_ssl_session_copy( session, &entry->session );
+    ret = mbedtls_ssl_session_load( session,
+                                    entry->session,
+                                    entry->session_len );
     if( ret != 0 )
         goto exit;
 
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-    /*
-     * Restore peer certificate (without rest of the original chain)
-     */
-    if( entry->peer_cert.p != NULL )
-    {
-        /* `session->peer_cert` is NULL after the call to
-         * mbedtls_ssl_session_copy(), because cache entries
-         * have the `peer_cert` field set to NULL. */
-
-        if( ( session->peer_cert = mbedtls_calloc( 1,
-                             sizeof(mbedtls_x509_crt) ) ) == NULL )
-        {
-            ret = 1;
-            goto exit;
-        }
-
-        mbedtls_x509_crt_init( session->peer_cert );
-        if( mbedtls_x509_crt_parse( session->peer_cert, entry->peer_cert.p,
-                            entry->peer_cert.len ) != 0 )
-        {
-            mbedtls_free( session->peer_cert );
-            session->peer_cert = NULL;
-            ret = 1;
-            goto exit;
-        }
-    }
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
-
     ret = 0;
 
 exit:
@@ -181,8 +153,8 @@
         }
 #endif
 
-        if( session_id_len == cur->session.id_len &&
-            memcmp( session_id, cur->session.id, cur->session.id_len ) == 0 )
+        if( session_id_len == cur->session_id_len &&
+            memcmp( session_id, cur->session_id, cur->session_id_len ) == 0 )
         {
             break; /* client reconnected, keep timestamp for session id */
         }
@@ -262,12 +234,15 @@
         *dst = cur;
 
         /*
-         * If we're reusing an entry, free its certificate first
+         * If we're reusing an entry, free it first.
          */
-        if( cur->peer_cert.p != NULL )
+        if( cur->session != NULL )
         {
-            mbedtls_free( cur->peer_cert.p );
-            memset( &cur->peer_cert, 0, sizeof(mbedtls_x509_buf) );
+            mbedtls_free( cur->session );
+            cur->session = NULL;
+            cur->session_len = 0;
+            memset( cur->session_id, 0, sizeof( cur->session_id ) );
+            cur->session_id_len = 0;
         }
 
         ret = 0;
@@ -287,6 +262,9 @@
     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
     mbedtls_ssl_cache_entry *cur;
 
+    size_t session_serialized_len;
+    unsigned char *session_serialized = NULL;
+
 #if defined(MBEDTLS_THREADING_C)
     if( ( ret = mbedtls_mutex_lock( &cache->mutex ) ) != 0 )
         return( ret );
@@ -298,41 +276,41 @@
     if( ret != 0 )
         goto exit;
 
-    /* Copy the entire session; this temporarily makes a copy of the
-     * X.509 CRT structure even though we only want to store the raw CRT.
-     * This inefficiency will go away as soon as we implement on-demand
-     * parsing of CRTs, in which case there's no need for the `peer_cert`
-     * field anymore in the first place, and we're done after this call. */
-    ret = mbedtls_ssl_session_copy( &cur->session, session );
-    if( ret != 0 )
+    /* Check how much space we need to serialize the session
+     * and allocate a sufficiently large buffer. */
+    ret = mbedtls_ssl_session_save( session, NULL, 0, &session_serialized_len );
+    if( ret != MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL )
     {
         ret = 1;
         goto exit;
     }
 
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-    /* If present, free the X.509 structure and only store the raw CRT data. */
-    if( cur->session.peer_cert != NULL )
+    session_serialized = mbedtls_calloc( 1, session_serialized_len );
+    if( session_serialized == NULL )
     {
-        cur->peer_cert.p =
-            mbedtls_calloc( 1, cur->session.peer_cert->raw.len );
-        if( cur->peer_cert.p == NULL )
-        {
-            ret = 1;
-            goto exit;
-        }
-
-        memcpy( cur->peer_cert.p,
-                cur->session.peer_cert->raw.p,
-                cur->session.peer_cert->raw.len );
-        cur->peer_cert.len = session->peer_cert->raw.len;
-
-        mbedtls_x509_crt_free( cur->session.peer_cert );
-        mbedtls_free( cur->session.peer_cert );
-        cur->session.peer_cert = NULL;
+        ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
+        goto exit;
     }
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
+
+    /* Now serialize the session into the allocated buffer. */
+    ret = mbedtls_ssl_session_save( session,
+                                    session_serialized,
+                                    session_serialized_len,
+                                    &session_serialized_len );
+    if( ret != 0 )
+        goto exit;
+
+    if( session_id_len > 32 )
+    {
+        ret = 1;
+        goto exit;
+    }
+    cur->session_id_len = session_id_len;
+    memcpy( cur->session_id, session_id, session_id_len );
+
+    cur->session = session_serialized;
+    cur->session_len = session_serialized_len;
+    session_serialized = NULL;
 
     ret = 0;
 
@@ -342,6 +320,9 @@
         ret = 1;
 #endif
 
+    if( session_serialized != NULL )
+        mbedtls_platform_zeroize( session_serialized, session_serialized_len );
+
     return( ret );
 }
 
@@ -372,13 +353,7 @@
         prv = cur;
         cur = cur->next;
 
-        mbedtls_ssl_session_free( &prv->session );
-
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-        mbedtls_free( prv->peer_cert.p );
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
-
+        mbedtls_free( prv->session );
         mbedtls_free( prv );
     }