blob: c74bd8f557b38eceef75cb0c35f4cd42514dad58 [file] [log] [blame]
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides helpers for defining and working with the fields of IR data classes.
Field specs
-----------
Various utilities are provided for working with `dataclasses.Field`s:
- Classes:
- `FieldSpec` - used to track data about an IR data field
- `FieldContainer` - used to track the field container type
- Functions that work with `FieldSpec`s:
- `make_field_spec`, `build_default`
- Functions for retrieving a set of `FieldSpec`s for a given class:
- `field_specs`
- Functions for retrieving fields and their values:
- `fields_and_values`
- Functions for copying and updating IR data classes
- `copy`, `update`
- Functions to help defining IR data fields
- `oneof_field`, `list_field`, `str_field`
"""
import dataclasses
import enum
import sys
from typing import (
Any,
Callable,
ClassVar,
ForwardRef,
Iterable,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Protocol,
SupportsIndex,
Tuple,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
class IrDataclassInstance(Protocol):
"""Type bound for an IR dataclass instance."""
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
IR_DATACLASS: ClassVar[object]
field_specs: ClassVar["FilteredIrFieldSpecs"]
IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance)
CopyValuesListT = TypeVar("CopyValuesListT", bound=type)
_IR_DATACLASSS_ATTR = "IR_DATACLASS"
def _is_ir_dataclass(obj):
return hasattr(obj, _IR_DATACLASSS_ATTR)
class CopyValuesList(list[CopyValuesListT]):
"""A list that makes copies of any value that is inserted"""
def __init__(
self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
):
if iterable:
super().__init__(iterable)
else:
super().__init__()
self.value_type = value_type
def _copy(self, obj: Any):
if _is_ir_dataclass(obj):
return copy(obj)
return self.value_type(obj)
def extend(self, iterable: Iterable) -> None:
return super().extend([self._copy(i) for i in iterable])
def shallow_copy(self, iterable: Iterable) -> None:
"""Explicitly performs a shallow copy of the provided list"""
return super().extend(iterable)
def append(self, obj: Any) -> None:
return super().append(self._copy(obj))
def insert(self, index: SupportsIndex, obj: Any) -> None:
return super().insert(index, self._copy(obj))
class TemporaryCopyValuesList(NamedTuple):
"""Class used to temporarily hold a CopyValuesList while copying and
constructing an IR dataclass.
"""
temp_list: CopyValuesList
class FieldContainer(enum.Enum):
"""Indicates a fields container type"""
NONE = 0
OPTIONAL = 1
LIST = 2
class FieldSpec(NamedTuple):
"""Indicates the container and type of a field.
`FieldSpec` objects are accessed millions of times during runs so we cache as
many operations as possible.
- `is_dataclass`: `dataclasses.is_dataclass(data_type)`
- `is_sequence`: `container is FieldContainer.LIST`
- `is_enum`: `issubclass(data_type, enum.Enum)`
- `is_oneof`: `oneof is not None`
Use `make_field_spec` to automatically fill in the cached operations.
"""
name: str
data_type: type
container: FieldContainer
oneof: Optional[str]
is_dataclass: bool
is_sequence: bool
is_enum: bool
is_oneof: bool
def make_field_spec(
name: str, data_type: type, container: FieldContainer, oneof: Optional[str]
):
"""Builds a field spec with cached type queries."""
return FieldSpec(
name,
data_type,
container,
oneof,
is_dataclass=_is_ir_dataclass(data_type),
is_sequence=container is FieldContainer.LIST,
is_enum=issubclass(data_type, enum.Enum),
is_oneof=oneof is not None,
)
def build_default(field_spec: FieldSpec):
"""Builds a default instance of the given field"""
if field_spec.is_sequence:
return CopyValuesList(field_spec.data_type)
if field_spec.is_enum:
return field_spec.data_type(int())
return field_spec.data_type()
class FilteredIrFieldSpecs:
"""Provides cached views of an IR dataclass' fields."""
def __init__(self, specs: Mapping[str, FieldSpec]):
self.all_field_specs = specs
self.field_specs = tuple(specs.values())
self.dataclass_field_specs = {
k: v for k, v in specs.items() if v.is_dataclass
}
self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
self.sequence_field_specs = tuple(
v for v in specs.values() if v.is_sequence
)
self.oneof_mappings = tuple(
(k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
)
def all_ir_classes(mod):
"""Retrieves a list of all IR dataclass definitions in the given module."""
return (
v
for v in mod.__dict__.values()
if isinstance(type, v.__class__) and _is_ir_dataclass(v)
)
class IrDataclassSpecs:
"""Maintains a cache of all IR dataclass specs."""
spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {}
@classmethod
def get_mod_specs(cls, mod):
"""Gets the IR dataclass specs for the given module."""
return {
ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
for ir_class in all_ir_classes(mod)
}
@classmethod
def get_specs(cls, data_class):
"""Gets the field specs for the given class. The specs will be cached."""
if data_class not in cls.spec_cache:
mod = sys.modules[data_class.__module__]
cls.spec_cache.update(cls.get_mod_specs(mod))
return cls.spec_cache[data_class]
def cache_message_specs(mod, cls):
"""Adds a cached `field_specs` attribute to IR dataclasses in `mod`
excluding the given base `cls`.
This needs to be done after the dataclass decorators run and create the
wrapped classes.
"""
for data_class in all_ir_classes(mod):
if data_class is not cls:
data_class.field_specs = IrDataclassSpecs.get_specs(data_class)
def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]:
"""Gets the IR data field names and types for the given IR data class"""
# Get the dataclass fields
class_fields = dataclasses.fields(cast(Any, cls))
# Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute
# `builtins.Expression` for 'Expression' rather than `ir_data.Expression`.
# Instead we manually subsitute the type by extracting the list of classes
# from the class' module and manually substituting.
mod_ns = {
k: v
for k, v in sys.modules[cls.__module__].__dict__.items()
if isinstance(type, v.__class__)
}
# Now extract the concrete type out of optionals
result: MutableMapping[str, FieldSpec] = {}
for class_field in class_fields:
if class_field.name.startswith("_"):
continue
container_type = FieldContainer.NONE
type_hint = class_field.type
oneof = class_field.metadata.get("oneof")
# Check if this type is wrapped
origin = get_origin(type_hint)
# Get the wrapped types if there are any
args = get_args(type_hint)
if origin is not None:
# Extract the type.
type_hint = args[0]
# Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
# have to check if it's a `Union` instead of just using `Optional`.
if origin == Union:
# Make sure this is an `Optional` and not another `Union` type.
assert len(args) == 2 and args[1] == type(None)
container_type = FieldContainer.OPTIONAL
elif origin == list:
container_type = FieldContainer.LIST
else:
raise TypeError(f"Field has invalid container type: {origin}")
# Resolve any forward references.
if isinstance(type_hint, str):
type_hint = mod_ns[type_hint]
if isinstance(type_hint, ForwardRef):
type_hint = mod_ns[type_hint.__forward_arg__]
result[class_field.name] = make_field_spec(
class_field.name, type_hint, container_type, oneof
)
return result
def field_specs(obj: IrDataT | type[IrDataT]) -> Mapping[str, FieldSpec]:
"""Retrieves the fields specs for the the give data type.
The results of this method are cached to reduce lookup overhead.
"""
cls = obj if isinstance(obj, type) else type(obj)
if cls is type(None):
raise TypeError("field_specs called with invalid type: NoneType")
return IrDataclassSpecs.get_specs(cls).all_field_specs
def fields_and_values(
ir: IrDataT,
value_filt: Optional[Callable[[Any], bool]] = None,
):
"""Retrieves the fields and their values for a given IR data class.
Args:
ir: The IR data class or a read-only wrapper of an IR data class.
value_filt: Optional filter used to exclude values.
"""
set_fields: list[Tuple[FieldSpec, Any]] = []
specs: FilteredIrFieldSpecs = ir.field_specs
for spec in specs.field_specs:
value = getattr(ir, spec.name)
if not value_filt or value_filt(value):
set_fields.append((spec, value))
return set_fields
# `copy` is one of the hottest paths of embossc. We've taken steps to
# optimize this path at the expense of code clarity and modularization.
#
# 1. `FilteredFieldSpecs` are cached on in the class definition for IR
# dataclasses under the `ir_data.Message.field_specs` class attribute. We
# just assume the passed in object has that attribute.
# 2. We cache a `field_specs` entry that is just the `values()` of the
# `all_field_specs` dict.
# 3. Copied lists are wrapped in a `TemporaryCopyValuesList`. This is used to
# signal to consumers that they can take ownership of the contained list
# rather than copying it again. See `ir_data.Message()` and `udpate()` for
# where this is used.
# 4. `FieldSpec` checks are cached including `is_dataclass` and `is_sequence`.
# 5. None checks are only done in `copy()`, `_copy_set_fields` only
# references `_copy()` to avoid this step.
def _copy_set_fields(ir: IrDataT):
values: MutableMapping[str, Any] = {}
specs: FilteredIrFieldSpecs = ir.field_specs
for spec in specs.field_specs:
value = getattr(ir, spec.name)
if value is not None:
if spec.is_sequence:
if spec.is_dataclass:
copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value))
value = TemporaryCopyValuesList(copy_value)
else:
copy_value = CopyValuesList(spec.data_type, value)
value = TemporaryCopyValuesList(copy_value)
elif spec.is_dataclass:
value = _copy(value)
values[spec.name] = value
return values
def _copy(ir: IrDataT) -> IrDataT:
return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc]
def copy(ir: IrDataT) -> IrDataT | None:
"""Creates a copy of the given IR data class"""
if not ir:
return None
return _copy(ir)
def update(ir: IrDataT, template: IrDataT):
"""Updates `ir`s fields with all set fields in the template."""
for k, v in _copy_set_fields(template).items():
if isinstance(v, TemporaryCopyValuesList):
v = v.temp_list
setattr(ir, k, v)
class OneOfField:
"""Decorator for a "oneof" field.
Tracks when the field is set and will unset othe fields in the associated
oneof group.
Note: Decorators only work if dataclass slots aren't used.
"""
def __init__(self, oneof: str) -> None:
super().__init__()
self.oneof = oneof
self.owner_type = None
self.proxy_name: str = ""
self.name: str = ""
def __set_name__(self, owner, name):
self.name = name
self.proxy_name = f"_{name}"
self.owner_type = owner
# Add our empty proxy field to the class.
setattr(owner, self.proxy_name, None)
def __get__(self, obj, objtype=None):
return getattr(obj, self.proxy_name)
def __set__(self, obj, value):
if value is self:
# This will happen if the dataclass uses the default value, we just
# default to None.
value = None
if value is not None:
# Clear the others
for name, oneof in IrDataclassSpecs.get_specs(
self.owner_type
).oneof_mappings:
if oneof == self.oneof and name != self.name:
setattr(obj, name, None)
setattr(obj, self.proxy_name, value)
def oneof_field(name: str):
"""Alternative for `datclasses.field` that sets up a oneof variable"""
return dataclasses.field( # pylint:disable=invalid-field-call
default=OneOfField(name), metadata={"oneof": name}, init=True
)
def str_field():
"""Helper used to define a defaulted str field"""
return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call
def list_field(cls_or_fn):
"""Helper used to define a defaulted list field.
A lambda can be used to defer resolution of a field type that references its
container type, for example:
```
class Foo:
subtypes: list['Foo'] = list_field(lambda: Foo)
names: list[str] = list_field(str)
```
Args:
cls_or_fn: The class type or a function that resolves to the class type.
"""
def list_factory(c):
return CopyValuesList(c if isinstance(c, type) else c())
return dataclasses.field( # pylint:disable=invalid-field-call
default_factory=lambda: list_factory(cls_or_fn)
)