Merge pull request #4347 from hanno-arm/ssl_session_cache_3_0

Add session ID as an explicit parameter to SSL session cache API
diff --git a/ChangeLog.d/session-cache-api.txt b/ChangeLog.d/session-cache-api.txt
new file mode 100644
index 0000000..75cc943
--- /dev/null
+++ b/ChangeLog.d/session-cache-api.txt
@@ -0,0 +1,5 @@
+API changes
+    * The getter and setter API of the SSL session cache (used for
+      session-ID based session resumption) has changed to that of
+      a key-value store with keys being session IDs and values
+      being opaque instances of `mbedtls_ssl_session`.
diff --git a/docs/3.0-migration-guide.d/session-cache-api.md b/docs/3.0-migration-guide.d/session-cache-api.md
new file mode 100644
index 0000000..b28ce19
--- /dev/null
+++ b/docs/3.0-migration-guide.d/session-cache-api.md
@@ -0,0 +1,28 @@
+Session Cache API Change
+-----------------------------------------------------------------
+
+This affects users who use `mbedtls_ssl_conf_session_cache()`
+to configure a custom session cache implementation different
+from the one Mbed TLS implements in `library/ssl_cache.c`.
+
+Those users will need to modify the API of their session cache
+implementation to that of a key-value store with keys being
+session IDs and values being instances of `mbedtls_ssl_session`:
+
+```
+typedef int mbedtls_ssl_cache_get_t( void *data,
+                                     unsigned char const *session_id,
+                                     size_t session_id_len,
+                                     mbedtls_ssl_session *session );
+typedef int mbedtls_ssl_cache_set_t( void *data,
+                                     unsigned char const *session_id,
+                                     size_t session_id_len,
+                                     const mbedtls_ssl_session *session );
+```
+
+Since the structure of `mbedtls_ssl_session` is no longer public from 3.0
+onwards, portable session cache implementations must not access fields of
+`mbedtls_ssl_session`. See the corresponding migration guide. Users that
+find themselves unable to migrate their session cache functionality without
+accessing fields of `mbedtls_ssl_session` should describe their usecase
+on the Mbed TLS mailing list.
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index 25e25d5..88a599c 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -480,6 +480,7 @@
    MBEDTLS_SSL_TLS_PRF_SHA256
 }
 mbedtls_tls_prf_types;
+
 /**
  * \brief          Callback type: send data on the network.
  *
@@ -605,6 +606,56 @@
 typedef struct mbedtls_ssl_flight_item mbedtls_ssl_flight_item;
 #endif
 
+/**
+ * \brief          Callback type: server-side session cache getter
+ *
+ *                 The session cache is logically a key value store, with
+ *                 keys being session IDs and values being instances of
+ *                 mbedtls_ssl_session.
+ *
+ *                 This callback retrieves an entry in this key-value store.
+ *
+ * \param data            The address of the session cache structure to query.
+ * \param session_id      The buffer holding the session ID to query.
+ * \param session_id_len  The length of \p session_id in Bytes.
+ * \param session         The address of the session structure to populate.
+ *                        It is initialized with mbdtls_ssl_session_init(),
+ *                        and the callback must always leave it in a state
+ *                        where it can safely be freed via
+ *                        mbedtls_ssl_session_free() independent of the
+ *                        return code of this function.
+ *
+ * \return                \c 0 on success
+ * \return                A non-zero return value on failure.
+ *
+ */
+typedef int mbedtls_ssl_cache_get_t( void *data,
+                                     unsigned char const *session_id,
+                                     size_t session_id_len,
+                                     mbedtls_ssl_session *session );
+/**
+ * \brief          Callback type: server-side session cache setter
+ *
+ *                 The session cache is logically a key value store, with
+ *                 keys being session IDs and values being instances of
+ *                 mbedtls_ssl_session.
+ *
+ *                 This callback sets an entry in this key-value store.
+ *
+ * \param data            The address of the session cache structure to modify.
+ * \param session_id      The buffer holding the session ID to query.
+ * \param session_id_len  The length of \p session_id in Bytes.
+ * \param session         The address of the session to be stored in the
+ *                        session cache.
+ *
+ * \return                \c 0 on success
+ * \return                A non-zero return value on failure.
+ */
+typedef int mbedtls_ssl_cache_set_t( void *data,
+                                     unsigned char const *session_id,
+                                     size_t session_id_len,
+                                     const mbedtls_ssl_session *session );
+
 #if defined(MBEDTLS_SSL_ASYNC_PRIVATE)
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
 /**
@@ -950,9 +1001,9 @@
     void *p_rng;                    /*!< context for the RNG function       */
 
     /** Callback to retrieve a session from the cache                       */
