aes: xts: Rewrite to avoid use of goto
The flow was a bit hard to follow with the `goto` everywhere. Rewrite the
XTS implementation to avoid the use of `goto`.
diff --git a/library/aes.c b/library/aes.c
index 2b64387..80447b7 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -1135,129 +1135,92 @@
const unsigned char *input,
unsigned char *output )
{
- union xts_buf128 {
- uint8_t u8[16];
- uint64_t u64[2];
- };
+ int ret;
+ size_t blocks = length / 16;
+ size_t leftover = length % 16;
+ unsigned char tweak[16];
+ unsigned char prev_tweak[16];
+ unsigned char tmp[16];
- union xts_buf128 scratch;
- union xts_buf128 cts_scratch;
- union xts_buf128 t_buf;
- union xts_buf128 cts_t_buf;
- union xts_buf128 *inbuf;
- union xts_buf128 *outbuf;
-
- size_t nblk = length / 16;
- size_t remn = length % 16;
-
- inbuf = (union xts_buf128*)input;
- outbuf = (union xts_buf128*)output;
-
- /* For performing the ciphertext-stealing operation, we have to get at least
- * one complete block */
+ /* Sectors must be at least 16 bytes. */
if( length < 16 )
- return( MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
+ return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
/* NIST SP 80-38E disallows data units larger than 2**20 blocks. */
if( length > ( 1 << 20 ) * 16 )
return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
- mbedtls_aes_crypt_ecb( &ctx->tweak, MBEDTLS_AES_ENCRYPT, iv, t_buf.u8 );
+ /* Compute the tweak. */
+ ret = mbedtls_aes_crypt_ecb( &ctx->tweak, MBEDTLS_AES_ENCRYPT, iv, tweak );
+ if( ret != 0 )
+ return( ret );
- if( mode == MBEDTLS_AES_DECRYPT && remn )
+ while( blocks-- )
{
- if( nblk == 1 )
- goto decrypt_only_one_full_block;
- nblk--;
+ size_t i;
+
+ if( leftover && ( mode == MBEDTLS_AES_DECRYPT ) && blocks == 0 )
+ {
+ /* We are on the last block in a decrypt operation that has
+ * leftover bytes, so we need to use the next tweak for this block,
+ * and this tweak for the lefover bytes. Save the current tweak for
+ * the leftovers and then update the current tweak for use on this,
+ * the last full block. */
+ memcpy( prev_tweak, tweak, sizeof( tweak ) );
+ mbedtls_gf128mul_x_ble( tweak, tweak );
+ }
+
+ for( i = 0; i < 16; i++ )
+ tmp[i] = input[i] ^ tweak[i];
+
+ ret = mbedtls_aes_crypt_ecb( &ctx->crypt, mode, tmp, tmp );
+ if( ret != 0 )
+ return( ret );
+
+ for( i = 0; i < 16; i++ )
+ output[i] = tmp[i] ^ tweak[i];
+
+ /* Update the tweak for the next block. */
+ mbedtls_gf128mul_x_ble( tweak, tweak );
+
+ output += 16;
+ input += 16;
}
- goto first;
-
- do
+ if( leftover )
{
- mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
+ /* If we are on the leftover bytes in a decrypt operation, we need to
+ * use the previous tweak for these bytes (as saved in prev_tweak). */
+ unsigned char *t = mode == MBEDTLS_AES_DECRYPT ? prev_tweak : tweak;
-first:
- /* PP <- T xor P */
- scratch.u64[0] = (uint64_t)( inbuf->u64[0] ^ t_buf.u64[0] );
- scratch.u64[1] = (uint64_t)( inbuf->u64[1] ^ t_buf.u64[1] );
+ /* We are now on the final part of the data unit, which doesn't divide
+ * evenly by 16. It's time for ciphertext stealing. */
+ size_t i;
+ unsigned char *prev_output = output - 16;
- /* CC <- E(Key2,PP) */
- mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, outbuf->u8 );
-
- /* C <- T xor CC */
- outbuf->u64[0] = (uint64_t)( outbuf->u64[0] ^ t_buf.u64[0] );
- outbuf->u64[1] = (uint64_t)( outbuf->u64[1] ^ t_buf.u64[1] );
-
- inbuf += 1;
- outbuf += 1;
- nblk -= 1;
- } while( nblk > 0 );
-
- /* Ciphertext stealing, if necessary */
- if( remn != 0 )
- {
- outbuf = (union xts_buf128*)output;
- inbuf = (union xts_buf128*)input;
- nblk = length / 16;
-
- if( mode == MBEDTLS_AES_ENCRYPT )
+ /* Copy ciphertext bytes from the previous block to our output for each
+ * byte of cyphertext we won't steal. At the same time, copy the
+ * remainder of the input for this final round (since the loop bounds
+ * are the same). */
+ for( i = 0; i < leftover; i++ )
{
- memcpy( cts_scratch.u8, (uint8_t*)&inbuf[nblk], remn );
- memcpy( cts_scratch.u8 + remn, ((uint8_t*)&outbuf[nblk - 1]) + remn, 16 - remn );
- memcpy( (uint8_t*)&outbuf[nblk], (uint8_t*)&outbuf[nblk - 1], remn );
-
- mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
-
- /* PP <- T xor P */
- scratch.u64[0] = (uint64_t)( cts_scratch.u64[0] ^ t_buf.u64[0] );
- scratch.u64[1] = (uint64_t)( cts_scratch.u64[1] ^ t_buf.u64[1] );
-
- /* CC <- E(Key2,PP) */
- mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
-
- /* C <- T xor CC */
- outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
- outbuf[nblk - 1].u64[1] = (uint64_t)( scratch.u64[1] ^ t_buf.u64[1] );
+ output[i] = prev_output[i];
+ tmp[i] = input[i] ^ t[i];
}
- else /* AES_DECRYPT */
- {
- mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
-decrypt_only_one_full_block:
- cts_t_buf.u64[0] = t_buf.u64[0];
- cts_t_buf.u64[1] = t_buf.u64[1];
+ /* Copy ciphertext bytes from the previous block for input in this
+ * round. */
+ for( ; i < 16; i++ )
+ tmp[i] = prev_output[i] ^ t[i];
- mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
+ ret = mbedtls_aes_crypt_ecb( &ctx->crypt, mode, tmp, tmp );
+ if( ret != 0 )
+ return ret;
- /* PP <- T xor P */
- scratch.u64[0] = (uint64_t)( inbuf[nblk - 1].u64[0] ^ t_buf.u64[0] );
- scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ t_buf.u64[1] );
-
- /* CC <- E(Key2,PP) */
- mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
-
- /* C <- T xor CC */
- cts_scratch.u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
- cts_scratch.u64[1] = (uint64_t)( scratch.u64[1] ^ t_buf.u64[1] );
-
-
- memcpy( (uint8_t*)&inbuf[nblk - 1], (uint8_t*)&inbuf[nblk], remn );
- memcpy( (uint8_t*)&inbuf[nblk - 1] + remn, cts_scratch.u8 + remn, 16 - remn );
- memcpy( (uint8_t*)&outbuf[nblk], cts_scratch.u8, remn );
-
-
- /* PP <- T xor P */
- scratch.u64[0] = (uint64_t)( inbuf[nblk - 1].u64[0] ^ cts_t_buf.u64[0] );
- scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ cts_t_buf.u64[1] );
-
- /* CC <- E(Key2,PP) */
- mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
-
- /* C <- T xor CC */
- outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ cts_t_buf.u64[0] );
- outbuf[nblk - 1].u64[1] = (uint64_t)( scratch.u64[1] ^ cts_t_buf.u64[1] );
- }
+ /* Write the result back to the previous block, overriding the previous
+ * output we copied. */
+ for( i = 0; i < 16; i++ )
+ prev_output[i] = tmp[i] ^ t[i];
}
return( 0 );