| # Copyright 2024 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Provides a helpers for working with IR data elements. |
| |
| Historical note: At one point protocol buffers were used for IR data. The |
| codebase still expects the IR data classes to behave similarly, particularly |
| with respect to "autovivification" where accessing an undefined field will |
| create it temporarily and add it if assigned to. Though, perhaps not fully |
| following the Pythonic ethos, we provide this behavior via the `builder` and |
| `reader` helpers to remain compatible with the rest of the codebase. |
| |
| builder |
| ------- |
| Instead of: |
| ``` |
| def set_function_name_end(function: Function): |
| if not function.function_name: |
| function.function_name = Word() |
| if not function.function_name.source_location: |
| function.function_name.source_location = Location() |
| word.source_location.end = Position(line=1,column=2) |
| ``` |
| |
| We can do: |
| ``` |
| def set_function_name_end(function: Function): |
| builder(function).function_name.source_location.end = Position(line=1, |
| column=2) |
| ``` |
| |
| reader |
| ------ |
| Instead of: |
| ``` |
| def is_leaf_synthetic(data): |
| if data: |
| if data.attribute: |
| if data.attribute.value: |
| if data.attribute.value.is_synthetic is not None: |
| return data.attribute.value.is_synthetic |
| return False |
| ``` |
| We can do: |
| ``` |
| def is_leaf_synthetic(data): |
| return reader(data).attribute.value.is_synthetic |
| ``` |
| |
| IrDataSerializer |
| ---------------- |
| Provides methods for serializing and deserializing an IR data object. |
| """ |
| import enum |
| import json |
| from typing import ( |
| Any, |
| Callable, |
| Generic, |
| MutableMapping, |
| MutableSequence, |
| Optional, |
| Tuple, |
| TypeVar, |
| cast, |
| ) |
| |
| from compiler.util import ir_data |
| from compiler.util import ir_data_fields |
| |
| |
| MessageT = TypeVar("MessageT", bound=ir_data.Message) |
| |
| |
| def field_specs(ir: MessageT | type[MessageT]): |
| """Retrieves the field specs for the IR data class""" |
| data_type = ir if isinstance(ir, type) else type(ir) |
| return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs |
| |
| |
| class IrDataSerializer: |
| """Provides methods for serializing IR data objects""" |
| |
| def __init__(self, ir: MessageT): |
| assert ir is not None |
| self.ir = ir |
| |
| def _to_dict( |
| self, |
| ir: MessageT, |
| field_func: Callable[ |
| [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]] |
| ], |
| ) -> MutableMapping[str, Any]: |
| assert ir is not None |
| values: MutableMapping[str, Any] = {} |
| for spec, value in field_func(ir): |
| if value is not None and spec.is_dataclass: |
| if spec.is_sequence: |
| value = [self._to_dict(v, field_func) for v in value] |
| else: |
| value = self._to_dict(value, field_func) |
| values[spec.name] = value |
| return values |
| |
| def to_dict(self, exclude_none: bool = False): |
| """Converts the IR data class to a dictionary.""" |
| |
| def non_empty(ir): |
| return fields_and_values( |
| ir, lambda v: v is not None and (not isinstance(v, list) or len(v)) |
| ) |
| |
| def all_fields(ir): |
| return fields_and_values(ir) |
| |
| # It's tempting to use `dataclasses.asdict` here, but that does a deep |
| # copy which is overkill for the current usage; mainly as an intermediary |
| # for `to_json` and `repr`. |
| return self._to_dict(self.ir, non_empty if exclude_none else all_fields) |
| |
| def to_json(self, *args, **kwargs): |
| """Converts the IR data class to a JSON string""" |
| return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs) |
| |
| @staticmethod |
| def from_json(data_cls, data): |
| """Constructs an IR data class from the given JSON string""" |
| as_dict = json.loads(data) |
| return IrDataSerializer.from_dict(data_cls, as_dict) |
| |
| def copy_from_dict(self, data): |
| """Deserializes the data and overwrites the IR data class with it""" |
| cls = type(self.ir) |
| data_copy = IrDataSerializer.from_dict(cls, data) |
| for k in field_specs(cls): |
| setattr(self.ir, k, getattr(data_copy, k)) |
| |
| @staticmethod |
| def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum: |
| if isinstance(val, str): |
| return getattr(enum_cls, val) |
| return enum_cls(val) |
| |
| @staticmethod |
| def _enum_type_hook(enum_cls: type[enum.Enum]): |
| return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val) |
| |
| @staticmethod |
| def _from_dict(data_cls: type[MessageT], data): |
| class_fields: MutableMapping[str, Any] = {} |
| for name, spec in ir_data_fields.field_specs(data_cls).items(): |
| if (value := data.get(name)) is not None: |
| if spec.is_dataclass: |
| if spec.is_sequence: |
| class_fields[name] = [ |
| IrDataSerializer._from_dict(spec.data_type, v) for v in value |
| ] |
| else: |
| class_fields[name] = IrDataSerializer._from_dict( |
| spec.data_type, value |
| ) |
| else: |
| if spec.data_type in ( |
| ir_data.FunctionMapping, |
| ir_data.AddressableUnit, |
| ): |
| class_fields[name] = IrDataSerializer._enum_type_converter( |
| spec.data_type, value |
| ) |
| else: |
| if spec.is_sequence: |
| class_fields[name] = value |
| else: |
| class_fields[name] = spec.data_type(value) |
| return data_cls(**class_fields) |
| |
| @staticmethod |
| def from_dict(data_cls: type[MessageT], data): |
| """Creates a new IR data instance from a serialized dict""" |
| return IrDataSerializer._from_dict(data_cls, data) |
| |
| |
| class _IrDataSequenceBuilder(MutableSequence[MessageT]): |
| """Wrapper for a list of IR elements |
| |
| Simply wraps the returned values during indexed access and iteration with |
| IrDataBuilders. |
| """ |
| |
| def __init__(self, target: MutableSequence[MessageT]): |
| self._target = target |
| |
| def __delitem__(self, key): |
| del self._target[key] |
| |
| def __getitem__(self, key): |
| return _IrDataBuilder(self._target.__getitem__(key)) |
| |
| def __setitem__(self, key, value): |
| self._target[key] = value |
| |
| def __iter__(self): |
| itr = iter(self._target) |
| for i in itr: |
| yield _IrDataBuilder(i) |
| |
| def __repr__(self): |
| return repr(self._target) |
| |
| def __len__(self): |
| return len(self._target) |
| |
| def __eq__(self, other): |
| return self._target == other |
| |
| def __ne__(self, other): |
| return self._target != other |
| |
| def insert(self, index, value): |
| self._target.insert(index, value) |
| |
| def extend(self, values): |
| self._target.extend(values) |
| |
| |
| class _IrDataBuilder(Generic[MessageT]): |
| """Wrapper for an IR element""" |
| |
| def __init__(self, ir: MessageT) -> None: |
| assert ir is not None |
| self.ir: MessageT = ir |
| |
| def __setattr__(self, __name: str, __value: Any) -> None: |
| if __name == "ir": |
| # This our proxy object |
| object.__setattr__(self, __name, __value) |
| else: |
| # Passthrough to the proxy object |
| ir: MessageT = object.__getattribute__(self, "ir") |
| setattr(ir, __name, __value) |
| |
| def __getattribute__(self, name: str) -> Any: |
| """Hook for `getattr` that handles adding missing fields. |
| |
| If the field is missing inserts it, and then returns either the raw value |
| for basic types |
| or a new IrBuilder wrapping the field to handle the next field access in a |
| longer chain. |
| """ |
| |
| # Check if getting one of the builder attributes |
| if name in ("CopyFrom", "ir"): |
| return object.__getattribute__(self, name) |
| |
| # Get our target object by bypassing our getattr hook |
| ir: MessageT = object.__getattribute__(self, "ir") |
| if ir is None: |
| return object.__getattribute__(self, name) |
| |
| if name in ("HasField", "WhichOneof"): |
| return getattr(ir, name) |
| |
| field_spec = field_specs(ir).get(name) |
| if field_spec is None: |
| raise AttributeError( |
| f"No field {name} on {type(ir).__module__}.{type(ir).__name__}." |
| ) |
| |
| obj = getattr(ir, name, None) |
| if obj is None: |
| # Create a default and store it |
| obj = ir_data_fields.build_default(field_spec) |
| setattr(ir, name, obj) |
| |
| if field_spec.is_dataclass: |
| obj = ( |
| _IrDataSequenceBuilder(obj) |
| if field_spec.is_sequence |
| else _IrDataBuilder(obj) |
| ) |
| |
| return obj |
| |
| def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name |
| """Updates the fields of this class with values set in the template""" |
| update(cast(type[MessageT], self), template) |
| |
| |
| def builder(target: MessageT) -> MessageT: |
| """Create a wrapper around the target to help build an IR Data structure""" |
| # Check if the target is already a builder. |
| if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)): |
| return target |
| |
| # Builders are only valid for IR data classes. |
| if not hasattr(type(target), "IR_DATACLASS"): |
| raise TypeError(f"Builder target {type(target)} is not an ir_data.message") |
| |
| # Create a builder and cast it to the target type to expose type hinting for |
| # the wrapped type. |
| return cast(MessageT, _IrDataBuilder(target)) |
| |
| |
| def _field_checker_from_spec(spec: ir_data_fields.FieldSpec): |
| """Helper that builds an FieldChecker that pretends to be an IR class""" |
| if spec.is_sequence: |
| return [] |
| if spec.is_dataclass: |
| return _ReadOnlyFieldChecker(spec) |
| return ir_data_fields.build_default(spec) |
| |
| |
| def _field_type(ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> type: |
| if isinstance(ir_or_spec, ir_data_fields.FieldSpec): |
| return ir_or_spec.data_type |
| return type(ir_or_spec) |
| |
| |
| class _ReadOnlyFieldChecker: |
| """Class used the chain calls to fields that aren't set""" |
| |
| def __init__(self, ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> None: |
| self.ir_or_spec = ir_or_spec |
| |
| def __setattr__(self, name: str, value: Any) -> None: |
| if name == "ir_or_spec": |
| return object.__setattr__(self, name, value) |
| |
| raise AttributeError(f"Cannot set {name} on read-only wrapper") |
| |
| def __getattribute__(self, name: str) -> Any: # pylint:disable=too-many-return-statements |
| ir_or_spec = object.__getattribute__(self, "ir_or_spec") |
| if name == "ir_or_spec": |
| return ir_or_spec |
| |
| field_type = _field_type(ir_or_spec) |
| spec = field_specs(field_type).get(name) |
| if not spec: |
| if isinstance(ir_or_spec, ir_data_fields.FieldSpec): |
| if name == "HasField": |
| return lambda x: False |
| if name == "WhichOneof": |
| return lambda x: None |
| return object.__getattribute__(ir_or_spec, name) |
| |
| if isinstance(ir_or_spec, ir_data_fields.FieldSpec): |
| # Just pretending |
| return _field_checker_from_spec(spec) |
| |
| value = getattr(ir_or_spec, name) |
| if value is None: |
| return _field_checker_from_spec(spec) |
| |
| if spec.is_dataclass: |
| if spec.is_sequence: |
| return [_ReadOnlyFieldChecker(i) for i in value] |
| return _ReadOnlyFieldChecker(value) |
| |
| return value |
| |
| def __eq__(self, other): |
| if isinstance(other, _ReadOnlyFieldChecker): |
| other = other.ir_or_spec |
| return self.ir_or_spec == other |
| |
| def __ne__(self, other): |
| return not self == other |
| |
| |
| def reader(obj: MessageT | _ReadOnlyFieldChecker) -> MessageT: |
| """Builds a read-only wrapper that can be used to check chains of possibly |
| unset fields. |
| |
| This wrapper explicitly does not alter the wrapped object and is only |
| intended for reading contents. |
| |
| For example, a `reader` lets you do: |
| ``` |
| def get_function_name_end_column(function: ir_data.Function): |
| return reader(function).function_name.source_location.end.column |
| ``` |
| |
| Instead of: |
| ``` |
| def get_function_name_end_column(function: ir_data.Function): |
| if function.function_name: |
| if function.function_name.source_location: |
| if function.function_name.source_location.end: |
| return function.function_name.source_location.end.column |
| return 0 |
| ``` |
| """ |
| # Create a read-only wrapper if it's not already one. |
| if not isinstance(obj, _ReadOnlyFieldChecker): |
| obj = _ReadOnlyFieldChecker(obj) |
| |
| # Cast it back to the original type. |
| return cast(MessageT, obj) |
| |
| |
| def _extract_ir( |
| ir_or_wrapper: MessageT | _ReadOnlyFieldChecker | _IrDataBuilder | None, |
| ) -> ir_data_fields.IrDataclassInstance | None: |
| if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker): |
| ir_or_spec = ir_or_wrapper.ir_or_spec |
| if isinstance(ir_or_spec, ir_data_fields.FieldSpec): |
| # This is a placeholder entry, no fields are set. |
| return None |
| ir_or_wrapper = ir_or_spec |
| elif isinstance(ir_or_wrapper, _IrDataBuilder): |
| ir_or_wrapper = ir_or_wrapper.ir |
| return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper) |
| |
| |
| def fields_and_values( |
| ir_wrapper: MessageT | _ReadOnlyFieldChecker, |
| value_filt: Optional[Callable[[Any], bool]] = None, |
| ) -> list[Tuple[ir_data_fields.FieldSpec, Any]]: |
| """Retrieves the fields and their values for a given IR data class. |
| |
| Args: |
| ir: The IR data class or a read-only wrapper of an IR data class. |
| value_filt: Optional filter used to exclude values. |
| """ |
| if (ir := _extract_ir(ir_wrapper)) is None: |
| return [] |
| |
| return ir_data_fields.fields_and_values(ir, value_filt) |
| |
| |
| def get_set_fields(ir: MessageT): |
| """Retrieves the field spec and value of fields that are set in the given IR data class. |
| |
| A value is considered "set" if it is not None. |
| """ |
| return fields_and_values(ir, lambda v: v is not None) |
| |
| |
| def copy(ir_wrapper: MessageT | None) -> MessageT | None: |
| """Creates a copy of the given IR data class""" |
| if (ir := _extract_ir(ir_wrapper)) is None: |
| return None |
| ir_copy = ir_data_fields.copy(ir) |
| return cast(MessageT, ir_copy) |
| |
| |
| def update(ir: MessageT, template: MessageT): |
| """Updates `ir`s fields with all set fields in the template.""" |
| if not (template_ir := _extract_ir(template)): |
| return |
| |
| ir_data_fields.update( |
| cast(ir_data_fields.IrDataclassInstance, ir), template_ir |
| ) |