blob: 8347625d5c76a8455e10032a4df522f8d935dbc5 [file] [log] [blame]
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a helpers for working with IR data elements.
Historical note: At one point protocol buffers were used for IR data. The
codebase still expects the IR data classes to behave similarly, particularly
with respect to "autovivification" where accessing an undefined field will
create it temporarily and add it if assigned to. Though, perhaps not fully
following the Pythonic ethos, we provide this behavior via the `builder` and
`reader` helpers to remain compatible with the rest of the codebase.
builder
-------
Instead of:
```
def set_function_name_end(function: Function):
if not function.function_name:
function.function_name = Word()
if not function.function_name.source_location:
function.function_name.source_location = Location()
word.source_location.end = Position(line=1,column=2)
```
We can do:
```
def set_function_name_end(function: Function):
builder(function).function_name.source_location.end = Position(line=1,
column=2)
```
reader
------
Instead of:
```
def is_leaf_synthetic(data):
if data:
if data.attribute:
if data.attribute.value:
if data.attribute.value.is_synthetic is not None:
return data.attribute.value.is_synthetic
return False
```
We can do:
```
def is_leaf_synthetic(data):
return reader(data).attribute.value.is_synthetic
```
IrDataSerializer
----------------
Provides methods for serializing and deserializing an IR data object.
"""
import enum
import json
from typing import (
Any,
Callable,
Generic,
MutableMapping,
MutableSequence,
Optional,
Tuple,
TypeVar,
Union,
cast,
)
from compiler.util import ir_data
from compiler.util import ir_data_fields
MessageT = TypeVar("MessageT", bound=ir_data.Message)
def field_specs(ir: Union[MessageT, type[MessageT]]):
"""Retrieves the field specs for the IR data class"""
data_type = ir if isinstance(ir, type) else type(ir)
return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
class IrDataSerializer:
"""Provides methods for serializing IR data objects"""
def __init__(self, ir: MessageT):
assert ir is not None
self.ir = ir
def _to_dict(
self,
ir: MessageT,
field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]],
) -> MutableMapping[str, Any]:
assert ir is not None
values: MutableMapping[str, Any] = {}
for spec, value in field_func(ir):
if value is not None and spec.is_dataclass:
if spec.is_sequence:
value = [self._to_dict(v, field_func) for v in value]
else:
value = self._to_dict(value, field_func)
values[spec.name] = value
return values
def to_dict(self, exclude_none: bool = False):
"""Converts the IR data class to a dictionary."""
def non_empty(ir):
return fields_and_values(
ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
)
def all_fields(ir):
return fields_and_values(ir)
# It's tempting to use `dataclasses.asdict` here, but that does a deep
# copy which is overkill for the current usage; mainly as an intermediary
# for `to_json` and `repr`.
return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
def to_json(self, *args, **kwargs):
"""Converts the IR data class to a JSON string"""
return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
@staticmethod
def from_json(data_cls, data):
"""Constructs an IR data class from the given JSON string"""
as_dict = json.loads(data)
return IrDataSerializer.from_dict(data_cls, as_dict)
def copy_from_dict(self, data):
"""Deserializes the data and overwrites the IR data class with it"""
cls = type(self.ir)
data_copy = IrDataSerializer.from_dict(cls, data)
for k in field_specs(cls):
setattr(self.ir, k, getattr(data_copy, k))
@staticmethod
def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
if isinstance(val, str):
return getattr(enum_cls, val)
return enum_cls(val)
@staticmethod
def _enum_type_hook(enum_cls: type[enum.Enum]):
return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
@staticmethod
def _from_dict(data_cls: type[MessageT], data):
class_fields: MutableMapping[str, Any] = {}
for name, spec in ir_data_fields.field_specs(data_cls).items():
if (value := data.get(name)) is not None:
if spec.is_dataclass:
if spec.is_sequence:
class_fields[name] = [
IrDataSerializer._from_dict(spec.data_type, v)
for v in value
]
else:
class_fields[name] = IrDataSerializer._from_dict(
spec.data_type, value
)
else:
if spec.data_type in (
ir_data.FunctionMapping,
ir_data.AddressableUnit,
):
class_fields[name] = IrDataSerializer._enum_type_converter(
spec.data_type, value
)
else:
if spec.is_sequence:
class_fields[name] = value
else:
class_fields[name] = spec.data_type(value)
return data_cls(**class_fields)
@staticmethod
def from_dict(data_cls: type[MessageT], data):
"""Creates a new IR data instance from a serialized dict"""
return IrDataSerializer._from_dict(data_cls, data)
class _IrDataSequenceBuilder(MutableSequence[MessageT]):
"""Wrapper for a list of IR elements
Simply wraps the returned values during indexed access and iteration with
IrDataBuilders.
"""
def __init__(self, target: MutableSequence[MessageT]):
self._target = target
def __delitem__(self, key):
del self._target[key]
def __getitem__(self, key):
return _IrDataBuilder(self._target.__getitem__(key))
def __setitem__(self, key, value):
self._target[key] = value
def __iter__(self):
itr = iter(self._target)
for i in itr:
yield _IrDataBuilder(i)
def __repr__(self):
return repr(self._target)
def __len__(self):
return len(self._target)
def __eq__(self, other):
return self._target == other
def __ne__(self, other):
return self._target != other
def insert(self, index, value):
self._target.insert(index, value)
def extend(self, values):
self._target.extend(values)
class _IrDataBuilder(Generic[MessageT]):
"""Wrapper for an IR element"""
def __init__(self, ir: MessageT) -> None:
assert ir is not None
self.ir: MessageT = ir
def __setattr__(self, __name: str, __value: Any) -> None:
if __name == "ir":
# This our proxy object
object.__setattr__(self, __name, __value)
else:
# Passthrough to the proxy object
ir: MessageT = object.__getattribute__(self, "ir")
setattr(ir, __name, __value)
def __getattribute__(self, name: str) -> Any:
"""Hook for `getattr` that handles adding missing fields.
If the field is missing inserts it, and then returns either the raw value
for basic types
or a new IrBuilder wrapping the field to handle the next field access in a
longer chain.
"""
# Check if getting one of the builder attributes
if name in ("CopyFrom", "ir"):
return object.__getattribute__(self, name)
# Get our target object by bypassing our getattr hook
ir: MessageT = object.__getattribute__(self, "ir")
if ir is None:
return object.__getattribute__(self, name)
if name in ("HasField", "WhichOneof"):
return getattr(ir, name)
field_spec = field_specs(ir).get(name)
if field_spec is None:
raise AttributeError(
f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
)
obj = getattr(ir, name, None)
if obj is None:
# Create a default and store it
obj = ir_data_fields.build_default(field_spec)
setattr(ir, name, obj)
if field_spec.is_dataclass:
obj = (
_IrDataSequenceBuilder(obj)
if field_spec.is_sequence
else _IrDataBuilder(obj)
)
return obj
def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
"""Updates the fields of this class with values set in the template"""
update(cast(type[MessageT], self), template)
def builder(target: MessageT) -> MessageT:
"""Create a wrapper around the target to help build an IR Data structure"""
# Check if the target is already a builder.
if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
return target
# Builders are only valid for IR data classes.
if not hasattr(type(target), "IR_DATACLASS"):
raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
# Create a builder and cast it to the target type to expose type hinting for
# the wrapped type.
return cast(MessageT, _IrDataBuilder(target))
def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
"""Helper that builds an FieldChecker that pretends to be an IR class"""
if spec.is_sequence:
return []
if spec.is_dataclass:
return _ReadOnlyFieldChecker(spec)
return ir_data_fields.build_default(spec)
def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type[Any]:
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
return ir_or_spec.data_type
return type(ir_or_spec)
class _ReadOnlyFieldChecker:
"""Class used the chain calls to fields that aren't set"""
def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
self.ir_or_spec = ir_or_spec
def __setattr__(self, name: str, value: Any) -> None:
if name == "ir_or_spec":
return object.__setattr__(self, name, value)
raise AttributeError(f"Cannot set {name} on read-only wrapper")
def __getattribute__(
self, name: str
) -> Any: # pylint:disable=too-many-return-statements
ir_or_spec = object.__getattribute__(self, "ir_or_spec")
if name == "ir_or_spec":
return ir_or_spec
field_type = _field_type(ir_or_spec)
spec = field_specs(field_type).get(name)
if not spec:
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
if name == "HasField":
return lambda x: False
if name == "WhichOneof":
return lambda x: None
return object.__getattribute__(ir_or_spec, name)
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
# Just pretending
return _field_checker_from_spec(spec)
value = getattr(ir_or_spec, name)
if value is None:
return _field_checker_from_spec(spec)
if spec.is_dataclass:
if spec.is_sequence:
return [_ReadOnlyFieldChecker(i) for i in value]
return _ReadOnlyFieldChecker(value)
return value
def __eq__(self, other):
if isinstance(other, _ReadOnlyFieldChecker):
other = other.ir_or_spec
return self.ir_or_spec == other
def __ne__(self, other):
return not self == other
def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT:
"""Builds a read-only wrapper that can be used to check chains of possibly
unset fields.
This wrapper explicitly does not alter the wrapped object and is only
intended for reading contents.
For example, a `reader` lets you do:
```
def get_function_name_end_column(function: ir_data.Function):
return reader(function).function_name.source_location.end.column
```
Instead of:
```
def get_function_name_end_column(function: ir_data.Function):
if function.function_name:
if function.function_name.source_location:
if function.function_name.source_location.end:
return function.function_name.source_location.end.column
return 0
```
"""
# Create a read-only wrapper if it's not already one.
if not isinstance(obj, _ReadOnlyFieldChecker):
obj = _ReadOnlyFieldChecker(obj)
# Cast it back to the original type.
return cast(MessageT, obj)
def _extract_ir(
ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None],
) -> Optional[ir_data_fields.IrDataclassInstance]:
if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
ir_or_spec = ir_or_wrapper.ir_or_spec
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
# This is a placeholder entry, no fields are set.
return None
ir_or_wrapper = ir_or_spec
elif isinstance(ir_or_wrapper, _IrDataBuilder):
ir_or_wrapper = ir_or_wrapper.ir
return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
def fields_and_values(
ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker],
value_filt: Optional[Callable[[Any], bool]] = None,
) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
"""Retrieves the fields and their values for a given IR data class.
Args:
ir: The IR data class or a read-only wrapper of an IR data class.
value_filt: Optional filter used to exclude values.
"""
if (ir := _extract_ir(ir_wrapper)) is None:
return []
return ir_data_fields.fields_and_values(ir, value_filt)
def get_set_fields(ir: MessageT):
"""Retrieves the field spec and value of fields that are set in the given IR data class.
A value is considered "set" if it is not None.
"""
return fields_and_values(ir, lambda v: v is not None)
def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]:
"""Creates a copy of the given IR data class"""
if (ir := _extract_ir(ir_wrapper)) is None:
return None
ir_copy = ir_data_fields.copy(ir)
return cast(MessageT, ir_copy)
def update(ir: MessageT, template: MessageT):
"""Updates `ir`s fields with all set fields in the template."""
if not (template_ir := _extract_ir(template)):
return
ir_data_fields.update(cast(ir_data_fields.IrDataclassInstance, ir), template_ir)