# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Adds auto-generated virtual fields to the IR."""

from compiler.front_end import attributes
from compiler.front_end import expression_bounds
from compiler.util import ir_pb2
from compiler.util import ir_util
from compiler.util import traverse_ir


def _find_field_reference_path(expression):
  """Returns a path to a field reference, or None.

  If the provided expression contains exactly one field_reference,
  _find_field_reference_path will return a list of indexes, such that
  recursively reading the index'th element of expression.function.args will find
  the field_reference.  For example, for:

      5 + (x * 2)

  _find_field_reference_path will return [1, 0]: from the top-level `+`
  expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
  expression.

  Arguments:
    expression: an ir_pb2.Expression to walk

  Returns:
    A list of indexes to find a field_reference, or None.
  """
  found, indexes = _recursively_find_field_reference_path(expression)
  if found == 1:
    return indexes
  else:
    return None


def _recursively_find_field_reference_path(expression):
  """Recursive implementation of _find_field_reference_path."""
  if expression.WhichOneof("expression") == "field_reference":
    return 1, []
  elif expression.WhichOneof("expression") == "function":
    field_count = 0
    path = []
    for index in range(len(expression.function.args)):
      arg = expression.function.args[index]
      arg_result = _recursively_find_field_reference_path(arg)
      arg_field_count, arg_path = arg_result
      if arg_field_count == 1 and field_count == 0:
        path = [index] + arg_path
      field_count += arg_field_count
    if field_count == 1:
      return field_count, path
    else:
      return field_count, []
  else:
    return 0, []


def _invert_expression(expression, ir):
  """For the given expression, searches for an algebraic inverse expression.

  That is, it takes the notional equation:

      $logical_value = expression

  and, if there is exactly one `field_reference` in `expression`, it will
  attempt to solve the equation for that field.  For example, if the expression
  is `x + 1`, it will iteratively transform:

      $logical_value = x + 1
      $logical_value - 1 = x + 1 - 1
      $logical_value - 1 = x

  and finally return `x` and `$logical_value - 1`.

  The purpose of this transformation is to find an assignment statement that can
  be used to write back through certain virtual fields.  E.g., given:

      struct Foo:
        0 [+1]  UInt  raw_value
        let actual_value = raw_value + 100

  it should be possible to write a value to the `actual_value` field, and have
  it set `raw_value` to the appropriate value.

  Arguments:
    expression: an ir_pb2.Expression to be inverted.
    ir: the full IR, for looking up symbols.

  Returns:
    (field_reference, inverse_expression) if expression can be inverted,
    otherwise None.
  """
  reference_path = _find_field_reference_path(expression)
  if reference_path is None:
    return None
  subexpression = expression
  result = ir_pb2.Expression(
      builtin_reference=ir_pb2.Reference(
          canonical_name=ir_pb2.CanonicalName(
              module_file="",
              object_path=["$logical_value"]
          ),
          source_name=[ir_pb2.Word(
              text="$logical_value",
              source_location=ir_pb2.Location(is_synthetic=True)
          )],
          source_location=ir_pb2.Location(is_synthetic=True)
      ),
      type=expression.type,
      source_location=ir_pb2.Location(is_synthetic=True)
  )

  # This loop essentially starts with:
  #
  #     f(g(x)) == $logical_value
  #
  # and ends with
  #
  #     x == g_inv(f_inv($logical_value))
  #
  # At each step, `subexpression` has one layer removed, and `result` has a
  # corresponding inverse function applied.  So, for example, it might start
  # with:
  #
  #     2 + ((3 - x) - 10)  ==  $logical_value
  #
  # On each iteration, `subexpression` and `result` will become:
  #
  #     (3 - x) - 10  ==  $logical_value - 2    [subtract 2 from both sides]
  #     (3 - x)  ==  ($logical_value - 2) + 10  [add 10 to both sides]
  #     x  ==  3 - (($logical_value - 2) + 10)  [subtract both sides from 3]
  #
  # This is an extremely limited algebraic solver, but it covers common-enough
  # cases.
  #
  # Note that any equation that can be solved here becomes part of Emboss's
  # contract, forever, so be conservative in expanding its solving capabilities!
  for index in reference_path:
    if subexpression.function.function == ir_pb2.FunctionMapping.ADDITION:
      result = ir_pb2.Expression(
          function=ir_pb2.Function(
              function=ir_pb2.FunctionMapping.SUBTRACTION,
              args=[
                  result,
                  subexpression.function.args[1 - index],
              ]
          ),
          type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
      )
    elif subexpression.function.function == ir_pb2.FunctionMapping.SUBTRACTION:
      if index == 0:
        result = ir_pb2.Expression(
            function=ir_pb2.Function(
                function=ir_pb2.FunctionMapping.ADDITION,
                args=[
                    result,
                    subexpression.function.args[1],
                ]
            ),
            type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
        )
      else:
        result = ir_pb2.Expression(
            function=ir_pb2.Function(
                function=ir_pb2.FunctionMapping.SUBTRACTION,
                args=[
                    subexpression.function.args[0],
                    result,
                ]
            ),
            type=ir_pb2.ExpressionType(integer=ir_pb2.IntegerType())
        )
    else:
      return None
    subexpression = subexpression.function.args[index]
  expression_bounds.compute_constraints_of_expression(result, ir)
  return subexpression, result