-    int (*f_get_cache)(void *, mbedtls_ssl_session *);
+    mbedtls_ssl_cache_get_t *f_get_cache;
     /** Callback to store a session into the cache                          */
-    int (*f_set_cache)(void *, const mbedtls_ssl_session *);
+    mbedtls_ssl_cache_set_t *f_set_cache;
     void *p_cache;                  /*!< context for cache callbacks        */
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -2360,9 +2411,9 @@
  * \param f_set_cache    session set callback
  */
 void mbedtls_ssl_conf_session_cache( mbedtls_ssl_config *conf,
-        void *p_cache,
-        int (*f_get_cache)(void *, mbedtls_ssl_session *),
-        int (*f_set_cache)(void *, const mbedtls_ssl_session *) );
+                                     void *p_cache,
+                                     mbedtls_ssl_cache_get_t *f_get_cache,
+                                     mbedtls_ssl_cache_set_t *f_set_cache );
 #endif /* MBEDTLS_SSL_SRV_C */
 
 #if defined(MBEDTLS_SSL_CLI_C)
diff --git a/include/mbedtls/ssl_cache.h b/include/mbedtls/ssl_cache.h
index c6ef296..ac7b77c 100644
--- a/include/mbedtls/ssl_cache.h
+++ b/include/mbedtls/ssl_cache.h
@@ -67,11 +67,13 @@
 #if defined(MBEDTLS_HAVE_TIME)
     mbedtls_time_t timestamp;           /*!< entry timestamp    */
 #endif
-    mbedtls_ssl_session session;        /*!< entry session      */
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-    mbedtls_x509_buf peer_cert;         /*!< entry peer_cert    */
-#endif
+
+    unsigned char session_id[32];       /*!< session ID         */
+    size_t session_id_len;
+
+    unsigned char *session;             /*!< serialized session */
+    size_t session_len;
+
     mbedtls_ssl_cache_entry *next;      /*!< chain pointer      */
 };
 
@@ -99,19 +101,32 @@
  * \brief          Cache get callback implementation
  *                 (Thread-safe if MBEDTLS_THREADING_C is enabled)
  *
- * \param data     SSL cache context
- * \param session  session to retrieve entry for
+ * \param data            The SSL cache context to use.
+ * \param session_id      The pointer to the buffer holding the session ID
+ *                        for the session to load.
+ * \param session_id_len  The length of \p session_id in bytes.
+ * \param session         The address at which to store the session
+ *                        associated with \p session_id, if present.
  */
-int mbedtls_ssl_cache_get( void *data, mbedtls_ssl_session *session );
+int mbedtls_ssl_cache_get( void *data,
+                           unsigned char const *session_id,
+                           size_t session_id_len,
+                           mbedtls_ssl_session *session );
 
 /**
  * \brief          Cache set callback implementation
  *                 (Thread-safe if MBEDTLS_THREADING_C is enabled)
  *
- * \param data     SSL cache context
- * \param session  session to store entry for
+ * \param data            The SSL cache context to use.
+ * \param session_id      The pointer to the buffer holding the session ID
+ *                        associated to \p session.
+ * \param session_id_len  The length of \p session_id in bytes.
+ * \param session         The session to store.
  */
