Track arg defaults in ArgTypeSelector
This ensures that arg types that require defaults (e.g. bool) always
have one defined. This also simplifies the default logic.
Also add more converters and extend ArgMap to reduce boiler plate.
Change-Id: I3af3a4434e057754a8fa5960ee854c5f6f8b93e3
Reviewed-on: https://pigweed-review.googlesource.com/c/open-dice/+/260071
Reviewed-by: Andrew Scull <ascull@google.com>
Lint: Lint 🤖 <android-build-ayeaye@system.gserviceaccount.com>
Commit-Queue: Darren Krahn <dkrahn@google.com>
Reviewed-by: Marshall Pierce <marshallpierce@google.com>
diff --git a/dpe-rs/src/args.rs b/dpe-rs/src/args.rs
index 92c090e..8a964c7 100644
--- a/dpe-rs/src/args.rs
+++ b/dpe-rs/src/args.rs
@@ -28,16 +28,34 @@
/// Indicates an argument was not recognized, so its type is unknown.
#[default]
Unknown,
- /// Indicates an argument is encoded as a CBOR byte string.
+ /// Indicates an argument is encoded as a CBOR byte string. The default
+ /// value is an empty slice.
Bytes,
- /// Indicates an argument is encoded as a CBOR unsigned integer.
+ /// Indicates an argument is encoded as a CBOR unsigned integer. The
+ /// default value is zero.
Int,
/// Indicates an argument is encoded as a CBOR true or false simple value.
- Bool,
- /// Indicates an argument needs additional custom decoding.
+ /// The value holds a default value.
+ Bool(bool),
+ /// Indicates an argument is encoded as a generic CBOR data item and needs
+ /// additional custom decoding. The default value is an empty slice.
Other,
}
+impl<'a> ArgTypeSelector {
+ pub(crate) fn default_value(self) -> DpeResult<ArgValue<'a>> {
+ match self {
+ Self::Bool(value) => Ok(ArgValue::Bool(value)),
+ Self::Int => Ok(ArgValue::Int(0)),
+ Self::Bytes | Self::Other => Ok(ArgValue::Bytes(&[])),
+ Self::Unknown => {
+ error!("No default for unknown types");
+ Err(ErrCode::InternalError)
+ }
+ }
+ }
+}
+
/// Represents a command or response argument value.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) enum ArgValue<'a> {
@@ -51,88 +69,100 @@
Bool(bool),
}
-impl<'a> ArgValue<'a> {
- /// Creates a new `Bytes` from a slice, borrowing the slice.
- pub(crate) fn from_slice(value: &'a [u8]) -> Self {
- ArgValue::Bytes(value)
- }
+impl TryFrom<&ArgValue<'_>> for bool {
+ type Error = ErrCode;
- /// Returns the borrowed slice if this is a Bytes.
- ///
- /// # Errors
- ///
- /// Returns an InternalError error if this is not a Bytes.
- pub(crate) fn try_into_slice(&self) -> DpeResult<&'a [u8]> {
- match self {
- ArgValue::Int(_) | ArgValue::Bool(_) => {
- error!("ArgValue::try_info_slice called on {:?}", self);
- Err(ErrCode::InternalError)
- }
- ArgValue::Bytes(value) => Ok(value),
- }
- }
-
- /// Returns the value held by an Int as a u32.
- ///
- /// # Errors
- ///
- /// Returns an InternalError error if this is not an Int.
- pub(crate) fn try_into_u32(&self) -> DpeResult<u32> {
- match self {
- ArgValue::Int(i) => Ok((*i).try_into()?),
- _ => {
- error!("ArgValue::try_into_u32 called on {:?}", self);
- Err(ErrCode::InternalError)
- }
- }
- }
-
- /// Creates a new `Int` holding the given u32 `value`.
- pub(crate) fn from_u32(value: u32) -> Self {
- ArgValue::Int(value as u64)
- }
-
- /// Returns the value held by an Int as a u64.
- ///
- /// # Errors
- ///
- /// Returns an InternalError error if this is not an Int.
- pub(crate) fn try_into_u64(&self) -> DpeResult<u64> {
- match self {
- ArgValue::Int(i) => Ok(*i),
- _ => {
- error!("ArgValue::try_into_u64 called on {:?}", self);
- Err(ErrCode::InternalError)
- }
- }
- }
-
- /// Creates a new `Int` holding the given u64 `value`.
- pub(crate) fn from_u64(value: u64) -> Self {
- ArgValue::Int(value)
- }
-
- /// Returns the value held by a Bool.
- ///
- /// # Errors
- ///
- /// Returns an InternalError error if this is not a Bool.
- pub(crate) fn try_into_bool(&self) -> DpeResult<bool> {
- match self {
+ fn try_from(value: &ArgValue) -> Result<Self, Self::Error> {
+ match value {
ArgValue::Bool(b) => Ok(*b),
_ => {
- error!("ArgValue::try_into_bool called on {:?}", self);
+ error!("{:?} cannot be converted to bool", value);
Err(ErrCode::InternalError)
}
}
}
+}
+impl TryFrom<&ArgValue<'_>> for u32 {
+ type Error = ErrCode;
- /// Creates a new `Bool` holding the given `value`.
- pub(crate) fn from_bool(value: bool) -> Self {
+ fn try_from(value: &ArgValue) -> Result<Self, Self::Error> {
+ match value {
+ ArgValue::Int(i) => Ok((*i).try_into()?),
+ _ => {
+ error!("{:?} cannot be converted to u32", value);
+ Err(ErrCode::InternalError)
+ }
+ }
+ }
+}
+
+impl TryFrom<&ArgValue<'_>> for u64 {
+ type Error = ErrCode;
+
+ fn try_from(value: &ArgValue) -> Result<Self, Self::Error> {
+ match value {
+ ArgValue::Int(i) => Ok(*i),
+ _ => {
+ error!("{:?} cannot be converted to u64", value);
+ Err(ErrCode::InternalError)
+ }
+ }
+ }
+}
+
+impl TryFrom<&ArgValue<'_>> for usize {
+ type Error = ErrCode;
+
+ fn try_from(value: &ArgValue) -> Result<Self, Self::Error> {
+ let tmp: u32 = value.try_into()?;
+ tmp.try_into().or(Err(ErrCode::InternalError))
+ }
+}
+
+impl<'a> TryFrom<&ArgValue<'a>> for &'a [u8] {
+ type Error = ErrCode;
+
+ fn try_from(value: &ArgValue<'a>) -> Result<Self, Self::Error> {
+ match value {
+ ArgValue::Int(_) | ArgValue::Bool(_) => {
+ error!("{:?} cannot be converted to a slice", value);
+ Err(ErrCode::InternalError)
+ }
+ ArgValue::Bytes(bytes) => Ok(bytes),
+ }
+ }
+}
+
+impl<'a> From<bool> for ArgValue<'a> {
+ fn from(value: bool) -> Self {
ArgValue::Bool(value)
}
}
+impl<'a> From<u32> for ArgValue<'a> {
+ fn from(value: u32) -> Self {
+ ArgValue::Int(value.into())
+ }
+}
+
+impl<'a> From<u64> for ArgValue<'a> {
+ fn from(value: u64) -> Self {
+ ArgValue::Int(value)
+ }
+}
+
+impl<'a> From<usize> for ArgValue<'a> {
+ fn from(value: usize) -> Self {
+ ArgValue::Int(value as u64)
+ }
+}
+
+impl<'a> From<&'a [u8]> for ArgValue<'a> {
+ fn from(value: &'a [u8]) -> Self {
+ ArgValue::Bytes(value)
+ }
+}
+
impl<'a, const S: usize> From<&'a SizedMessage<S>> for ArgValue<'a> {
fn from(message: &'a SizedMessage<S>) -> Self {
Self::Bytes(message.as_slice())
@@ -146,3 +176,35 @@
/// Contains a set of argument types in the form of a map from ArgId to
/// [`ArgTypeSelector`].
pub(crate) type ArgTypeMap = FnvIndexMap<ArgId, ArgTypeSelector, 16>;
+
+pub(crate) trait ArgMapExt<'a> {
+ /// Get the value for `id` as type `V`, or [`ErrCode::InternalError`] if
+ /// not found.
+ fn get_or_err<V>(&self, id: ArgId) -> DpeResult<V>
+ where
+ for<'b> &'b ArgValue<'a>: TryInto<V, Error = ErrCode>;
+
+ /// Insert a value with type 'V' for `id`, or [`ErrCode::OutOfMemory`] if
+ /// not possible.
+ fn insert_or_err<V>(&mut self, id: ArgId, value: V) -> DpeResult<()>
+ where
+ ArgValue<'a>: From<V>;
+}
+
+impl<'a> ArgMapExt<'a> for ArgMap<'a> {
+ fn get_or_err<V>(&self, id: ArgId) -> DpeResult<V>
+ where
+ for<'b> &'b ArgValue<'a>: TryInto<V, Error = ErrCode>,
+ {
+ self.get(&id).ok_or(ErrCode::InternalError)?.try_into()
+ }
+
+ fn insert_or_err<V>(&mut self, id: ArgId, value: V) -> DpeResult<()>
+ where
+ ArgValue<'a>: From<V>,
+ {
+ let _ =
+ self.insert(id, value.into()).map_err(|_| ErrCode::OutOfMemory)?;
+ Ok(())
+ }
+}
diff --git a/dpe-rs/src/encode.rs b/dpe-rs/src/encode.rs
index 15b4e7a..581f2ad 100644
--- a/dpe-rs/src/encode.rs
+++ b/dpe-rs/src/encode.rs
@@ -14,7 +14,7 @@
//! Types and functions for encoding and decoding command/response messages.
-use crate::args::{ArgId, ArgMap, ArgTypeMap, ArgTypeSelector, ArgValue};
+use crate::args::{ArgMap, ArgTypeMap, ArgTypeSelector, ArgValue};
use crate::byte_array_wrapper;
use crate::cbor::{
cbor_decoder_from_message, cbor_encoder_from_message, encode_bytes_prefix,
@@ -27,7 +27,7 @@
DiceInputConfig, DiceInputMode, InternalInputType, Uds,
};
use crate::error::{DpeResult, ErrCode};
-use crate::memory::{Message, SizedMessage};
+use crate::memory::{Message, SmallMessage};
use heapless::Vec;
use log::{debug, error};
use minicbor::Decoder;
@@ -68,10 +68,6 @@
}
}
-/// A message type with a smaller buffer. This saves memory when we're confident
-/// the contents will fit.
-pub(crate) type SmallMessage = SizedMessage<DPE_MAX_SMALL_MESSAGE_SIZE>;
-
/// A Vec wrapper to represent a certificate chain.
#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
pub(crate) struct CertificateChain(
@@ -593,7 +589,7 @@
let _ = encoder.bool(*value)?;
}
_ => {
- let _ = encoder.bytes(arg_value.try_into_slice()?)?;
+ let _ = encoder.bytes(arg_value.try_into()?)?;
}
}
}
@@ -602,7 +598,7 @@
pub(crate) fn decode_args_internal<'a>(
encoded_args: &'a [u8],
- arg_types: &[(ArgId, ArgTypeSelector)],
+ arg_types: &ArgTypeMap,
) -> DpeResult<ArgMap<'a>> {
debug!("Decoding arguments of type: {:?}", arg_types);
let mut decoder = Decoder::new(encoded_args);
@@ -617,33 +613,27 @@
}
Ok(Some(num)) => num,
};
- let mut arg_types_map: ArgTypeMap = Default::default();
- for (id, value) in arg_types {
- let _ = arg_types_map
- .insert(*id, *value)
- .map_err(|_| ErrCode::OutOfMemory)?;
- }
let mut args: ArgMap = Default::default();
for _ in 0..num_pairs {
let arg_id = decoder.u32()?;
- match arg_types_map.get(&arg_id).cloned().unwrap_or_default() {
+ match arg_types.get(&arg_id).cloned().unwrap_or_default() {
ArgTypeSelector::Unknown => {
error!("Unknown argument id");
return Err(ErrCode::InvalidArgument);
}
ArgTypeSelector::Bytes => {
let _ = args
- .insert(arg_id, ArgValue::from_slice(decoder.bytes()?))
+ .insert(arg_id, ArgValue::Bytes(decoder.bytes()?))
.map_err(|_| ErrCode::OutOfMemory)?;
}
ArgTypeSelector::Int => {
let _ = args
- .insert(arg_id, ArgValue::from_u64(decoder.u64()?))
+ .insert(arg_id, ArgValue::Int(decoder.u64()?))
.map_err(|_| ErrCode::OutOfMemory)?;
}
- ArgTypeSelector::Bool => {
+ ArgTypeSelector::Bool(_) => {
let _ = args
- .insert(arg_id, ArgValue::from_bool(decoder.bool()?))
+ .insert(arg_id, ArgValue::Bool(decoder.bool()?))
.map_err(|_| ErrCode::OutOfMemory)?;
}
ArgTypeSelector::Other => {
@@ -653,11 +643,10 @@
let _ = args
.insert(
arg_id,
- ArgValue::from_slice(
- encoded_args
- .get(start..end)
- .ok_or(ErrCode::InvalidCommand)?,
- ),
+ encoded_args
+ .get(start..end)
+ .ok_or(ErrCode::InvalidCommand)?
+ .into(),
)
.map_err(|_| ErrCode::OutOfMemory)?;
}
@@ -678,30 +667,14 @@
/// * Returns `InternalError` if a default value is missing.
pub(crate) fn decode_args<'a>(
encoded_args: &'a [u8],
- arg_types: &[(ArgId, ArgTypeSelector)],
- arg_defaults: &[(ArgId, ArgValue<'a>)],
+ arg_types: &ArgTypeMap,
) -> DpeResult<ArgMap<'a>> {
let mut args = decode_args_internal(encoded_args, arg_types)?;
- for (arg_id, default_value) in arg_defaults {
- if !args.contains_key(arg_id) {
- let _ = args
- .insert(*arg_id, default_value.clone())
- .map_err(|_| ErrCode::OutOfMemory)?;
- }
- }
for (arg_id, arg_type) in arg_types {
if !args.contains_key(arg_id) {
- if *arg_type == ArgTypeSelector::Bytes
- || *arg_type == ArgTypeSelector::Other
- {
- // The default value for any byte array is the empty array.
- let _ = args
- .insert(*arg_id, ArgValue::from_slice(&[]))
- .map_err(|_| ErrCode::OutOfMemory)?;
- } else {
- error!("No default value found for argument {}", arg_id);
- return Err(ErrCode::InternalError);
- }
+ let _ = args
+ .insert(*arg_id, arg_type.default_value()?)
+ .map_err(|_| ErrCode::OutOfMemory)?;
}
}
Ok(args)
diff --git a/dpe-rs/src/encode_test.rs b/dpe-rs/src/encode_test.rs
index 681febf..8fd903f 100644
--- a/dpe-rs/src/encode_test.rs
+++ b/dpe-rs/src/encode_test.rs
@@ -14,7 +14,7 @@
//! Tests for the encode module.
-use crate::args::{ArgId, ArgMap, ArgTypeSelector, ArgValue};
+use crate::args::{ArgMap, ArgTypeMap, ArgTypeSelector, ArgValue};
use crate::cbor::tokenize_cbor_for_debug;
use crate::cbor::{
cbor_decoder_from_message, cbor_encoder_from_message, encode_bytes_prefix,
@@ -27,7 +27,7 @@
};
use crate::encode::*;
use crate::error::{DpeResult, ErrCode};
-use crate::memory::Message;
+use crate::memory::{Message, SmallMessage};
use heapless::Vec;
use log::debug;
use minicbor::Decoder;
@@ -69,7 +69,7 @@
#[allow(clippy::unwrap_used)]
pub(crate) fn decode_response_for_testing<'a>(
message: &'a Message,
- arg_types: &[(ArgId, ArgTypeSelector)],
+ arg_types: &ArgTypeMap,
) -> DpeResult<ArgMap<'a>> {
debug!("Raw response: {:?}", tokenize_cbor_for_debug(message.as_slice()));
let mut decoder = cbor_decoder_from_message(message);
@@ -893,50 +893,48 @@
let large = [0xAA; 2000];
let arg_map = ArgMap::from_iter(
[
- (1, ArgValue::from_bool(true)),
- (2, ArgValue::from_slice(&empty)),
- (3, ArgValue::from_slice(&small)),
- (4, ArgValue::from_slice(&large)),
- (5, ArgValue::from_u32(5)),
- (6, ArgValue::from_u64(2)),
+ (1, ArgValue::Bool(true)),
+ (2, ArgValue::Bytes(&empty)),
+ (3, ArgValue::Bytes(&small)),
+ (4, ArgValue::Bytes(&large)),
+ (5, ArgValue::Int(5)),
+ (6, ArgValue::Int(2)),
]
.into_iter(),
);
- let arg_types = [
- (4, ArgTypeSelector::Bytes),
- (5, ArgTypeSelector::Int),
- (6, ArgTypeSelector::Int),
- (1, ArgTypeSelector::Bool),
- (2, ArgTypeSelector::Bytes),
- (3, ArgTypeSelector::Bytes),
- ];
- let arg_defaults = [
- (1, ArgValue::from_bool(false)),
- (5, ArgValue::from_u32(12)),
- (6, ArgValue::from_u64(25000)),
- ];
+ let arg_types = ArgTypeMap::from_iter(
+ [
+ (4, ArgTypeSelector::Bytes),
+ (5, ArgTypeSelector::Int),
+ (6, ArgTypeSelector::Int),
+ (1, ArgTypeSelector::Bool(false)),
+ (2, ArgTypeSelector::Bytes),
+ (3, ArgTypeSelector::Bytes),
+ ]
+ .into_iter(),
+ );
encode_args(&arg_map, &mut buffer).unwrap();
{
let decoded_arg_map =
- decode_args(buffer.as_slice(), &arg_types, &arg_defaults).unwrap();
+ decode_args(buffer.as_slice(), &arg_types).unwrap();
assert_eq!(arg_map, decoded_arg_map);
// Since all fields are defined, defaults are superfluous.
let decoded_arg_map =
- decode_args(buffer.as_slice(), &arg_types, &[]).unwrap();
+ decode_args(buffer.as_slice(), &arg_types).unwrap();
assert_eq!(arg_map, decoded_arg_map);
}
let arg_map_empty = ArgMap::new();
- let mut arg_map_with_defaults =
- ArgMap::from_iter(arg_defaults.clone().into_iter());
- // The Bytes args should default to empty.
- let _ = arg_map_with_defaults.insert(2, ArgValue::from_slice(&[])).unwrap();
- let _ = arg_map_with_defaults.insert(3, ArgValue::from_slice(&[])).unwrap();
- let _ = arg_map_with_defaults.insert(4, ArgValue::from_slice(&[])).unwrap();
+ let arg_map_with_defaults = ArgMap::from_iter(
+ arg_types
+ .clone()
+ .into_iter()
+ .map(|(id, arg_type)| (id, arg_type.default_value().unwrap())),
+ );
encode_args(&arg_map_empty, &mut buffer).unwrap();
{
let decoded_arg_map =
- decode_args(buffer.as_slice(), &arg_types, &arg_defaults).unwrap();
+ decode_args(buffer.as_slice(), &arg_types).unwrap();
assert_eq!(arg_map_with_defaults, decoded_arg_map);
}
}
@@ -946,7 +944,7 @@
// Empty bytes -> not valid CBOR map
assert_eq!(
ErrCode::InvalidCommand,
- decode_args(&[], &[], &[]).unwrap_err()
+ decode_args(&[], &Default::default()).unwrap_err()
);
let mut buffer = Message::new();
// CBOR array -> not valid CBOR map
@@ -957,10 +955,11 @@
.unwrap()
.u16(7)
.unwrap();
- let arg_types = [(1, ArgTypeSelector::Bool)];
+ let arg_types =
+ ArgTypeMap::from_iter([(1, ArgTypeSelector::Bool(false))].into_iter());
assert_eq!(
ErrCode::InvalidCommand,
- decode_args(buffer.as_slice(), &arg_types, &[]).unwrap_err()
+ decode_args(buffer.as_slice(), &arg_types).unwrap_err()
);
buffer.clear();
// CBOR indefinite map -> should be not supported
@@ -971,22 +970,14 @@
.unwrap();
assert_eq!(
ErrCode::InvalidCommand,
- decode_args(buffer.as_slice(), &arg_types, &[]).unwrap_err()
+ decode_args(buffer.as_slice(), &arg_types).unwrap_err()
);
// All args must be represented in arg types
- let unknown_arg =
- ArgMap::from_iter([(17, ArgValue::from_u32(17))].into_iter());
+ let unknown_arg = ArgMap::from_iter([(17, ArgValue::Int(17))].into_iter());
encode_args(&unknown_arg, &mut buffer).unwrap();
assert_eq!(
ErrCode::InvalidArgument,
- decode_args(buffer.as_slice(), &arg_types, &[]).unwrap_err()
- );
- // All known args must be present or have a default value
- let arg_map_empty = ArgMap::new();
- encode_args(&arg_map_empty, &mut buffer).unwrap();
- assert_eq!(
- ErrCode::InternalError,
- decode_args(buffer.as_slice(), &arg_types, &[]).unwrap_err()
+ decode_args(buffer.as_slice(), &arg_types).unwrap_err()
);
}
diff --git a/dpe-rs/src/memory.rs b/dpe-rs/src/memory.rs
index 2911a45..a8d4f1a 100644
--- a/dpe-rs/src/memory.rs
+++ b/dpe-rs/src/memory.rs
@@ -293,3 +293,7 @@
/// Represents a DPE command/response message. This type is large and should not
/// be instantiated unnecessarily.
pub(crate) type Message = SizedMessage<MAX_MESSAGE_SIZE>;
+
+/// A message type with a smaller buffer. This saves memory when we're confident
+/// the contents will fit.
+pub(crate) type SmallMessage = SizedMessage<DPE_MAX_SMALL_MESSAGE_SIZE>;