def _add_write_method(field, ir):
  """Adds an appropriate write_method to field, if applicable.

  Currently, the "alias" write_method will be added for virtual fields of the
  form `let v = some_field_reference` when `some_field_reference` is a physical
  field or a writeable alias.  The "physical" write_method will be added for
  physical fields.  The "transform" write_method will be added when the virtual
  field's value is an easily-invertible function of a single writeable field.
  All other fields will have the "read_only" write_method; i.e., they will not
  be writeable.

  Arguments:
    field: an ir_pb2.Field to which to add a write_method.
    ir: The IR in which to look up field_references.

  Returns:
    None
  """
  if field.HasField("write_method"):
    # Do not recompute anything.
    return

  if not ir_util.field_is_virtual(field):
    # If the field is not virtual, writes are physical.
    field.write_method.physical = True
    return

  # A virtual field cannot be a direct alias if it has an additional
  # requirement.
  requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
  if (field.read_transform.WhichOneof("expression") != "field_reference" or
      requires_attr is not None):
    inverse = _invert_expression(field.read_transform, ir)
    if inverse:
      field_reference, function_body = inverse
      referenced_field = ir_util.find_object(
          field_reference.field_reference.path[-1], ir)
      if not isinstance(referenced_field, ir_pb2.Field):
        reference_is_read_only = True
      else:
        _add_write_method(referenced_field, ir)
        reference_is_read_only = referenced_field.write_method.read_only
      if not reference_is_read_only:
        field.write_method.transform.destination.CopyFrom(
            field_reference.field_reference)
        field.write_method.transform.function_body.CopyFrom(function_body)
      else:
        # If the virtual field's expression is invertible, but its target field
        # is read-only, it is also read-only.
        field.write_method.read_only = True
    else:
      # If the virtual field's expression is not invertible, it is
      # read-only.
      field.write_method.read_only = True
    return

  referenced_field = ir_util.find_object(
      field.read_transform.field_reference.path[-1], ir)
  if not isinstance(referenced_field, ir_pb2.Field):
    # If the virtual field aliases a non-field (i.e., a parameter), it is
    # read-only.
    field.write_method.read_only = True
    return

  _add_write_method(referenced_field, ir)
  if referenced_field.write_method.read_only:
    # If the virtual field directly aliases a read-only field, it is read-only.
    field.write_method.read_only = True
    return

  # Otherwise, it can be written as a direct alias.
  field.write_method.alias.CopyFrom(
      field.read_transform.field_reference)


def set_write_methods(ir):
  """Sets the write_method member of all ir_pb2.Fields in ir.

  Arguments:
      ir: The IR to which to add write_methods.

  Returns:
      A list of errors, or an empty list.
  """
  traverse_ir.fast_traverse_ir_top_down(ir, [ir_pb2.Field], _add_write_method)
  return []