-int mbedtls_ssl_cache_set( void *data, const mbedtls_ssl_session *session );
+int mbedtls_ssl_cache_set( void *data,
+                           unsigned char const *session_id,
+                           size_t session_id_len,
+                           const mbedtls_ssl_session *session );
 
 #if defined(MBEDTLS_HAVE_TIME)
 /**
diff --git a/library/ssl_cache.c b/library/ssl_cache.c
index bb5007b..fe4f30c 100644
--- a/library/ssl_cache.c
+++ b/library/ssl_cache.c
@@ -50,83 +50,70 @@
 #endif
 }
 
-int mbedtls_ssl_cache_get( void *data, mbedtls_ssl_session *session )
+static int ssl_cache_find_entry( mbedtls_ssl_cache_context *cache,
+                                 unsigned char const *session_id,
+                                 size_t session_id_len,
+                                 mbedtls_ssl_cache_entry **dst )
 {
     int ret = 1;
 #if defined(MBEDTLS_HAVE_TIME)
     mbedtls_time_t t = mbedtls_time( NULL );
 #endif
+    mbedtls_ssl_cache_entry *cur;
+
+    for( cur = cache->chain; cur != NULL; cur = cur->next )
+    {
+#if defined(MBEDTLS_HAVE_TIME)
+        if( cache->timeout != 0 &&
+            (int) ( t - cur->timestamp ) > cache->timeout )
+            continue;
+#endif
+
+        if( session_id_len != cur->session_id_len ||
+            memcmp( session_id, cur->session_id,
+                    cur->session_id_len ) != 0 )
+        {
+            continue;
+        }
+
+        break;
+    }
+
+    if( cur != NULL )
+    {
+        *dst = cur;
+        ret = 0;
+    }
+
+    return( ret );
+}
+
+
+int mbedtls_ssl_cache_get( void *data,
+                           unsigned char const *session_id,
+                           size_t session_id_len,
+                           mbedtls_ssl_session *session )
+{
+    int ret = 1;
     mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
-    mbedtls_ssl_cache_entry *cur, *entry;
+    mbedtls_ssl_cache_entry *entry;
 
 #if defined(MBEDTLS_THREADING_C)
     if( mbedtls_mutex_lock( &cache->mutex ) != 0 )
         return( 1 );
 #endif
 
-    cur = cache->chain;
-    entry = NULL;
-
-    while( cur != NULL )
-    {
-        entry = cur;
-        cur = cur->next;
-
-#if defined(MBEDTLS_HAVE_TIME)
-        if( cache->timeout != 0 &&
-            (int) ( t - entry->timestamp ) > cache->timeout )
-            continue;
-#endif
-
-        if( session->ciphersuite != entry->session.ciphersuite ||
-            session->compression != entry->session.compression ||
-            session->id_len != entry->session.id_len )
-            continue;
-
-        if( memcmp( session->id, entry->session.id,
-                    entry->session.id_len ) != 0 )
-            continue;
-
-        ret = mbedtls_ssl_session_copy( session, &entry->session );
-        if( ret != 0 )
-        {
-            ret = 1;
-            goto exit;
-        }
-
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-        /*
-         * Restore peer certificate (without rest of the original chain)
-         */
-        if( entry->peer_cert.p != NULL )
-        {
-            /* `session->peer_cert` is NULL after the call to
-             * mbedtls_ssl_session_copy(), because cache entries
-             * have the `peer_cert` field set to NULL. */
-
-            if( ( session->peer_cert = mbedtls_calloc( 1,
-                                 sizeof(mbedtls_x509_crt) ) ) == NULL )
-            {
-                ret = 1;
-                goto exit;
-            }
-
-            mbedtls_x509_crt_init( session->peer_cert );
-            if( mbedtls_x509_crt_parse( session->peer_cert, entry->peer_cert.p,
-                                entry->peer_cert.len ) != 0 )
-            {
-                mbedtls_free( session->peer_cert );
-                session->peer_cert = NULL;
-                ret = 1;
-                goto exit;
-            }
-        }
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
-
-        ret = 0;
+    ret = ssl_cache_find_entry( cache, session_id, session_id_len, &entry );
+    if( ret != 0 )
         goto exit;
-    }
+
+    ret = mbedtls_ssl_session_load( session,
+                                    entry->session,
+                                    entry->session_len );
+    if( ret != 0 )
+        goto exit;
+
+    ret = 0;
 
 exit:
 #if defined(MBEDTLS_THREADING_C)
@@ -137,158 +124,184 @@
     return( ret );
 }
 
