[crypto] Add support for single-part AEAD operation (#40984)

This commit adds support for single-part AEAD operations.

Signed-off-by: Łukasz Duda <lukasz.duda@nordicsemi.no>
diff --git a/config/nrfconnect/chip-module/CMakeLists.txt b/config/nrfconnect/chip-module/CMakeLists.txt
index 8d212c9..881e532 100644
--- a/config/nrfconnect/chip-module/CMakeLists.txt
+++ b/config/nrfconnect/chip-module/CMakeLists.txt
@@ -191,8 +191,9 @@
 endif()
 
 if (CONFIG_CHIP_CRYPTO_PSA)
-    matter_add_gn_arg_string("chip_crypto"             "psa")
-    matter_add_gn_arg_bool  ("chip_crypto_psa_spake2p" CONFIG_PSA_WANT_ALG_SPAKE2P_MATTER)
+    matter_add_gn_arg_string("chip_crypto"                      "psa")
+    matter_add_gn_arg_bool  ("chip_crypto_psa_spake2p"          CONFIG_PSA_WANT_ALG_SPAKE2P_MATTER)
+    matter_add_gn_arg_bool  ("chip_crypto_psa_aead_single_part" CONFIG_CHIP_CRYPTO_PSA_AEAD_SINGLE_PART)
 endif()
 
 if (BOARD STREQUAL "native_sim")
diff --git a/config/zephyr/Kconfig b/config/zephyr/Kconfig
index 3a6f906..583ed9d 100644
--- a/config/zephyr/Kconfig
+++ b/config/zephyr/Kconfig
@@ -282,6 +282,13 @@
 	  based on the PSA crypto API (instead of the default implementation, which
 	  is based on the legacy mbedTLS APIs).
 
+config CHIP_CRYPTO_PSA_AEAD_SINGLE_PART
+    bool "Use PSA AEAD single-part API"
+    depends on CHIP_CRYPTO_PSA
+    help
+      When enabled, the single-part AEAD API is used with a large buffer
+      allocated on the stack. Otherwise, the multipart API path is used.
+
 config CHIP_PERSISTENT_SUBSCRIPTIONS
 	bool "Persistent subscriptions"
 	help
diff --git a/src/crypto/BUILD.gn b/src/crypto/BUILD.gn
index 7461268..26e8829 100644
--- a/src/crypto/BUILD.gn
+++ b/src/crypto/BUILD.gn
@@ -67,6 +67,7 @@
     "CHIP_CRYPTO_MBEDTLS=${chip_crypto_mbedtls}",
     "CHIP_CRYPTO_PSA=${chip_crypto_psa}",
     "CHIP_CRYPTO_PSA_SPAKE2P=${chip_crypto_psa_spake2p}",
+    "CHIP_CRYPTO_PSA_AEAD_SINGLE_PART=${chip_crypto_psa_aead_single_part}",
     "CHIP_CRYPTO_KEYSTORE_PSA=${chip_crypto_keystore_psa}",
     "CHIP_CRYPTO_KEYSTORE_RAW=${chip_crypto_keystore_raw}",
     "CHIP_CRYPTO_KEYSTORE_APP=${chip_crypto_keystore_app}",
diff --git a/src/crypto/CHIPCryptoPALPSA.cpp b/src/crypto/CHIPCryptoPALPSA.cpp
index 7aa33c1..b53a588 100644
--- a/src/crypto/CHIPCryptoPALPSA.cpp
+++ b/src/crypto/CHIPCryptoPALPSA.cpp
@@ -43,6 +43,12 @@
 #include <string.h>
 #include <type_traits>
 
+#if CHIP_CRYPTO_PSA_AEAD_SINGLE_PART
+constexpr size_t kAeadMaxPlaintextSize = CHIP_CONFIG_DEFAULT_UDP_MTU_SIZE;
+constexpr size_t kAeadMaxTagSize       = 16;
+constexpr size_t kAeadTempBufferSize   = kAeadMaxPlaintextSize + kAeadMaxTagSize;
+#endif
+
 namespace chip {
 namespace Crypto {
 
@@ -71,9 +77,30 @@
 
     const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG(PSA_ALG_CCM, tag_length);
     psa_status_t status             = PSA_SUCCESS;
-    psa_aead_operation_t operation  = PSA_AEAD_OPERATION_INIT;
     size_t out_length               = 0;
-    size_t tag_out_length           = 0;
+
+#if CHIP_CRYPTO_PSA_AEAD_SINGLE_PART
+    // Temporary buffer is required since during encryption `plaintext` and `tag` are encrypted as one unit
+    // and then `ciphertext` and `tag` are separated out. `ciphertext` is only guranteed to have
+    // `plaintext_length` size and we need a buffer of `plaintext_length` + `tag_length` for output.
+    uint8_t temp_buf[kAeadTempBufferSize];
+
+    VerifyOrReturnError(plaintext_length + tag_length <= kAeadTempBufferSize, CHIP_ERROR_INVALID_ARGUMENT);
+
+    status = psa_aead_encrypt(key.As<psa_key_id_t>(), algorithm, nonce, nonce_length, aad, aad_length, plaintext, plaintext_length,
+                              temp_buf, sizeof(temp_buf), &out_length);
+
+    VerifyOrReturnError(status == PSA_SUCCESS && out_length == plaintext_length + tag_length, CHIP_ERROR_INTERNAL);
+
+    if (plaintext_length)
+    {
+        memcpy(ciphertext, temp_buf, plaintext_length);
+    }
+
+    memcpy(tag, temp_buf + plaintext_length, tag_length);
+#else
+    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
+    size_t tag_out_length          = 0;
 
     status = psa_aead_encrypt_setup(&operation, key.As<psa_key_id_t>(), algorithm);
     VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
@@ -149,6 +176,7 @@
         status = psa_aead_finish(&operation, nullptr, 0, &out_length, tag, tag_length, &tag_out_length);
     }
     VerifyOrReturnError(status == PSA_SUCCESS && tag_length == tag_out_length, CHIP_ERROR_INTERNAL);
+#endif
 
     return CHIP_NO_ERROR;
 }
@@ -164,8 +192,29 @@
 
     const psa_algorithm_t algorithm = PSA_ALG_AEAD_WITH_SHORTENED_TAG(PSA_ALG_CCM, tag_length);
     psa_status_t status             = PSA_SUCCESS;
-    psa_aead_operation_t operation  = PSA_AEAD_OPERATION_INIT;
-    size_t outLength                = 0;
+    size_t out_length               = 0;
+
+#if CHIP_CRYPTO_PSA_AEAD_SINGLE_PART
+    // Temporary buffer is required since during decryption `ciphertext` and `tag` are decrypted as one unit.
+    // `ciphertext` is only guranteed to have `ciphertext_length` size and we need a buffer of
+    // `ciphertext_length` + `tag_length` for input.
+    uint8_t temp_buf[kAeadTempBufferSize];
+
+    VerifyOrReturnError(ciphertext_length + tag_length <= kAeadTempBufferSize, CHIP_ERROR_INVALID_ARGUMENT);
+
+    if (ciphertext_length)
+    {
+        memcpy(temp_buf, ciphertext, ciphertext_length);
+    }
+
+    memcpy(temp_buf + ciphertext_length, tag, tag_length);
+
+    status = psa_aead_decrypt(key.As<psa_key_id_t>(), algorithm, nonce, nonce_length, aad, aad_length, temp_buf,
+                              ciphertext_length + tag_length, plaintext, ciphertext_length, &out_length);
+
+    VerifyOrReturnError(status == PSA_SUCCESS && out_length == ciphertext_length, CHIP_ERROR_INTERNAL);
+#else
+    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
 
     status = psa_aead_decrypt_setup(&operation, key.As<psa_key_id_t>(), algorithm);
     VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
@@ -200,10 +249,10 @@
 
     if (block_aligned_length > 0)
     {
-        status = psa_aead_update(&operation, ciphertext, block_aligned_length, plaintext, block_aligned_length, &outLength);
+        status = psa_aead_update(&operation, ciphertext, block_aligned_length, plaintext, block_aligned_length, &out_length);
         VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
 
-        plaintext += outLength;
+        plaintext += out_length;
     }
 
     if (partial_block_length > 0)
@@ -213,30 +262,31 @@
 
         VerifyOrReturnError(rounded_up_length <= sizeof(temp_buffer), CHIP_ERROR_BUFFER_TOO_SMALL);
 
-        outLength = 0;
-        status    = psa_aead_update(&operation, ciphertext + block_aligned_length, partial_block_length, &temp_buffer[0],
-                                    rounded_up_length, &outLength);
+        out_length = 0;
+        status     = psa_aead_update(&operation, ciphertext + block_aligned_length, partial_block_length, &temp_buffer[0],
+                                     rounded_up_length, &out_length);
         VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
 
-        VerifyOrReturnError(partial_block_length == outLength, CHIP_ERROR_INTERNAL);
+        VerifyOrReturnError(partial_block_length == out_length, CHIP_ERROR_INTERNAL);
         memcpy(plaintext, &temp_buffer[0], partial_block_length);
 
-        plaintext += outLength;
+        plaintext += out_length;
     }
 
     if (ciphertext_length != 0)
     {
-        outLength = 0;
-        status = psa_aead_verify(&operation, plaintext, PSA_AEAD_VERIFY_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm), &outLength, tag,
+        out_length = 0;
+        status = psa_aead_verify(&operation, plaintext, PSA_AEAD_VERIFY_OUTPUT_SIZE(PSA_KEY_TYPE_AES, algorithm), &out_length, tag,
                                  tag_length);
     }
     else
     {
-        outLength = 0;
-        status    = psa_aead_verify(&operation, nullptr, 0, &outLength, tag, tag_length);
+        out_length = 0;
+        status     = psa_aead_verify(&operation, nullptr, 0, &out_length, tag, tag_length);
     }
 
     VerifyOrReturnError(status == PSA_SUCCESS, CHIP_ERROR_INTERNAL);
+#endif
 
     return CHIP_NO_ERROR;
 }
diff --git a/src/crypto/crypto.gni b/src/crypto/crypto.gni
index 36c873e..a77ec43 100644
--- a/src/crypto/crypto.gni
+++ b/src/crypto/crypto.gni
@@ -23,6 +23,9 @@
   # Use PSA Spake2+ implementation. Only used if chip_crypto == "psa"
   chip_crypto_psa_spake2p = false
 
+  # Use PSA AEAD single-part implementation. Only used if chip_crypto == "psa"
+  chip_crypto_psa_aead_single_part = false
+
   # Crypto storage: psa, raw, app.
   #   app: includes zero new files and disables the unit tests for the keystore.
   chip_crypto_keystore = ""
diff --git a/src/test_driver/nrfconnect/flash.bin b/src/test_driver/nrfconnect/flash.bin
new file mode 100644
index 0000000..0a3d980
--- /dev/null
+++ b/src/test_driver/nrfconnect/flash.bin
Binary files differ