blob: 09501f06eab8353a13eb0d3b86c821da95256f35 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adds auto-generated virtual fields to the IR."""
from compiler.front_end import attributes
from compiler.front_end import expression_bounds
from compiler.util import ir_pb2
from compiler.util import ir_util
from compiler.util import traverse_ir
def _find_field_reference_path(expression):
"""Returns a path to a field reference, or None.
If the provided expression contains exactly one field_reference,
_find_field_reference_path will return a list of indexes, such that
recursively reading the index'th element of expression.function.args will find
the field_reference. For example, for:
5 + (x * 2)
_find_field_reference_path will return [1, 0]: from the top-level `+`
expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
expression.
Arguments:
expression: an ir_pb2.Expression to walk
Returns:
A list of indexes to find a field_reference, or None.
"""
found, indexes = _recursively_find_field_reference_path(expression)
if found == 1:
return indexes
else:
return None
def _recursively_find_field_reference_path(expression):
"""Recursive implementation of _find_field_reference_path."""
if expression.WhichOneof("expression") == "field_reference":
return 1, []
elif expression.WhichOneof("expression") == "function":
field_count = 0
path = []
for index in range(len(expression.function.args)):
arg = expression.function.args[index]
arg_result = _recursively_find_field_reference_path(arg)
arg_field_count, arg_path = arg_result
if arg_field_count == 1 and field_count == 0:
path = [index] + arg_path
field_count += arg_field_count
if field_count == 1:
return field_count, path
else:
return field_count, []
else:
return 0, []
def _invert_expression(expression, ir):
"""For the given expression, searches for an algebraic inverse expression.
That is, it takes the notional equation:
$logical_value = expression
and, if there is exactly one `field_reference` in `expression`, it will
attempt to solve the equation for that field. For example, if the expression
is `x + 1`, it will iteratively transform:
$logical_value = x + 1
$logical_value - 1 = x + 1 - 1
$logical_value - 1 = x
and finally return `x` and `$logical_value - 1`.
The purpose of this transformation is to find an assignment statement that can
be used to write back through certain virtual fields. E.g., given:
struct Foo:
0 [+1] UInt raw_value
let actual_value = raw_value + 100
it should be possible to write a value to the `actual_value` field, and have
it set `raw_value` to the appropriate value.
Arguments:
expression: an ir_pb2.Expression to be inverted.
ir: the full IR, for looking up symbols.
Returns:
(field_reference, inverse_expression) if expression can be inverted,
otherwise None.
"""
reference_path = _find_field_reference_path(expression)
if reference_path is None:
return None
subexpression = expression
result = ir_pb2.Expression(
builtin_reference=ir_pb2.Reference(
canonical_name=ir_pb2.CanonicalName(
module_file="",
object_path=["$logical_value"]
),
source_name=[ir_pb2.Word(
text="$logical_value",
source_location=ir_pb2.Location(is_synthetic=True)
)],
source_location=ir_pb2.Location(is_synthetic=True)
),
type=expression.type,
source_location=ir_pb2.Location(is_synthetic=True)
)
# This loop essentially starts with:
#
# f(g(x)) == $logical_value
#
# and ends with
#
# x == g_inv(f_inv($logical_value))
#
# At each step, `subexpression` has one layer removed, and `result` has a
# corresponding inverse function applied. So, for example, it might start
# with:
#
# 2 + ((3 - x) - 10) == $logical_value
#
# On each iteration, `subexpression` and `result` will become:
#
# (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides]
# (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides]
# x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3]
#
# This is an extremely limited algebraic solver, but it covers common-enough
# cases.
#
# Note that any equation that can be solved here becomes part of Emboss's
# contract, forever, so be conservative in expanding its solving capabilities!
for index in reference_path:
if subexpression.function.function == ir_pb2.Function.ADDITION:
result = ir_pb2.Expression(
function=ir_pb2.Function(
function=ir_pb2.Function.SUBTRACTION,
args=[
result,
subexpression.function.args[1 - index],
]
),
type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
)
elif subexpression.function.function == ir_pb2.Function.SUBTRACTION:
if index == 0:
result = ir_pb2.Expression(
function=ir_pb2.Function(
function=ir_pb2.Function.ADDITION,
args=[
result,
subexpression.function.args[1],
]
),
type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
)
else:
result = ir_pb2.Expression(
function=ir_pb2.Function(
function=ir_pb2.Function.SUBTRACTION,
args=[
subexpression.function.args[0],
result,
]
),
type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
)
else:
return None
subexpression = subexpression.function.args[index]
expression_bounds.compute_constraints_of_expression(result, ir)
return subexpression, result
def _add_write_method(field, ir):
"""Adds an appropriate write_method to field, if applicable.
Currently, the "alias" write_method will be added for virtual fields of the
form `let v = some_field_reference` when `some_field_reference` is a physical
field or a writeable alias. The "physical" write_method will be added for
physical fields. The "transform" write_method will be added when the virtual
field's value is an easily-invertible function of a single writeable field.
All other fields will have the "read_only" write_method; i.e., they will not
be writeable.
Arguments:
field: an ir_pb2.Field to which to add a write_method.
ir: The IR in which to look up field_references.
Returns:
None
"""
if field.HasField("write_method"):
# Do not recompute anything.
return
if not ir_util.field_is_virtual(field):
# If the field is not virtual, writes are physical.
field.write_method.physical = True
return
# A virtual field cannot be a direct alias if it has an additional
# requirement.
requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
if (field.read_transform.WhichOneof("expression") != "field_reference" or
requires_attr is not None):
inverse = _invert_expression(field.read_transform, ir)
if inverse:
field_reference, function_body = inverse
referenced_field = ir_util.find_object(
field_reference.field_reference.path[-1], ir)
if not isinstance(referenced_field, ir_pb2.Field):
reference_is_read_only = True
else:
_add_write_method(referenced_field, ir)
reference_is_read_only = referenced_field.write_method.read_only
if not reference_is_read_only:
field.write_method.transform.destination.CopyFrom(
field_reference.field_reference)
field.write_method.transform.function_body.CopyFrom(function_body)
else:
# If the virtual field's expression is invertible, but its target field
# is read-only, it is also read-only.
field.write_method.read_only = True
else:
# If the virtual field's expression is not invertible, it is
# read-only.
field.write_method.read_only = True
return
referenced_field = ir_util.find_object(
field.read_transform.field_reference.path[-1], ir)
if not isinstance(referenced_field, ir_pb2.Field):
# If the virtual field aliases a non-field (i.e., a parameter), it is
# read-only.
field.write_method.read_only = True
return
_add_write_method(referenced_field, ir)
if referenced_field.write_method.read_only:
# If the virtual field directly aliases a read-only field, it is read-only.
field.write_method.read_only = True
return
# Otherwise, it can be written as a direct alias.
field.write_method.alias.CopyFrom(
field.read_transform.field_reference)
def set_write_methods(ir):
"""Sets the write_method member of all ir_pb2.Fields in ir.
Arguments:
ir: The IR to which to add write_methods.
Returns:
A list of errors, or an empty list.
"""
traverse_ir.fast_traverse_ir_top_down(ir, [ir_pb2.Field], _add_write_method)
return []