blob: 25dc041843bde4b26d0bbe10afdecedad0dddcf5 [file] [log] [blame]
// Copyright 2022 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
use std::fmt::Display;
use rustpython::vm;
use serde::{
ser::{self, Error as _},
Serialize,
};
use crate::{Error, Result};
/// Serialize `value` to a Python object.
///
/// Serialize `value` to a [`PyObjectRef`](rustpython::vm::PyObjectRef) using
/// specified `vm`.
pub fn to_python_object<T: Serialize>(
vm: &vm::VirtualMachine,
value: &T,
) -> Result<vm::PyObjectRef> {
let mut ser = Serializer { vm };
value.serialize(&mut ser)
}
pub struct Serializer<'a> {
vm: &'a vm::VirtualMachine,
}
pub struct Sequence<'a> {
vm: &'a vm::VirtualMachine,
values: Vec<vm::PyObjectRef>,
}
pub struct Dict<'a> {
vm: &'a vm::VirtualMachine,
current_key: Option<vm::PyObjectRef>,
dict: vm::builtins::PyDictRef,
}
impl ser::Error for Error {
fn custom<T: Display>(msg: T) -> Self {
Error::Serialization(format!("{}", msg))
}
}
impl<'a> ser::Serializer for &'a mut Serializer<'a> {
// The output type produced by this `Serializer` during successful
// serialization. Most serializers that produce text or binary output should
// set `Ok = ()` and serialize into an `io::Write` or buffer contained
// within the `Serializer` instance, as happens here. Serializers that build
// in-memory data structures may be simplified by using `Ok` to propagate
// the data structure around.
type Ok = vm::PyObjectRef;
// The error type when some error occurs during serialization.
type Error = Error;
// Associated types for keeping track of additional state while serializing
// compound data structures like sequences and maps. In this case no
// additional state is required beyond what is already stored in the
// Serializer struct.
type SerializeSeq = Sequence<'a>;
type SerializeTuple = Sequence<'a>;
type SerializeTupleStruct = Sequence<'a>;
type SerializeTupleVariant = Dict<'a>;
type SerializeMap = Dict<'a>;
type SerializeStruct = Dict<'a>;
type SerializeStructVariant = Dict<'a>;
fn serialize_bool(self, v: bool) -> Result<Self::Ok> {
Ok(self.vm.ctx.new_bool(v).into())
}
fn serialize_i8(self, v: i8) -> Result<Self::Ok> {
self.serialize_i32(i32::from(v))
}
fn serialize_i16(self, v: i16) -> Result<Self::Ok> {
self.serialize_i32(i32::from(v))
}
fn serialize_i32(self, v: i32) -> Result<Self::Ok> {
Ok(self.vm.ctx.new_int(v).into())
}
fn serialize_i64(self, _v: i64) -> Result<Self::Ok> {
Err(Error::custom("BigInts are not supported"))
}
fn serialize_u8(self, v: u8) -> Result<Self::Ok> {
self.serialize_i32(i32::from(v))
}
fn serialize_u16(self, v: u16) -> Result<Self::Ok> {
self.serialize_i32(i32::from(v))
}
fn serialize_u32(self, _v: u32) -> Result<Self::Ok> {
Err(Error::custom("BigInts are not supported"))
}
fn serialize_u64(self, _v: u64) -> Result<Self::Ok> {
Err(Error::custom("BigInts are not supported"))
}
fn serialize_f32(self, v: f32) -> Result<Self::Ok> {
self.serialize_f64(f64::from(v))
}
fn serialize_f64(self, v: f64) -> Result<Self::Ok> {
Ok(self.vm.ctx.new_float(v).into())
}
fn serialize_char(self, v: char) -> Result<Self::Ok> {
self.serialize_str(&v.to_string())
}
fn serialize_str(self, v: &str) -> Result<Self::Ok> {
Ok(self.vm.ctx.new_str(v).into())
}
fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok> {
use serde::ser::SerializeSeq;
let mut seq = self.serialize_seq(Some(v.len()))?;
for byte in v {
seq.serialize_element(byte)?;
}
seq.end()
}
fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq> {
Ok(Self::SerializeSeq {
values: Vec::with_capacity(len),
vm: self.vm,
})
}
fn serialize_none(self) -> Result<Self::Ok> {
Ok(self.vm.ctx.none())
}
fn serialize_some<T>(self, value: &T) -> Result<Self::Ok>
where
T: ?Sized + Serialize,
{
value.serialize(self)
}
fn serialize_unit(self) -> Result<Self::Ok> {
self.serialize_none()
}
fn serialize_unit_struct(self, _name: &'static str) -> Result<Self::Ok> {
self.serialize_none()
}
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
) -> Result<Self::Ok> {
self.serialize_str(variant)
}
fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<Self::Ok>
where
T: ?Sized + Serialize,
{
value.serialize(self)
}
fn serialize_newtype_variant<T>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result<Self::Ok>
where
T: ?Sized + Serialize,
{
Err(Error::custom("Newtype variants are not supported"))
}
fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
self.serialize_seq(Some(len))
}
fn serialize_tuple_struct(
self,
_name: &'static str,
len: usize,
) -> Result<Self::SerializeTupleStruct> {
self.serialize_seq(Some(len))
}
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeTupleVariant> {
Err(Error::custom("TupleVariants are not supported"))
}
fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
Ok(Dict::new(self.vm))
}
fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
self.serialize_map(Some(len))
}
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(Error::custom("StructVariants are not supported"))
}
}
impl<'a> Sequence<'a> {
fn into_list(self) -> vm::PyObjectRef {
self.vm.ctx.new_list(self.values).into()
}
fn into_tuple(self) -> vm::PyObjectRef {
self.vm.ctx.new_tuple(self.values).into()
}
fn add_element<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
let mut ser = Serializer { vm: self.vm };
let object = value.serialize(&mut ser)?;
self.values.push(object);
Ok(())
}
}
impl<'a> ser::SerializeSeq for Sequence<'a> {
// Must match the `Ok` type of the serializer.
type Ok = vm::PyObjectRef;
// Must match the `Error` type of the serializer.
type Error = Error;
// Serialize a single element of the sequence.
fn serialize_element<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_element(value)
}
// Close the sequence.
fn end(self) -> Result<Self::Ok> {
Ok(self.into_list())
}
}
impl<'a> ser::SerializeTuple for Sequence<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_element<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_element(value)
}
fn end(self) -> Result<Self::Ok> {
Ok(self.into_tuple())
}
}
impl<'a> ser::SerializeTupleStruct for Sequence<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_field<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_element(value)
}
fn end(self) -> Result<Self::Ok> {
Ok(self.into_tuple())
}
}
impl<'a> Dict<'a> {
fn new(vm: &'a vm::VirtualMachine) -> Self {
Dict {
vm,
current_key: None,
dict: vm.ctx.new_dict(),
}
}
fn add_key<T>(&mut self, key: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
let mut ser = Serializer { vm: self.vm };
// TODO: validation that the value is a valid Dictionary key is needed.
self.current_key = Some(key.serialize(&mut ser)?);
Ok(())
}
fn add_value<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
let mut ser = Serializer { vm: self.vm };
let object = value.serialize(&mut ser)?;
let key = self
.current_key
.take()
.ok_or_else(|| Error::custom("Value added without key"))?;
self.dict
.set_item(&*key, object, self.vm)
.map_err(|e| Error::custom(format!("Can't add key to dict: {:?}", e)))?;
Ok(())
}
}
impl<'a> ser::SerializeTupleVariant for Dict<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_field<T>(&mut self, _value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
Err(Error::custom("unsupported"))
}
fn end(self) -> Result<Self::Ok> {
Err(Error::custom("unsupported"))
}
}
impl<'a> ser::SerializeMap for Dict<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_key<T>(&mut self, key: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_key(key)
}
fn serialize_value<T>(&mut self, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_value(value)
}
fn end(self) -> Result<Self::Ok> {
Ok(self.dict.into())
}
}
impl<'a> ser::SerializeStruct for Dict<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_key(key)?;
self.add_value(value)?;
Ok(())
}
fn end(self) -> Result<Self::Ok> {
Ok(self.dict.into())
}
}
impl<'a> ser::SerializeStructVariant for Dict<'a> {
type Ok = vm::PyObjectRef;
type Error = Error;
fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
where
T: ?Sized + Serialize,
{
self.add_key(key)?;
self.add_value(value)?;
Ok(())
}
fn end(self) -> Result<Self::Ok> {
Ok(self.dict.into())
}
}
#[cfg(test)]
mod test {
use num_traits::Float;
use std::{
collections::{BTreeMap, BTreeSet},
str::FromStr,
sync::{Arc, Mutex},
};
use crate::py::{self, Py};
use super::*;
fn unimplemented_test<T: Serialize>(value: T) {
let mut py = Py::new();
py.run(
|vm, _scope| {
let ret = py::to_python_object(vm, &value);
assert!(ret.is_err());
Ok(())
},
"",
)
.unwrap();
}
fn run_test<T: Serialize>(value: T) -> String {
// The Python VM is treated as another thread context so types must
// implement Send and Sync.
let lines = Arc::new(Mutex::new(Vec::new()));
// A clone for the python side
let py_lines = lines.clone();
// We create a fresh Python VM for each test step. Even in release mode,
// this can incur ~ 0.25 seconds per step. The could be sped up by sharing
// VMs between steps.
let mut py = Py::new();
// We're assuming a stable output of Python structures via `.format()`.
let code = "qg_test_func('{}'.format(data))";
py.run(
|vm, scope| {
scope
.globals
.set_item("data", py::to_python_object(vm, &value).unwrap(), vm)?;
// Add our test logging function to the scope before running the code.
scope.globals.set_item(
"qg_test_func",
vm.ctx
.new_function("qg_test_func", move |s: String| {
println!("qg_test_func: {}", s);
py_lines.lock().unwrap().push(s);
})
.into(),
vm,
)?;
Ok(())
},
code,
)
.unwrap();
let test_output = lines.lock().unwrap()[0].clone();
test_output
}
fn serialize_test<T: Serialize>(value: T, expected: &str) {
let test_output = run_test(value);
assert_eq!(test_output, expected.to_string());
}
fn serialize_test_float<T: Float + FromStr + Serialize>(value: T, epsilon: T)
where
<T as FromStr>::Err: std::fmt::Debug,
{
let test_output = run_test(value);
let parsed: T = test_output.parse().unwrap();
assert!((value - parsed).abs() < epsilon);
}
// Tests for each data type in the Serde Data Model:
// https://serde.rs/data-model.html
#[test]
fn integer_serialization() {
serialize_test(100i8, "100");
serialize_test(100u8, "100");
serialize_test(10000i16, "10000");
serialize_test(10000u16, "10000");
serialize_test(1000000i32, "1000000");
unimplemented_test(1000000u32);
unimplemented_test(1000000i64);
unimplemented_test(1000000u64);
}
#[test]
fn float_serialization() {
serialize_test_float(3.141593f32, 0.0000005);
serialize_test_float(3.141592653589793f32, 0.0000000000000005);
}
#[test]
fn char_serialization() {
serialize_test('q', "q");
}
#[test]
fn str_serialization() {
serialize_test("qg", "qg");
}
#[test]
fn bytes_serialization() {
let bytes: &[u8] = &[0x00, 0x55, 0xaa, 0xff][..];
serialize_test(bytes, "[0, 85, 170, 255]");
}
#[test]
fn option_serialization() {
serialize_test(Some(12345678i32), "12345678");
serialize_test(None::<i32>, "None");
}
#[test]
fn unit_serialization() {
serialize_test((), "None");
}
#[test]
fn unit_struct_serialization() {
#[derive(Serialize)]
struct Unit;
serialize_test(Unit {}, "None");
}
#[test]
fn unit_variant_serialization() {
#[derive(Serialize)]
enum E {
A,
B,
}
serialize_test(E::A, "A");
serialize_test(E::B, "B");
}
#[test]
fn newtype_struct_serialization() {
#[derive(Serialize)]
struct S(u8);
serialize_test(S(123u8), "123");
}
#[test]
fn newtype_variant_serialization() {
#[derive(Serialize)]
enum E {
A(u8),
B(String),
}
unimplemented_test(E::A(123u8));
unimplemented_test(E::B("qg".into()));
}
#[test]
fn seq_serialization() {
let vec = vec![1i32, 2, 3];
let set: BTreeSet<i32> = vec.clone().into_iter().collect();
serialize_test(vec, "[1, 2, 3]");
serialize_test(set, "[1, 2, 3]");
}
#[test]
fn tuple_serialization() {
serialize_test(('q', 'g', '!'), "('q', 'g', '!')");
}
#[test]
fn tuple_struct_serialization() {
#[derive(Serialize)]
struct Abc(u8, u8, u8);
serialize_test(Abc(1, 2, 3), "(1, 2, 3)");
}
#[test]
fn map_serialization() {
let map: BTreeMap<String, i32> =
vec![("one".into(), 1), ("two".into(), 2), ("three".into(), 3)]
.clone()
.into_iter()
.collect();
serialize_test(map, "{'one': 1, 'three': 3, 'two': 2}");
}
#[test]
fn struct_serialization() {
#[derive(Debug, Serialize)]
pub struct TestStruct {
i: i32,
s: String,
}
serialize_test(
&TestStruct {
i: 1,
s: "Hello Serde!".to_string(),
},
"{'i': 1, 's': 'Hello Serde!'}",
);
}
#[test]
fn struct_variant_serialization() {
#[derive(Serialize)]
enum E {
A { a1: u8, a2: u8 },
B { b1: u8, b2: u8 },
}
unimplemented_test(E::A { a1: 1, a2: 2 });
unimplemented_test(E::B { b1: 1, b2: 2 });
}
}