-int mbedtls_ssl_cache_set( void *data, const mbedtls_ssl_session *session )
+static int ssl_cache_pick_writing_slot( mbedtls_ssl_cache_context *cache,
+                                        unsigned char const *session_id,
+                                        size_t session_id_len,
+                                        mbedtls_ssl_cache_entry **dst )
 {
-    int ret = 1;
 #if defined(MBEDTLS_HAVE_TIME)
     mbedtls_time_t t = mbedtls_time( NULL ), oldest = 0;
+#endif /* MBEDTLS_HAVE_TIME */
+
     mbedtls_ssl_cache_entry *old = NULL;
-#endif
-    mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
-    mbedtls_ssl_cache_entry *cur, *prv;
     int count = 0;
+    mbedtls_ssl_cache_entry *cur, *last;
+
+    /* Check 1: Is there already an entry with the given session ID?
+     *
+     * If yes, overwrite it.
+     *
+     * If not, `count` will hold the size of the session cache
+     * at the end of this loop, and `last` will point to the last
+     * entry, both of which will be used later. */
+
+    last = NULL;
+    for( cur = cache->chain; cur != NULL; cur = cur->next )
+    {
+        count++;
+        if( session_id_len == cur->session_id_len &&
+            memcmp( session_id, cur->session_id, cur->session_id_len ) == 0 )
+        {
+            goto found;
+        }
+        last = cur;
+    }
+
+    /* Check 2: Is there an outdated entry in the cache?
+     *
+     * If so, overwrite it.
+     *
+     * If not, remember the oldest entry in `old` for later.
+     */
+
+#if defined(MBEDTLS_HAVE_TIME)
+    for( cur = cache->chain; cur != NULL; cur = cur->next )
+    {
+        if( cache->timeout != 0 &&
+            (int) ( t - cur->timestamp ) > cache->timeout )
+        {
+            goto found;
+        }
+
+        if( oldest == 0 || cur->timestamp < oldest )
+        {
+            oldest = cur->timestamp;
+            old = cur;
+        }
+    }
+#endif /* MBEDTLS_HAVE_TIME */
+
+    /* Check 3: Is there free space in the cache? */
+
+    if( count < cache->max_entries )
+    {
+        /* Create new entry */
+        cur = mbedtls_calloc( 1, sizeof(mbedtls_ssl_cache_entry) );
+        if( cur == NULL )
+            return( 1 );
+
+        /* Append to the end of the linked list. */
+        if( last == NULL )
+            cache->chain = cur;
+        else
+            last->next = cur;
+
+        goto found;
+    }
+
+    /* Last resort: The cache is full and doesn't contain any outdated
+     * elements. In this case, we evict the oldest one, judged by timestamp
+     * (if present) or cache-order. */
+
+#if defined(MBEDTLS_HAVE_TIME)
+    if( old == NULL )
+    {
+        /* This should only happen on an ill-configured cache
+         * with max_entries == 0. */
+        return( 1 );
+    }
+#else /* MBEDTLS_HAVE_TIME */
+    /* Reuse first entry in chain, but move to last place. */
+    if( cache->chain == NULL )
+        return( 1 );
+
+    old = cache->chain;
+    cache->chain = old->next;
+    old->next = NULL;
+    last->next = old;
+#endif /* MBEDTLS_HAVE_TIME */
+
+    /* Now `old` points to the oldest entry to be overwritten. */
+    cur = old;
+
+found:
+
+#if defined(MBEDTLS_HAVE_TIME)
+    cur->timestamp = t;
+#endif
+
+    /* If we're reusing an entry, free it first. */
+    if( cur->session != NULL )
+    {
+        mbedtls_free( cur->session );
+        cur->session = NULL;
+        cur->session_len = 0;
+        memset( cur->session_id, 0, sizeof( cur->session_id ) );
+        cur->session_id_len = 0;
+    }
+
+    *dst = cur;
+    return( 0 );
+}
+
+int mbedtls_ssl_cache_set( void *data,
+                           unsigned char const *session_id,
+                           size_t session_id_len,
+                           const mbedtls_ssl_session *session )
+{
+    int ret = 1;
+    mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
+    mbedtls_ssl_cache_entry *cur;
+
+    size_t session_serialized_len;
+    unsigned char *session_serialized = NULL;
 
 #if defined(MBEDTLS_THREADING_C)
     if( ( ret = mbedtls_mutex_lock( &cache->mutex ) ) != 0 )
         return( ret );
 #endif
 
-    cur = cache->chain;
-    prv = NULL;
-
-    while( cur != NULL )
-    {
-        count++;
-
-#if defined(MBEDTLS_HAVE_TIME)
-        if( cache->timeout != 0 &&
-            (int) ( t - cur->timestamp ) > cache->timeout )
-        {
-            cur->timestamp = t;
-            break; /* expired, reuse this slot, update timestamp */
-        }
-#endif
-
-        if( memcmp( session->id, cur->session.id, cur->session.id_len ) == 0 )
-            break; /* client reconnected, keep timestamp for session id */
-
-#if defined(MBEDTLS_HAVE_TIME)
-        if( oldest == 0 || cur->timestamp < oldest )
-        {
-            oldest = cur->timestamp;
-            old = cur;
-        }
-#endif
-
-        prv = cur;
-        cur = cur->next;
-    }
-
-    if( cur == NULL )
-    {
-#if defined(MBEDTLS_HAVE_TIME)
-        /*
-         * Reuse oldest entry if max_entries reached
-         */
-        if( count >= cache->max_entries )
-        {
-            if( old == NULL )
-            {
-                ret = 1;
-                goto exit;
-            }
-
-            cur = old;
-        }
-#else /* MBEDTLS_HAVE_TIME */
-        /*
-         * Reuse first entry in chain if max_entries reached,
-         * but move to last place
-         */
-        if( count >= cache->max_entries )
-        {
-            if( cache->chain == NULL )
-            {
-                ret = 1;
-                goto exit;
-            }
-
-            cur = cache->chain;
-            cache->chain = cur->next;
-            cur->next = NULL;
-            prv->next = cur;
-        }
-#endif /* MBEDTLS_HAVE_TIME */
-        else
-        {
-            /*
-             * max_entries not reached, create new entry
-             */
-            cur = mbedtls_calloc( 1, sizeof(mbedtls_ssl_cache_entry) );
-            if( cur == NULL )
-            {
-                ret = 1;
-                goto exit;
-            }
-
-            if( prv == NULL )
-                cache->chain = cur;
-            else
-                prv->next = cur;
-        }
-
-#if defined(MBEDTLS_HAVE_TIME)
-        cur->timestamp = t;
-#endif
-    }
-
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-    /*
-     * If we're reusing an entry, free its certificate first
-     */
-    if( cur->peer_cert.p != NULL )
-    {
-        mbedtls_free( cur->peer_cert.p );
-        memset( &cur->peer_cert, 0, sizeof(mbedtls_x509_buf) );
-    }
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
-
-    /* Copy the entire session; this temporarily makes a copy of the
-     * X.509 CRT structure even though we only want to store the raw CRT.
-     * This inefficiency will go away as soon as we implement on-demand
-     * parsing of CRTs, in which case there's no need for the `peer_cert`
-     * field anymore in the first place, and we're done after this call. */
-    ret = mbedtls_ssl_session_copy( &cur->session, session );
+    ret = ssl_cache_pick_writing_slot( cache,
+                                       session_id, session_id_len,
+                                       &cur );
     if( ret != 0 )
+        goto exit;
+
+    /* Check how much space we need to serialize the session
+     * and allocate a sufficiently large buffer. */
+    ret = mbedtls_ssl_session_save( session, NULL, 0, &session_serialized_len );
+    if( ret != MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL )
     {
         ret = 1;
         goto exit;
     }
 
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-    /* If present, free the X.509 structure and only store the raw CRT data. */
-    if( cur->session.peer_cert != NULL )
+    session_serialized = mbedtls_calloc( 1, session_serialized_len );
+    if( session_serialized == NULL )
     {
-        cur->peer_cert.p =
-            mbedtls_calloc( 1, cur->session.peer_cert->raw.len );
-        if( cur->peer_cert.p == NULL )
-        {
-            ret = 1;
-            goto exit;
-        }
-
-        memcpy( cur->peer_cert.p,
-                cur->session.peer_cert->raw.p,
-                cur->session.peer_cert->raw.len );
-        cur->peer_cert.len = session->peer_cert->raw.len;
-
-        mbedtls_x509_crt_free( cur->session.peer_cert );
-        mbedtls_free( cur->session.peer_cert );
-        cur->session.peer_cert = NULL;
+        ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
+        goto exit;
     }
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
+
+    /* Now serialize the session into the allocated buffer. */
+    ret = mbedtls_ssl_session_save( session,
+                                    session_serialized,
+                                    session_serialized_len,
+                                    &session_serialized_len );
+    if( ret != 0 )
+        goto exit;
+
+    if( session_id_len > sizeof( cur->session_id ) )
+    {
+        ret = 1;
+        goto exit;
+    }
+    cur->session_id_len = session_id_len;
+    memcpy( cur->session_id, session_id, session_id_len );
+
+    cur->session = session_serialized;
+    cur->session_len = session_serialized_len;
+    session_serialized = NULL;
 
     ret = 0;
 
@@ -298,6 +311,9 @@
         ret = 1;
 #endif
 
+    if( session_serialized != NULL )
+        mbedtls_platform_zeroize( session_serialized, session_serialized_len );
+
     return( ret );
 }
 
