Improve perf of `_call_with_optional_args` (#120)

* Convert action lambdas to static functions

Cleanup in preparation for further performance improvements.

This changes lambda actions to static functions. Additionally duplicate
field actions of the form follwing form were removed:
  `ir_pb2.Field: lambda f: {"field": f}`.
These actions were already covered by the default actions and ended up
being run twice.

* Improve perf of _call_with_optional_args

This adds memoization of the results of `getfullargspec` calls. Running
against real-life examples it shows a 10% reduction of runtime for the
Emboss front-end.
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index 5f67d7c..b43ca0e 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -537,6 +537,8 @@
   # errors are just returned, rather than appended to a shared list.
   errors += _integer_bounds_errors_for_expression(expression, source_file_name)
 
+def _attribute_in_attribute_action(a):
+  return {"in_attribute": a}
 
 def check_constraints(ir):
   """Checks miscellaneous validity constraints in ir.
@@ -597,7 +599,7 @@
       parameters={"errors": errors})
   traverse_ir.fast_traverse_ir_top_down(
       ir, [ir_pb2.Expression], _check_bounds_on_runtime_integer_expressions,
-      incidental_actions={ir_pb2.Attribute: lambda a: {"in_attribute": a}},
+      incidental_actions={ir_pb2.Attribute: _attribute_in_attribute_action},
       skip_descendants_of={ir_pb2.EnumValue, ir_pb2.Expression},
       parameters={"errors": errors, "in_attribute": None})
   traverse_ir.fast_traverse_ir_top_down(
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 54bcdcc..f4fb581 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -468,6 +468,8 @@
       "visible_scopes": (field.name.canonical_name,) + visible_scopes,
   }
 
+def _module_source_from_table_action(m, table):
+  return {"module": table[m.source_file_name]}
 
 def _resolve_symbols_from_table(ir, table):
   """Resolves all references in the given IR, given the constructed table."""
@@ -477,7 +479,7 @@
   traverse_ir.fast_traverse_ir_top_down(
       ir, [ir_pb2.Import], _add_import_to_scope,
       incidental_actions={
-          ir_pb2.Module: lambda m, table: {"module": table[m.source_file_name]},
+          ir_pb2.Module: _module_source_from_table_action,
       },
       parameters={"errors": errors, "table": table})
   if errors:
@@ -490,7 +492,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"table": table, "errors": errors, "field": None})
@@ -500,7 +501,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"table": table, "errors": errors, "field": None})
@@ -515,7 +515,6 @@
       incidental_actions={
           ir_pb2.TypeDefinition: _set_visible_scopes_for_type_definition,
           ir_pb2.Module: _set_visible_scopes_for_module,
-          ir_pb2.Field: lambda f: {"field": f},
           ir_pb2.Attribute: _set_visible_scopes_for_attribute,
       },
       parameters={"errors": errors, "field": None})
diff --git a/compiler/util/traverse_ir.py b/compiler/util/traverse_ir.py
index fc04ba8..3bd95c3 100644
--- a/compiler/util/traverse_ir.py
+++ b/compiler/util/traverse_ir.py
@@ -17,27 +17,98 @@
 import inspect
 
 from compiler.util import ir_pb2
+from compiler.util import simple_memoizer
+
+
+class _FunctionCaller:
+  """Provides a template for setting up a generic call to a function.
+
+  The function parameters are inspected at run-time to build up a set of valid
+  and required arguments. When invoking the function unneccessary parameters
+  will be trimmed out. If arguments are missing an assertion will be triggered.
+
+  This is currently limited to functions that have at least one positional
+  parameter.
+
+  Example usage:
+  ```
+  def func_1(a, b, c=2): pass
+  def func_2(a, d): pass
+  caller_1 = _FunctionCaller(func_1)
+  caller_2 = _FunctionCaller(func_2)
+  generic_params = {"b": 2, "c": 3, "d": 4}
+
+  # Equivalent of: func_1(a, b=2, c=3)
+  caller_1.invoke(a, generic_params)
+
+  # Equivalent of: func_2(a, d=4)
+  caller_2.invoke(a, generic_params)
+  """
+
+  def __init__(self, function):
+    self.function = function
+    self.needs_filtering = True
+    self.valid_arg_names = set()
+    self.required_arg_names = set()
+
+    argspec = inspect.getfullargspec(function)
+    if argspec.varkw:
+      # If the function accepts a kwargs parameter, then it will accept all
+      # arguments.
+      # Note: this isn't technically true if one of the keyword arguments has the
+      # same name as one of the positional arguments.
+      self.needs_filtering = False
+    else:
+      # argspec.args is a list of all parameter names excluding keyword only
+      # args. The first element is our required positional_arg and should be
+      # ignored.
+      args = argspec.args[1:]
+      self.valid_arg_names.update(args)
+
+      # args.kwonlyargs gives us the list of keyword only args which are
+      # also valid.
+      self.valid_arg_names.update(argspec.kwonlyargs)
+
+      # Required args are positional arguments that don't have defaults.
+      # Keyword only args are always optional and can be ignored. Args with
+      # defaults are the last elements of the argsepec.args list and should
+      # be ignored.
+      if argspec.defaults:
+        # Trim the arguments with defaults.
+        args = args[: -len(argspec.defaults)]
+      self.required_arg_names.update(args)
+
+  def invoke(self, positional_arg, keyword_args):
+    """Invokes the function with the given args."""
+    if self.needs_filtering:
+      # Trim to just recognized args.
+      matched_args = {
+          k: v for k, v in keyword_args.items() if k in self.valid_arg_names
+      }
+      # Check if any required args are missing.
+      missing_args = self.required_arg_names.difference(matched_args.keys())
+      assert not missing_args, (
+          f"Attempting to call '{self.function.__name__}'; "
+          f"missing {missing_args} (have {set(keyword_args.keys())})"
+      )
+      keyword_args = matched_args
+
+    return self.function(positional_arg, **keyword_args)
+
+
+@simple_memoizer.memoize
+def _memoized_caller(function):
+  default_lambda_name = (lambda: None).__name__
+  assert (
+      callable(function) and not function.__name__ == default_lambda_name
+  ), "For performance reasons actions must be defined as static functions"
+  return _FunctionCaller(function)
 
 
 def _call_with_optional_args(function, positional_arg, keyword_args):
   """Calls function with whatever keyword_args it will accept."""
-  argspec = inspect.getfullargspec(function)
-  if argspec.varkw:
-    # If the function accepts a kwargs parameter, then it will accept all
-    # arguments.
-    # Note: this isn't technically true if one of the keyword arguments has the
-    # same name as one of the positional arguments.
-    return function(positional_arg, **keyword_args)
-  else:
-    ok_arguments = {}
-    for name in keyword_args:
-      if name in argspec.args[1:] or name in argspec.kwonlyargs:
-        ok_arguments[name] = keyword_args[name]
-    for name in argspec.args[1:len(argspec.args) - len(argspec.defaults or [])]:
-      assert name in ok_arguments, (
-          "Attempting to call '{}'; missing '{}' (have '{!r}')".format(
-              function.__name__, name, list(keyword_args.keys())))
-    return function(positional_arg, **ok_arguments)
+  caller = _memoized_caller(function)
+  return caller.invoke(positional_arg, keyword_args)
 
 
 def _fast_traverse_proto_top_down(proto, incidental_actions, pattern,
@@ -181,6 +252,17 @@
 
 _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target()
 
+def _emboss_ir_action(ir):
+  return {"ir": ir}
+
+def _module_action(m):
+  return {"source_file_name": m.source_file_name}
+
+def _type_definition_action(t):
+  return {"type_definition": t}
+
+def _field_action(f):
+  return {"field": f}
 
 def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None,
                               skip_descendants_of=(), parameters=None):
@@ -269,10 +351,10 @@
     None
   """
   all_incidental_actions = {
-      ir_pb2.EmbossIr: [lambda ir: {"ir": ir}],
-      ir_pb2.Module: [lambda m: {"source_file_name": m.source_file_name}],
-      ir_pb2.TypeDefinition: [lambda t: {"type_definition": t}],
-      ir_pb2.Field: [lambda f: {"field": f}],
+      ir_pb2.EmbossIr: [_emboss_ir_action],
+      ir_pb2.Module: [_module_action],
+      ir_pb2.TypeDefinition: [_type_definition_action],
+      ir_pb2.Field: [_field_action],
   }
   if incidental_actions:
     for key, incidental_action in incidental_actions.items():