/* BEGIN_HEADER */
#include "mbedtls/asn1write.h"

#define GUARD_LEN 4
#define GUARD_VAL 0x2a

typedef struct
{
    unsigned char *output;
    unsigned char *start;
    unsigned char *end;
    unsigned char *p;
    size_t size;
} generic_write_data_t;

int generic_write_start_step( generic_write_data_t *data )
{
    mbedtls_test_set_step( data->size );
    mbedtls_free( data->output );
    data->output = NULL;
    ASSERT_ALLOC( data->output, data->size == 0 ? 1 : data->size );
    data->end = data->output + data->size;
    data->p = data->end;
    data->start = data->end - data->size;
    return( 1 );
exit:
    return( 0 );
}

int generic_write_finish_step( generic_write_data_t *data,
                               const data_t *expected, int ret )
{
    int ok = 0;

    if( data->size < expected->len )
    {
        TEST_EQUAL( ret, MBEDTLS_ERR_ASN1_BUF_TOO_SMALL );
    }
    else
    {
        TEST_EQUAL( ret, data->end - data->p );
        TEST_ASSERT( data->p >= data->start );
        TEST_ASSERT( data->p <= data->end );
        ASSERT_COMPARE( data->p, (size_t)( data->end - data->p ),
                        expected->x, expected->len );
    }
    ok = 1;

exit:
    return( ok );
}

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_ASN1_WRITE_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void mbedtls_asn1_write_null( data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_null( &data.p, data.start );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
        /* There's no parsing function for NULL. */
    }