@@ -328,13 +344,7 @@
         prv = cur;
         cur = cur->next;
 
-        mbedtls_ssl_session_free( &prv->session );
-
-#if defined(MBEDTLS_X509_CRT_PARSE_C) && \
-    defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
-        mbedtls_free( prv->peer_cert.p );
-#endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_KEEP_PEER_CERTIFICATE */
-
+        mbedtls_free( prv->session );
         mbedtls_free( prv );
     }
 
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 5c07e3e..73b79da 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2461,6 +2461,54 @@
 }
 #endif /* MBEDTLS_SSL_DTLS_HELLO_VERIFY */
 
+static void ssl_handle_id_based_session_resumption( mbedtls_ssl_context *ssl )
+{
+    int ret;
+    mbedtls_ssl_session session_tmp;
+    mbedtls_ssl_session * const session = ssl->session_negotiate;
+
+    /* Resume is 0  by default, see ssl_handshake_init().
+     * It may be already set to 1 by ssl_parse_session_ticket_ext(). */
+    if( ssl->handshake->resume == 1 )
+        return;
+    if( session->id_len == 0 )
+        return;
+    if( ssl->conf->f_get_cache == NULL )
+        return;
+#if defined(MBEDTLS_SSL_RENEGOTIATION)
+    if( ssl->renego_status != MBEDTLS_SSL_INITIAL_HANDSHAKE )
+        return;
+#endif
+
+    mbedtls_ssl_session_init( &session_tmp );
+
+    ret = ssl->conf->f_get_cache( ssl->conf->p_cache,
+                                  session->id,
+                                  session->id_len,
+                                  &session_tmp );
+    if( ret != 0 )
+        goto exit;
+
+    if( session->ciphersuite != session_tmp.ciphersuite ||
+        session->compression != session_tmp.compression )
+    {
+        /* Mismatch between cached and negotiated session */
+        goto exit;
+    }
+
+    /* Move semantics */
+    mbedtls_ssl_session_free( session );
+    *session = session_tmp;
+    memset( &session_tmp, 0, sizeof( session_tmp ) );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "session successfully restored from cache" ) );
+    ssl->handshake->resume = 1;
+
+exit:
+
+    mbedtls_ssl_session_free( &session_tmp );
+}
+
 static int ssl_write_server_hello( mbedtls_ssl_context *ssl )
 {
 #if defined(MBEDTLS_HAVE_TIME)
@@ -2531,22 +2579,7 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "server hello, random bytes", buf + 6, 32 );
 
