Expand and fix resend infrastructure
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index 1a7722c..63a7528 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -644,6 +644,12 @@
     unsigned char retransmit_state;     /*!<  Retransmission state           */
     ssl_flight_item *flight;            /*!<  Current outgoing flight        */
     ssl_flight_item *cur_msg;           /*!<  Current message in flight      */
+    unsigned int in_flight_start_seq;   /*!<  Minimum message sequence in the
+                                              flight being received          */
+    ssl_transform *alt_transform_out;   /*!<  Alternative transform for
+                                              resending messages             */
+    unsigned char alt_out_ctr[8];       /*!<  Alternative record epoch/counter
+                                              for resending messages         */
 #endif
 
     /*
@@ -719,7 +725,8 @@
 struct _ssl_flight_item
 {
     unsigned char *p;       /*!< message, including handshake headers   */
-    size_t len;             /*!< length of hs_msg                       */
+    size_t len;             /*!< length of p                            */
+    unsigned char type;     /*!< type of the message: handshake or CCS  */
     ssl_flight_item *next;  /*!< next handshake message(s)              */
 };
 #endif /* POLARSSL_SSL_PROTO_DTLS */
@@ -2031,6 +2038,11 @@
     return( 4 );
 }
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+void ssl_recv_flight_completed( ssl_context *ssl );
+int ssl_resend( ssl_context *ssl );
+#endif
+
 /* constant-time buffer comparison */
 static inline int safer_memcmp( const void *a, const void *b, size_t n )
 {
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index ce14f58..b5dae23 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1011,6 +1011,8 @@
     ssl->state = SSL_CLIENT_HELLO;
     ssl_reset_checksum( ssl );
 
+    ssl_recv_flight_completed( ssl );
+
     SSL_DEBUG_MSG( 2, ( "<= parse hello verify request" ) );
 
     return( 0 );
@@ -2229,6 +2231,11 @@
 
     ssl->state++;
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
+        ssl_recv_flight_completed( ssl );
+#endif
+
     SSL_DEBUG_MSG( 2, ( "<= parse server hello done" ) );
 
     return( 0 );
@@ -2734,6 +2741,16 @@
     if( ( ret = ssl_flush_output( ssl ) ) != 0 )
         return( ret );
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM &&
+        ssl->handshake != NULL &&
+        ssl->handshake->retransmit_state == SSL_RETRANS_SENDING )
+    {
+        if( ( ret = ssl_resend( ssl ) ) != 0 )
+            return( ret );
+    }
+#endif
+
     switch( ssl->state )
     {
         case SSL_HELLO_REQUEST:
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index c839ea7..a0bb653 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -1805,6 +1805,11 @@
 
     ssl->state++;
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
+        ssl_recv_flight_completed( ssl );
+#endif
+
     SSL_DEBUG_MSG( 2, ( "<= parse client hello" ) );
 
     return( 0 );
@@ -3485,6 +3490,16 @@
     if( ( ret = ssl_flush_output( ssl ) ) != 0 )
         return( ret );
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM &&
+        ssl->handshake != NULL &&
+        ssl->handshake->retransmit_state == SSL_RETRANS_SENDING )
+    {
+        if( ( ret = ssl_resend( ssl ) ) != 0 )
+            return( ret );
+    }
+#endif
+
     switch( ssl->state )
     {
         case SSL_HELLO_REQUEST:
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 423bc0b..9160baa 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2013,6 +2013,9 @@
     return( 0 );
 }
 
+/*
+ * Functions to handle the DTLS retransmission state machine
+ */
 #if defined(POLARSSL_SSL_PROTO_DTLS)
 /*
  * Append current handshake message to current outgoing flight
@@ -2038,14 +2041,12 @@
     /* Copy current handshake message with headers */
     memcpy( msg->p, ssl->out_msg, ssl->out_msglen );
     msg->len = ssl->out_msglen;
+    msg->type = ssl->out_msgtype;
     msg->next = NULL;
 
     /* Append to the current flight */
     if( ssl->handshake->flight == NULL )
-    {
-        ssl->handshake->flight  = msg;
-        ssl->handshake->cur_msg = msg;
-    }
+        ssl->handshake->flight = msg;
     else
     {
         ssl_flight_item *cur = ssl->handshake->flight;
@@ -2077,30 +2078,72 @@
 }
 
 /*
- * Send current flight of messages.
+ * Swap transform_out and out_ctr with the alternative ones
+ */
+static void ssl_swap_epochs( ssl_context *ssl )
+{
+    ssl_transform *tmp_transform;
+    unsigned char tmp_out_ctr[8];
+
+    if( ssl->transform_out == ssl->handshake->alt_transform_out )
+    {
+        SSL_DEBUG_MSG( 3, ( "skip swap epochs" ) );
+        return;
+    }
+
+    SSL_DEBUG_MSG( 3, ( "swap epochs" ) );
+
+    tmp_transform                     = ssl->transform_out;
+    ssl->transform_out                = ssl->handshake->alt_transform_out;
+    ssl->handshake->alt_transform_out = tmp_transform;
+
+    memcpy( tmp_out_ctr,                 ssl->out_ctr,                8 );
+    memcpy( ssl->out_ctr,                ssl->handshake->alt_out_ctr, 8 );
+    memcpy( ssl->handshake->alt_out_ctr, tmp_out_ctr,                 8 );
+}
+
+/*
+ * Retransmit the current flight of messages.
  *
  * Need to remember the current message in case flush_output returns
  * WANT_WRITE, causing us to exit this function and come back later.
+ * This function must be called until state is no longer SENDING.
  */