exit:
    mbedtls_free( data.output );
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_asn1_write_bool( int val, data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_bool( &data.p, data.start, val );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
#if defined(MBEDTLS_ASN1_PARSE_C)
        if( ret >= 0 )
        {
            int read = 0xdeadbeef;
            TEST_EQUAL( mbedtls_asn1_get_bool( &data.p, data.end, &read ), 0 );
            TEST_EQUAL( val, read );
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
    }

exit:
    mbedtls_free( data.output );
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_asn1_write_int( int val, data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_int( &data.p, data.start, val );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
#if defined(MBEDTLS_ASN1_PARSE_C)
        if( ret >= 0 )
        {
            int read = 0xdeadbeef;
            TEST_EQUAL( mbedtls_asn1_get_int( &data.p, data.end, &read ), 0 );
            TEST_EQUAL( val, read );
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
    }

exit:
    mbedtls_free( data.output );
}
/* END_CASE */


/* BEGIN_CASE */
void mbedtls_asn1_write_enum( int val, data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_enum( &data.p, data.start, val );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
#if defined(MBEDTLS_ASN1_PARSE_C)
        if( ret >= 0 )
        {
            int read = 0xdeadbeef;
            TEST_EQUAL( mbedtls_asn1_get_enum( &data.p, data.end, &read ), 0 );
            TEST_EQUAL( val, read );
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
    }

exit:
    mbedtls_free( data.output );
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_BIGNUM_C */
void mbedtls_asn1_write_mpi( data_t *val, data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    mbedtls_mpi mpi, read;
    int ret;

    mbedtls_mpi_init( &mpi );
    mbedtls_mpi_init( &read );
    TEST_ASSERT( mbedtls_mpi_read_binary( &mpi, val->x, val->len ) == 0 );

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_mpi( &data.p, data.start, &mpi );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
#if defined(MBEDTLS_ASN1_PARSE_C)
        if( ret >= 0 )
        {
            TEST_EQUAL( mbedtls_asn1_get_mpi( &data.p, data.end, &read ), 0 );
            TEST_EQUAL( 0, mbedtls_mpi_cmp_mpi( &mpi, &read ) );
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
        /* Skip some intermediate lengths, they're boring. */
        if( expected->len > 10 && data.size == 8 )
            data.size = expected->len - 2;
    }

exit:
    mbedtls_mpi_free( &mpi );
    mbedtls_mpi_free( &read );
    mbedtls_free( data.output );
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_asn1_write_string( int tag, data_t *content, data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        switch( tag )
        {
            case MBEDTLS_ASN1_OCTET_STRING:
                ret = mbedtls_asn1_write_octet_string(
                    &data.p, data.start, content->x, content->len );
                break;
            case MBEDTLS_ASN1_OID:
                ret = mbedtls_asn1_write_oid(
                    &data.p, data.start,
                    (const char *) content->x, content->len );
                break;
            case MBEDTLS_ASN1_UTF8_STRING:
                ret = mbedtls_asn1_write_utf8_string(
                    &data.p, data.start,
                    (const char *) content->x, content->len );
                break;
            case MBEDTLS_ASN1_PRINTABLE_STRING:
                ret = mbedtls_asn1_write_printable_string(
                    &data.p, data.start,
                    (const char *) content->x, content->len );
                break;
            case MBEDTLS_ASN1_IA5_STRING:
                ret = mbedtls_asn1_write_ia5_string(
                    &data.p, data.start,
                    (const char *) content->x, content->len );
                break;
            default:
                ret = mbedtls_asn1_write_tagged_string(
                    &data.p, data.start, tag,
                    (const char *) content->x, content->len );
        }
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
        /* There's no parsing function for octet or character strings. */
        /* Skip some intermediate lengths, they're boring. */
        if( expected->len > 10 && data.size == 8 )
            data.size = expected->len - 2;
    }

exit:
    mbedtls_free( data.output );
}
/* END_CASE */

/* BEGIN_CASE */
void mbedtls_asn1_write_algorithm_identifier( data_t *oid,
                                              int par_len,
                                              data_t *expected )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;
#if defined(MBEDTLS_ASN1_PARSE_C)
    unsigned char *buf_complete = NULL;
#endif /* MBEDTLS_ASN1_PARSE_C */

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = mbedtls_asn1_write_algorithm_identifier(
            &data.p, data.start,
            (const char *) oid->x, oid->len, par_len );
        /* If params_len != 0, mbedtls_asn1_write_algorithm_identifier()
         * assumes that the parameters are already present in the buffer
         * and returns a length that accounts for this, but our test
         * data omits the parameters. */
        if( ret >= 0 )
            ret -= par_len;
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;

#if defined(MBEDTLS_ASN1_PARSE_C)
        /* Only do a parse-back test if the parameters aren't too large for
         * a small-heap environment. The boundary is somewhat arbitrary. */
        if( ret >= 0 && par_len <= 1234 )
        {
            mbedtls_asn1_buf alg = {0, 0, NULL};
            mbedtls_asn1_buf params = {0, 0, NULL};
            /* The writing function doesn't write the parameters unless
             * they're null: it only takes their length as input. But the
             * parsing function requires the parameters to be present.
             * Thus make up parameters. */
            size_t data_len = data.end - data.p;
            size_t len_complete = data_len + par_len;
            unsigned char expected_params_tag;
            size_t expected_params_len;
            ASSERT_ALLOC( buf_complete, len_complete );
            unsigned char *end_complete = buf_complete + len_complete;
            memcpy( buf_complete, data.p, data_len );
            if( par_len == 0 )
            {
                /* mbedtls_asn1_write_algorithm_identifier() wrote a NULL */
                expected_params_tag = 0x05;
                expected_params_len = 0;
            }
            else if( par_len >= 2 && par_len < 2 + 128 )
            {
                /* Write an OCTET STRING with a short length encoding */
                expected_params_tag = buf_complete[data_len] = 0x04;
                expected_params_len = par_len - 2;
                buf_complete[data_len + 1] = (unsigned char) expected_params_len;
            }
            else if( par_len >= 4 + 128 && par_len < 3 + 256 * 256 )
            {
                /* Write an OCTET STRING with a two-byte length encoding */
                expected_params_tag = buf_complete[data_len] = 0x04;
                expected_params_len = par_len - 4;
                buf_complete[data_len + 1] = 0x82;
                buf_complete[data_len + 2] = (unsigned char) ( expected_params_len >> 8 );
                buf_complete[data_len + 3] = (unsigned char) ( expected_params_len );
            }
            else
            {
                TEST_ASSERT( ! "Bad test data: invalid length of ASN.1 element" );
            }
            unsigned char *p = buf_complete;
            TEST_EQUAL( mbedtls_asn1_get_alg( &p, end_complete,
                                              &alg, &params ), 0 );
            TEST_EQUAL( alg.tag, MBEDTLS_ASN1_OID );
            ASSERT_COMPARE( alg.p, alg.len, oid->x, oid->len );
            TEST_EQUAL( params.tag, expected_params_tag );
            TEST_EQUAL( params.len, expected_params_len );
            mbedtls_free( buf_complete );
            buf_complete = NULL;
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
    }

exit:
    mbedtls_free( data.output );
#if defined(MBEDTLS_ASN1_PARSE_C)
    mbedtls_free( buf_complete );
#endif /* MBEDTLS_ASN1_PARSE_C */
}
/* END_CASE */

/* BEGIN_CASE depends_on:MBEDTLS_ASN1_PARSE_C */
void mbedtls_asn1_write_len( int len, data_t * asn1, int buf_len,
                             int result )
{
    int ret;
    unsigned char buf[150];
    unsigned char *p;
    size_t i;
    size_t read_len;

    memset( buf, GUARD_VAL, sizeof( buf ) );

    p = buf + GUARD_LEN + buf_len;

    ret = mbedtls_asn1_write_len( &p, buf + GUARD_LEN, (size_t) len );

    TEST_ASSERT( ret == result );

    /* Check for buffer overwrite on both sides */
    for( i = 0; i < GUARD_LEN; i++ )
    {
        TEST_ASSERT( buf[i] == GUARD_VAL );
        TEST_ASSERT( buf[GUARD_LEN + buf_len + i] == GUARD_VAL );
    }

    if( result >= 0 )
    {
        TEST_ASSERT( p + asn1->len == buf + GUARD_LEN + buf_len );

        TEST_ASSERT( memcmp( p, asn1->x, asn1->len ) == 0 );

        /* Read back with mbedtls_asn1_get_len() to check */
        ret = mbedtls_asn1_get_len( &p, buf + GUARD_LEN + buf_len, &read_len );

        if( len == 0 )
        {
            TEST_ASSERT( ret == 0 );
        }
        else
        {
            /* Return will be MBEDTLS_ERR_ASN1_OUT_OF_DATA because the rest of
             * the buffer is missing
             */
            TEST_ASSERT( ret == MBEDTLS_ERR_ASN1_OUT_OF_DATA );
        }
        TEST_ASSERT( read_len == (size_t) len );
        TEST_ASSERT( p == buf + GUARD_LEN + buf_len );
    }
}
/* END_CASE */

/* BEGIN_CASE */
void test_asn1_write_bitstrings( data_t *bitstring, int bits,
                                 data_t *expected, int is_named )
{
    generic_write_data_t data = { NULL, NULL, NULL, NULL, 0 };
    int ret;
    int ( *func )( unsigned char **p, const unsigned char *start,
                   const unsigned char *buf, size_t bits ) =
        ( is_named ? mbedtls_asn1_write_named_bitstring :
          mbedtls_asn1_write_bitstring );
#if defined(MBEDTLS_ASN1_PARSE_C)
    unsigned char *masked_bitstring = NULL;
#endif /* MBEDTLS_ASN1_PARSE_C */

    /* The API expects `bitstring->x` to contain `bits` bits. */
    size_t byte_length = ( bits + 7 ) / 8;
    TEST_ASSERT( bitstring->len >= byte_length );

#if defined(MBEDTLS_ASN1_PARSE_C)
    ASSERT_ALLOC( masked_bitstring, byte_length );
    if( byte_length != 0 )
    {
        memcpy( masked_bitstring, bitstring->x, byte_length );
        if( bits % 8 != 0 )
            masked_bitstring[byte_length - 1] &= ~( 0xff >> ( bits % 8 ) );
    }
    size_t value_bits = bits;
    if( is_named )
    {
        /* In a named bit string, all trailing 0 bits are removed. */
        while( byte_length > 0 && masked_bitstring[byte_length - 1] == 0 )
            --byte_length;
        value_bits = 8 * byte_length;
        if( byte_length > 0 )
        {
            unsigned char last_byte = masked_bitstring[byte_length - 1];
            for( unsigned b = 1; b < 0xff && ( last_byte & b ) == 0; b <<= 1 )
                --value_bits;
        }
    }
#endif /* MBEDTLS_ASN1_PARSE_C */

    for( data.size = 0; data.size <= expected->len + 1; data.size++ )
    {
        if( ! generic_write_start_step( &data ) )
            goto exit;
        ret = ( *func )( &data.p, data.start, bitstring->x, bits );
        if( ! generic_write_finish_step( &data, expected, ret ) )
            goto exit;
#if defined(MBEDTLS_ASN1_PARSE_C)
        if( ret >= 0 )
        {
            mbedtls_asn1_bitstring read = {0, 0, NULL};
            TEST_EQUAL( mbedtls_asn1_get_bitstring( &data.p, data.end,
                                                    &read ), 0 );
            ASSERT_COMPARE( read.p, read.len,
                            masked_bitstring, byte_length );
            TEST_EQUAL( read.unused_bits, 8 * byte_length - value_bits );
        }
#endif /* MBEDTLS_ASN1_PARSE_C */
    }

exit:
    mbedtls_free( data.output );
#if defined(MBEDTLS_ASN1_PARSE_C)
    mbedtls_free( masked_bitstring );
#endif /* MBEDTLS_ASN1_PARSE_C */
}
/* END_CASE */

/* BEGIN_CASE */
void store_named_data_find( data_t *oid0, data_t *oid1,
                            data_t *oid2, data_t *oid3,
                            data_t *needle, int from, int position )
{
    data_t *oid[4] = {oid0, oid1, oid2, oid3};
    mbedtls_asn1_named_data nd[] ={
        { {0x06, 0, NULL}, {0, 0, NULL}, NULL, 0 },
        { {0x06, 0, NULL}, {0, 0, NULL}, NULL, 0 },
        { {0x06, 0, NULL}, {0, 0, NULL}, NULL, 0 },
        { {0x06, 0, NULL}, {0, 0, NULL}, NULL, 0 },
    };
    mbedtls_asn1_named_data *pointers[ARRAY_LENGTH( nd ) + 1];
    size_t i;
    mbedtls_asn1_named_data *head = NULL;
    mbedtls_asn1_named_data *found = NULL;

    for( i = 0; i < ARRAY_LENGTH( nd ); i++ )
        pointers[i] = &nd[i];
    pointers[ARRAY_LENGTH( nd )] = NULL;
    for( i = 0; i < ARRAY_LENGTH( nd ); i++ )
    {
        ASSERT_ALLOC( nd[i].oid.p, oid[i]->len );
        memcpy( nd[i].oid.p, oid[i]->x, oid[i]->len );
        nd[i].oid.len = oid[i]->len;
        nd[i].next = pointers[i+1];
    }

    head = pointers[from];
    found = mbedtls_asn1_store_named_data( &head,
                                           (const char *) needle->x,
                                           needle->len,
                                           NULL, 0 );

    /* In any case, the existing list structure must be unchanged. */
    for( i = 0; i < ARRAY_LENGTH( nd ); i++ )
        TEST_ASSERT( nd[i].next == pointers[i+1] );

    if( position >= 0 )
    {
        /* position should have been found and modified. */
        TEST_ASSERT( head == pointers[from] );
        TEST_ASSERT( found == pointers[position] );
    }
    else
    {
        /* A new entry should have been created. */
        TEST_ASSERT( found == head );
        TEST_ASSERT( head->next == pointers[from] );
        for( i = 0; i < ARRAY_LENGTH( nd ); i++ )
            TEST_ASSERT( found != &nd[i] );
    }

exit:
    if( found != NULL && found == head && found != pointers[from] )
    {
        mbedtls_free( found->oid.p );
        mbedtls_free( found );
    }
    for( i = 0; i < ARRAY_LENGTH( nd ); i++ )
        mbedtls_free( nd[i].oid.p );
}
/* END_CASE */

/* BEGIN_CASE */
void store_named_data_val_found( int old_len, int new_len )
{
    mbedtls_asn1_named_data nd =
        { {0x06, 3, (unsigned char *) "OID"}, {0, 0, NULL}, NULL, 0 };
    mbedtls_asn1_named_data *head = &nd;
    mbedtls_asn1_named_data *found = NULL;
    unsigned char *old_val = NULL;
    unsigned char *new_val = (unsigned char *) "new value";

    if( old_len != 0 )
    {
        ASSERT_ALLOC( nd.val.p, (size_t) old_len );
        old_val = nd.val.p;
        nd.val.len = old_len;
        memset( old_val, 'x', old_len );
    }
    if( new_len <= 0 )
    {
        new_len = - new_len;
        new_val = NULL;
    }

    found = mbedtls_asn1_store_named_data( &head, "OID", 3,
                                           new_val, new_len );
    TEST_ASSERT( head == &nd );
    TEST_ASSERT( found == head );

    if( new_val != NULL)
        ASSERT_COMPARE( found->val.p, found->val.len,
                        new_val, (size_t) new_len );
    if( new_len == 0)
        TEST_ASSERT( found->val.p == NULL );
    else if( new_len == old_len )
        TEST_ASSERT( found->val.p == old_val );
    else
        TEST_ASSERT( found->val.p != old_val );

exit:
    mbedtls_free( nd.val.p );
}
/* END_CASE */

/* BEGIN_CASE */
void store_named_data_val_new( int new_len, int set_new_val )
{
    mbedtls_asn1_named_data *head = NULL;
    mbedtls_asn1_named_data *found = NULL;
    const unsigned char *oid = (unsigned char *) "OID";
    size_t oid_len = strlen( (const char *) oid );
    const unsigned char *new_val = (unsigned char *) "new value";

    if( set_new_val == 0 )
        new_val = NULL;

    found = mbedtls_asn1_store_named_data( &head,
                                           (const char *) oid, oid_len,
                                           new_val, (size_t) new_len );
    TEST_ASSERT( found != NULL );
    TEST_ASSERT( found == head );
    TEST_ASSERT( found->oid.p != oid );
    ASSERT_COMPARE( found->oid.p, found->oid.len, oid, oid_len );
    if( new_len == 0 )
        TEST_ASSERT( found->val.p == NULL );
    else if( new_val == NULL )
        TEST_ASSERT( found->val.p != NULL );
    else
    {
        TEST_ASSERT( found->val.p != new_val );
        ASSERT_COMPARE( found->val.p, found->val.len,
                        new_val, (size_t) new_len );
    }

exit:
    if( found != NULL )
    {
        mbedtls_free( found->oid.p );
        mbedtls_free( found->val.p );
    }
    mbedtls_free( found );
}
/* END_CASE */