-    /*
-     * Resume is 0  by default, see ssl_handshake_init().
-     * It may be already set to 1 by ssl_parse_session_ticket_ext().
-     * If not, try looking up session ID in our cache.
-     */
-    if( ssl->handshake->resume == 0 &&
-#if defined(MBEDTLS_SSL_RENEGOTIATION)
-        ssl->renego_status == MBEDTLS_SSL_INITIAL_HANDSHAKE &&
-#endif
-        ssl->session_negotiate->id_len != 0 &&
-        ssl->conf->f_get_cache != NULL &&
-        ssl->conf->f_get_cache( ssl->conf->p_cache, ssl->session_negotiate ) == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 3, ( "session successfully restored from cache" ) );
-        ssl->handshake->resume = 1;
-    }
+    ssl_handle_id_based_session_resumption( ssl );
 
     if( ssl->handshake->resume == 0 )
     {
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index b81f8f5..170d563 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3036,7 +3036,10 @@
         ssl->session->id_len != 0 &&
         resume == 0 )
     {
-        if( ssl->conf->f_set_cache( ssl->conf->p_cache, ssl->session ) != 0 )
+        if( ssl->conf->f_set_cache( ssl->conf->p_cache,
+                                    ssl->session->id,
+                                    ssl->session->id_len,
+                                    ssl->session ) != 0 )
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "cache did not store session" ) );
     }
 
@@ -3758,9 +3761,9 @@
 
 #if defined(MBEDTLS_SSL_SRV_C)
 void mbedtls_ssl_conf_session_cache( mbedtls_ssl_config *conf,
-        void *p_cache,
-        int (*f_get_cache)(void *, mbedtls_ssl_session *),
-        int (*f_set_cache)(void *, const mbedtls_ssl_session *) )
+                                     void *p_cache,
+                                     mbedtls_ssl_cache_get_t *f_get_cache,
+                                     mbedtls_ssl_cache_set_t *f_set_cache )
 {
     conf->p_cache = p_cache;
     conf->f_get_cache = f_get_cache;