-static int ssl_send_current_flight( ssl_context *ssl )
+int ssl_resend( ssl_context *ssl )
 {
-    ssl->handshake->retransmit_state = SSL_RETRANS_SENDING;
+    SSL_DEBUG_MSG( 2, ( "=> ssl_resend" ) );
 
-    SSL_DEBUG_MSG( 2, ( "=> ssl_send_current_flight" ) );
+    if( ssl->handshake->retransmit_state != SSL_RETRANS_SENDING )
+    {
+        SSL_DEBUG_MSG( 2, ( "initialise resending" ) );
+
+        ssl->handshake->cur_msg = ssl->handshake->flight;
+        ssl_swap_epochs( ssl );
+
+        ssl->handshake->retransmit_state = SSL_RETRANS_SENDING;
+    }
 
     while( ssl->handshake->cur_msg != NULL )
     {
         int ret;
         ssl_flight_item *cur = ssl->handshake->cur_msg;
 
+        memcpy( ssl->out_msg, cur->p, cur->len );
         ssl->out_msglen = cur->len;
-        memcpy( ssl->out_msg, cur->p, ssl->out_msglen );
-        ssl->out_msgtype = SSL_MSG_HANDSHAKE;
+        ssl->out_msgtype = cur->type;
 
         ssl->handshake->cur_msg = cur->next;
 
         SSL_DEBUG_BUF( 3, "resent handshake message header", ssl->out_msg, 12 );
 
+        /* Swap epochs before sending Finished: we can't do it right after
+         * sending ChangeCipherSpec, in case write returns WANT_READ */
+        if( ssl->out_msgtype == SSL_MSG_HANDSHAKE &&
+            ssl->out_msg[0] == SSL_HS_FINISHED )
+        {
+            ssl_swap_epochs( ssl );
+        }
+
         if( ( ret = ssl_write_record( ssl ) ) != 0 )
         {
             SSL_DEBUG_RET( 1, "ssl_write_record", ret );
@@ -2110,10 +2153,32 @@
 
     ssl->handshake->retransmit_state = SSL_RETRANS_WAITING;
 
-    SSL_DEBUG_MSG( 2, ( "<= ssl_send_current_flight" ) );
+    SSL_DEBUG_MSG( 2, ( "<= ssl_resend" ) );
 
     return( 0 );
 }
+
+/*
+ * To be called when the last message of an incoming flight is received.
+ */
+void ssl_recv_flight_completed( ssl_context *ssl )
+{
+    /* We won't need to resend that one any more */
+    ssl_flight_free( ssl->handshake->flight );
+    ssl->handshake->flight = NULL;
+    ssl->handshake->cur_msg = NULL;
+
+    /* The next incoming flight will start with this msg_seq */
+    ssl->handshake->in_flight_start_seq = ssl->handshake->in_msg_seq;
+
+    if( ssl->in_msgtype == SSL_MSG_HANDSHAKE &&
+        ssl->in_msg[0] == SSL_HS_FINISHED )
+    {
+        ssl->handshake->retransmit_state = SSL_RETRANS_FINISHED;
+    }
+    else
+        ssl->handshake->retransmit_state = SSL_RETRANS_PREPARING;
+}
 #endif /* POLARSSL_SSL_PROTO_DTLS */
 
 /*
@@ -3818,14 +3883,16 @@
      * data.
      */
     SSL_DEBUG_MSG( 3, ( "switching to new transform spec for outbound data" ) );
-    ssl->transform_out = ssl->transform_negotiate;
-    ssl->session_out = ssl->session_negotiate;
 
 #if defined(POLARSSL_SSL_PROTO_DTLS)
     if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
     {
         unsigned char i;
 
+        /* Remember current epoch settings for resending */
+        ssl->handshake->alt_transform_out = ssl->transform_out;
+        memcpy( ssl->handshake->alt_out_ctr, ssl->out_ctr, 8 );
+
         /* Set sequence_number to zero */
         memset( ssl->out_ctr + 2, 0, 6 );
 
@@ -3845,6 +3912,9 @@
 #endif /* POLARSSL_SSL_PROTO_DTLS */
     memset( ssl->out_ctr, 0, 8 );
 
+    ssl->transform_out = ssl->transform_negotiate;
+    ssl->session_out = ssl->session_negotiate;
+
 #if defined(POLARSSL_SSL_HW_RECORD_ACCEL)
     if( ssl_hw_record_activate != NULL )
     {
@@ -3985,6 +4055,11 @@
     else
         ssl->state++;
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
+        ssl_recv_flight_completed( ssl );
+#endif
+
     SSL_DEBUG_MSG( 2, ( "<= parse finished" ) );
 
     return( 0 );
@@ -4098,6 +4173,18 @@
     ssl->handshake->key_cert = ssl->key_cert;
 #endif
 
+#if defined(POLARSSL_SSL_PROTO_DTLS)
+    if( ssl->transport == SSL_TRANSPORT_DATAGRAM )
+    {
+        ssl->handshake->alt_transform_out = ssl->transform_out;
+
+        if( ssl->endpoint == SSL_IS_CLIENT )
+            ssl->handshake->retransmit_state = SSL_RETRANS_PREPARING;
+        else
+            ssl->handshake->retransmit_state = SSL_RETRANS_WAITING;
+    }
+#endif
+
     return( 0 );
 }