pw_env_setup: Separate generation from data model
Separate generation of environment scripts/JSON files from the object
that stores that environment.
Change-Id: I4dba68636d79d23bc50ec737753b115ed81e26d5
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/33000
Reviewed-by: Keir Mierle <keir@google.com>
Commit-Queue: Rob Mohr <mohrr@google.com>
diff --git a/pw_env_setup/py/BUILD.gn b/pw_env_setup/py/BUILD.gn
index 5aaed810..030b671 100644
--- a/pw_env_setup/py/BUILD.gn
+++ b/pw_env_setup/py/BUILD.gn
@@ -20,6 +20,8 @@
setup = [ "setup.py" ]
sources = [
"pw_env_setup/__init__.py",
+ "pw_env_setup/apply_visitor.py",
+ "pw_env_setup/batch_visitor.py",
"pw_env_setup/cargo_setup/__init__.py",
"pw_env_setup/cipd_setup/__init__.py",
"pw_env_setup/cipd_setup/update.py",
@@ -27,12 +29,17 @@
"pw_env_setup/colors.py",
"pw_env_setup/env_setup.py",
"pw_env_setup/environment.py",
+ "pw_env_setup/json_visitor.py",
+ "pw_env_setup/shell_visitor.py",
"pw_env_setup/spinner.py",
"pw_env_setup/virtualenv_setup/__init__.py",
"pw_env_setup/virtualenv_setup/__main__.py",
"pw_env_setup/virtualenv_setup/install.py",
"pw_env_setup/windows_env_start.py",
]
- tests = [ "environment_test.py" ]
+ tests = [
+ "environment_test.py",
+ "json_visitor_test.py",
+ ]
pylintrc = "$dir_pigweed/.pylintrc"
}
diff --git a/pw_env_setup/py/json_visitor_test.py b/pw_env_setup/py/json_visitor_test.py
new file mode 100644
index 0000000..73ed3c7
--- /dev/null
+++ b/pw_env_setup/py/json_visitor_test.py
@@ -0,0 +1,102 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests for env_setup.environment.
+
+This tests the error-checking, context manager, and written environment scripts
+of the Environment class.
+
+Tests that end in "_ctx" modify the environment and validate it in-process.
+
+Tests that end in "_written" write the environment to a file intended to be
+evaluated by the shell, then launches the shell and then saves the environment.
+This environment is then validated in the test process.
+"""
+
+import json
+import unittest
+
+import six
+
+from pw_env_setup import environment, json_visitor
+
+
+# pylint: disable=super-with-arguments
+class JSONVisitorTest(unittest.TestCase):
+ """Tests for env_setup.json_visitor."""
+ def setUp(self):
+ self.env = environment.Environment()
+
+ def _write_and_parse_json(self):
+ buf = six.StringIO()
+ json_visitor.JSONVisitor(self.env, buf)
+ return json.loads(buf.getvalue())
+
+ def _assert_json(self, value):
+ self.assertEqual(self._write_and_parse_json(), value)
+
+ def test_set(self):
+ self.env.clear('VAR')
+ self.env.set('VAR', '1')
+ self._assert_json({'set': {'VAR': '1'}})
+
+ def test_clear(self):
+ self.env.set('VAR', '1')
+ self.env.clear('VAR')
+ self._assert_json({'set': {'VAR': None}})
+
+ def test_append(self):
+ self.env.append('VAR', 'path1')
+ self.env.append('VAR', 'path2')
+ self.env.append('VAR', 'path3')
+ self._assert_json(
+ {'modify': {
+ 'VAR': {
+ 'append': 'path1 path2 path3'.split()
+ }
+ }})
+
+ def test_prepend(self):
+ self.env.prepend('VAR', 'path1')
+ self.env.prepend('VAR', 'path2')
+ self.env.prepend('VAR', 'path3')
+ self._assert_json(
+ {'modify': {
+ 'VAR': {
+ 'prepend': 'path3 path2 path1'.split()
+ }
+ }})
+
+ def test_remove(self):
+ self.env.remove('VAR', 'path1')
+ self._assert_json({'modify': {'VAR': {'remove': ['path1']}}})
+
+ def test_echo(self):
+ self.env.echo('echo')
+ self._assert_json({})
+
+ def test_comment(self):
+ self.env.comment('comment')
+ self._assert_json({})
+
+ def test_command(self):
+ self.env.command('command')
+ self._assert_json({})
+
+ def test_doctor(self):
+ self.env.doctor()
+ self._assert_json({})
+
+ def test_function(self):
+ self.env.function('name', 'body')
+ self._assert_json({})
diff --git a/pw_env_setup/py/pw_env_setup/apply_visitor.py b/pw_env_setup/py/pw_env_setup/apply_visitor.py
new file mode 100644
index 0000000..b26f33a
--- /dev/null
+++ b/pw_env_setup/py/pw_env_setup/apply_visitor.py
@@ -0,0 +1,79 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Applies an Environment to the current process."""
+
+import os
+
+# Disable super() warnings since this file must be Python 2 compatible.
+# pylint: disable=super-with-arguments
+
+
+class ApplyVisitor(object): # pylint: disable=useless-object-inheritance
+ """Applies an Environment to the current process."""
+ def __init__(self, *args, **kwargs):
+ pathsep = kwargs.pop('pathsep', os.pathsep)
+ super(ApplyVisitor, self).__init__(*args, **kwargs)
+ self._pathsep = pathsep
+ self._environ = None
+ self._unapply_steps = None
+
+ def apply(self, env, environ):
+ self._unapply_steps = []
+ try:
+ self._environ = environ
+ env.accept(self)
+ finally:
+ self._environ = None
+
+ def visit_set(self, set): # pylint: disable=redefined-builtin
+ self._environ[set.name] = set.value
+
+ def visit_clear(self, clear):
+ if clear.name in self._environ:
+ del self._environ[clear.name]
+
+ def visit_remove(self, remove):
+ values = self._environ.get(remove.name, '').split(self._pathsep)
+ norm = os.path.normpath
+ values = [x for x in values if norm(x) != norm(remove.value)]
+ self._environ[remove.name] = self._pathsep.join(values)
+
+ def visit_prepend(self, prepend):
+ self._environ[prepend.name] = self._pathsep.join(
+ (prepend.value, self._environ.get(prepend.name, '')))
+
+ def visit_append(self, append):
+ self._environ[append.name] = self._pathsep.join(
+ (self._environ.get(append.name, ''), append.value))
+
+ def visit_echo(self, echo):
+ pass # Not relevant for apply.
+
+ def visit_comment(self, comment):
+ pass # Not relevant for apply.
+
+ def visit_command(self, command):
+ pass # Not relevant for apply.
+
+ def visit_doctor(self, doctor):
+ pass # Not relevant for apply.
+
+ def visit_blank_line(self, blank_line):
+ pass # Not relevant for apply.
+
+ def visit_function(self, function):
+ pass # Not relevant for apply.
+
+ def visit_hash(self, hash): # pylint: disable=redefined-builtin
+ pass # Not relevant for apply.
diff --git a/pw_env_setup/py/pw_env_setup/batch_visitor.py b/pw_env_setup/py/pw_env_setup/batch_visitor.py
new file mode 100644
index 0000000..e7c5353
--- /dev/null
+++ b/pw_env_setup/py/pw_env_setup/batch_visitor.py
@@ -0,0 +1,122 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Serializes an Environment into a batch file."""
+
+# Disable super() warnings since this file must be Python 2 compatible.
+# pylint: disable=super-with-arguments
+
+# goto label written to the end of Windows batch files for exiting a script.
+_SCRIPT_END_LABEL = '_pw_end'
+
+
+class BatchVisitor(object): # pylint: disable=useless-object-inheritance
+ """Serializes an Environment into a batch file."""
+ def __init__(self, *args, **kwargs):
+ pathsep = kwargs.pop('pathsep', ':')
+ super(BatchVisitor, self).__init__(*args, **kwargs)
+ self._replacements = ()
+ self._outs = None
+ self._pathsep = pathsep
+
+ def serialize(self, env, outs):
+ try:
+ self._replacements = tuple(
+ (key, env.get(key) if value is None else value)
+ for key, value in env.replacements)
+ self._outs = outs
+ self._outs.write('@echo off\n')
+
+ env.accept(self)
+
+ outs.write(':{}\n'.format(_SCRIPT_END_LABEL))
+
+ finally:
+ self._replacements = ()
+ self._outs = None
+
+ def _apply_replacements(self, action):
+ value = action.value
+ for var, replacement in self._replacements:
+ if var != action.name:
+ value = value.replace(replacement, '%{}%'.format(var))
+ return value
+
+ def visit_set(self, set): # pylint: disable=redefined-builtin
+ value = self._apply_replacements(set)
+ self._outs.write('set {name}={value}\n'.format(name=set.name,
+ value=value))
+
+ def visit_clear(self, clear):
+ self._outs.write('set {name}=\n'.format(name=clear.name))
+
+ def visit_remove(self, remove):
+ pass # Not supported on Windows.
+
+ def _join(self, *args):
+ if len(args) == 1 and isinstance(args[0], (list, tuple)):
+ args = args[0]
+ return self._pathsep.join(args)
+
+ def visit_prepend(self, prepend):
+ value = self._apply_replacements(prepend)
+ value = self._join(value, '%{}%'.format(prepend.name))
+ self._outs.write('set {name}={value}\n'.format(name=prepend.name,
+ value=value))
+
+ def visit_append(self, append):
+ value = self._apply_replacements(append)
+ value = self._join('%{}%'.format(append.name), value)
+ self._outs.write('set {name}={value}\n'.format(name=append.name,
+ value=value))
+
+ def visit_echo(self, echo):
+ if echo.newline:
+ if not echo.value:
+ self._outs.write('echo.\n')
+ else:
+ self._outs.write('echo {}\n'.format(echo.value))
+ else:
+ self._outs.write('<nul set /p="{}"\n'.format(echo.value))
+
+ def visit_comment(self, comment):
+ for line in comment.value.splitlines():
+ self._outs.write(':: {}\n'.format(line))
+
+ def visit_command(self, command):
+ # TODO(mohrr) use shlex.quote here?
+ self._outs.write('{}\n'.format(' '.join(command.command)))
+ if not command.exit_on_error:
+ return
+
+ # Assume failing command produced relevant output.
+ self._outs.write(
+ 'if %ERRORLEVEL% neq 0 goto {}\n'.format(_SCRIPT_END_LABEL))
+
+ def visit_doctor(self, doctor):
+ self._outs.write('if "%PW_ACTIVATE_SKIP_CHECKS%"=="" (\n')
+ self.visit_command(doctor)
+ self._outs.write(') else (\n')
+ self._outs.write('echo Skipping environment check because '
+ 'PW_ACTIVATE_SKIP_CHECKS is set\n')
+ self._outs.write(')\n')
+
+ def visit_blank_line(self, blank_line):
+ del blank_line
+ self._outs.write('\n')
+
+ def visit_function(self, function):
+ pass # Not supported on Windows.
+
+ def visit_hash(self, hash): # pylint: disable=redefined-builtin
+ pass # Not relevant on Windows.
diff --git a/pw_env_setup/py/pw_env_setup/environment.py b/pw_env_setup/py/pw_env_setup/environment.py
index 396671b..1a97c0a 100644
--- a/pw_env_setup/py/pw_env_setup/environment.py
+++ b/pw_env_setup/py/pw_env_setup/environment.py
@@ -14,7 +14,6 @@
"""Stores the environment changes necessary for Pigweed."""
import contextlib
-import json
import os
import re
@@ -27,12 +26,14 @@
except ImportError:
from io import StringIO
+from . import apply_visitor
+from . import batch_visitor
+from . import json_visitor
+from . import shell_visitor
+
# Disable super() warnings since this file must be Python 2 compatible.
# pylint: disable=super-with-arguments
-# goto label written to the end of Windows batch files for exiting a script.
-_SCRIPT_END_LABEL = '_pw_end'
-
class BadNameType(TypeError):
pass
@@ -58,12 +59,18 @@
pass
+class AcceptNotOverridden(TypeError):
+ pass
+
+
class _Action(object): # pylint: disable=useless-object-inheritance
def unapply(self, env, orig_env):
pass
- def json(self, data):
- pass
+ def accept(self, visitor):
+ del visitor
+ raise AcceptNotOverridden('accept() not overridden for {}'.format(
+ self.__class__.__name__))
def write_deactivate(self,
outs,
@@ -118,43 +125,10 @@
env.pop(self.name, None)
-def _var_form(variable, windows=(os.name == 'nt')):
- if windows:
- return '%{}%'.format(variable)
- return '${}'.format(variable)
-
-
class Set(_VariableAction):
"""Set a variable."""
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- value = self.value
- for var, replacement in replacements:
- if var != self.name:
- value = value.replace(replacement, _var_form(var, windows))
-
- if windows:
- outs.write('set {name}={value}\n'.format(name=self.name,
- value=value))
- else:
- outs.write('{name}="{value}"\nexport {name}\n'.format(
- name=self.name, value=value))
-
- def write_deactivate(self,
- outs,
- windows=(os.name == 'nt'),
- replacements=()):
- del replacements # Unused.
-
- if windows:
- outs.write('set {name}=\n'.format(name=self.name))
- else:
- outs.write('unset {name}\n'.format(name=self.name))
-
- def apply(self, env):
- env[self.name] = self.value
-
- def json(self, data):
- data['set'][self.name] = self.value
+ def accept(self, visitor):
+ visitor.visit_set(self)
class Clear(_VariableAction):
@@ -164,76 +138,14 @@
kwargs['allow_empty_values'] = True
super(Clear, self).__init__(*args, **kwargs)
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- del replacements # Unused.
- if windows:
- outs.write('set {name}=\n'.format(**vars(self)))
- else:
- outs.write('unset {name}\n'.format(**vars(self)))
-
- def apply(self, env):
- if self.name in env:
- del env[self.name]
-
- def json(self, data):
- data['set'][self.name] = None
-
-
-def _initialize_path_like_variable(data, name):
- default = {'append': [], 'prepend': [], 'remove': []}
- data['modify'].setdefault(name, default)
-
-
-def _remove_value_from_path(variable, value, pathsep):
- return ('{variable}="$(echo "${variable}"'
- ' | sed "s|{pathsep}{value}{pathsep}|{pathsep}|g;"'
- ' | sed "s|^{value}{pathsep}||g;"'
- ' | sed "s|{pathsep}{value}$||g;"'
- ')"\nexport {variable}\n'.format(variable=variable,
- value=value,
- pathsep=pathsep))
+ def accept(self, visitor):
+ visitor.visit_clear(self)
class Remove(_VariableAction):
"""Remove a value from a PATH-like variable."""
- def __init__(self, name, value, pathsep, *args, **kwargs):
- super(Remove, self).__init__(name, value, *args, **kwargs)
- self._pathsep = pathsep
-
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- value = self.value
- for var, replacement in replacements:
- if var != self.name:
- value = value.replace(replacement, _var_form(var, windows))
-
- if windows:
- pass
- # TODO(pwbug/231) This does not seem to be supported when value
- # contains a %variable%. Disabling for now.
- # outs.write(':: Remove\n:: {value}\n:: from\n:: {name}\n'
- # ':: before adding it back.\n'
- # 'set {name}=%{name}:{value}{pathsep}=%\n'.format(
- # name=self.name, value=value, pathsep=self._pathsep))
-
- else:
- outs.write('# Remove \n# {value}\n# from\n# {value}\n# before '
- 'adding it back.\n')
- outs.write(_remove_value_from_path(self.name, value,
- self._pathsep))
-
- def apply(self, env):
- env[self.name] = env[self.name].replace(
- '{}{}'.format(self.value, self._pathsep), '')
- env[self.name] = env[self.name].replace(
- '{}{}'.format(self._pathsep, self.value), '')
-
- def json(self, data):
- _initialize_path_like_variable(data, self.name)
- data['modify'][self.name]['remove'].append(self.value)
- if self.value in data['modify'][self.name]['append']:
- data['modify'][self.name]['append'].remove(self.value)
- if self.value in data['modify'][self.name]['prepend']:
- data['modify'][self.name]['prepend'].remove(self.value)
+ def accept(self, visitor):
+ visitor.visit_remove(self)
class BadVariableValue(ValueError):
@@ -251,44 +163,12 @@
super(Prepend, self).__init__(name, value, *args, **kwargs)
self._join = join
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- value = self.value
- for var, replacement in replacements:
- if var != self.name:
- value = value.replace(replacement, _var_form(var, windows))
- value = self._join(value, _var_form(self.name, windows))
-
- if windows:
- outs.write('set {name}={value}\n'.format(name=self.name,
- value=value))
- else:
- outs.write('{name}="{value}"\nexport {name}\n'.format(
- name=self.name, value=value))
-
- def write_deactivate(self,
- outs,
- windows=(os.name == 'nt'),
- replacements=()):
- value = self.value
- for var, replacement in replacements:
- if var != self.name:
- value = value.replace(replacement, _var_form(var, windows))
-
- outs.write(
- _remove_value_from_path(self.name, value, self._join.pathsep))
-
- def apply(self, env):
- env[self.name] = self._join(self.value, env.get(self.name, ''))
-
def _check(self):
super(Prepend, self)._check()
_append_prepend_check(self)
- def json(self, data):
- _initialize_path_like_variable(data, self.name)
- data['modify'][self.name]['prepend'].append(self.value)
- if self.value in data['modify'][self.name]['remove']:
- data['modify'][self.name]['remove'].remove(self.value)
+ def accept(self, visitor):
+ visitor.visit_prepend(self)
class Append(_VariableAction):
@@ -297,44 +177,12 @@
super(Append, self).__init__(name, value, *args, **kwargs)
self._join = join
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- value = self.value
- for var, repl_value in replacements:
- if var != self.name:
- value = value.replace(repl_value, _var_form(var, windows))
- value = self._join(_var_form(self.name, windows), value)
-
- if windows:
- outs.write('set {name}={value}\n'.format(name=self.name,
- value=value))
- else:
- outs.write('{name}="{value}"\nexport {name}\n'.format(
- name=self.name, value=value))
-
- def write_deactivate(self,
- outs,
- windows=(os.name == 'nt'),
- replacements=()):
- value = self.value
- for var, replacement in replacements:
- if var != self.name:
- value = value.replace(replacement, _var_form(var, windows))
-
- outs.write(
- _remove_value_from_path(self.name, value, self._join.pathsep))
-
- def apply(self, env):
- env[self.name] = self._join(env.get(self.name, ''), self.value)
-
def _check(self):
super(Append, self)._check()
_append_prepend_check(self)
- def json(self, data):
- _initialize_path_like_variable(data, self.name)
- data['modify'][self.name]['append'].append(self.value)
- if self.value in data['modify'][self.name]['remove']:
- data['modify'][self.name]['remove'].remove(self.value)
+ def accept(self, visitor):
+ visitor.visit_append(self)
class BadEchoValue(ValueError):
@@ -349,31 +197,10 @@
raise BadEchoValue(value)
super(Echo, self).__init__(*args, **kwargs)
self.value = value
- self._newline = newline
+ self.newline = newline
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- del replacements # Unused.
- # POSIX shells parse arguments and pass to echo, but Windows seems to
- # pass the command line as is without parsing, so quoting is wrong.
- if windows:
- if self._newline:
- if not self.value:
- outs.write('echo.\n')
- else:
- outs.write('echo {}\n'.format(self.value))
- else:
- outs.write('<nul set /p="{}"\n'.format(self.value))
- else:
- # TODO(mohrr) use shlex.quote().
- outs.write('if [ -z "${PW_ENVSETUP_QUIET:-}" ]; then\n')
- if self._newline:
- outs.write(' echo "{}"\n'.format(self.value))
- else:
- outs.write(' echo -n "{}"\n'.format(self.value))
- outs.write('fi\n')
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_echo(self)
class Comment(_Action):
@@ -382,14 +209,8 @@
super(Comment, self).__init__(*args, **kwargs)
self.value = value
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- del replacements # Unused.
- comment_char = '::' if windows else '#'
- for line in self.value.splitlines():
- outs.write('{} {}\n'.format(comment_char, line))
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_comment(self)
class Command(_Action):
@@ -401,22 +222,8 @@
self.command = command
self.exit_on_error = exit_on_error
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- del replacements # Unused.
- # TODO(mohrr) use shlex.quote here?
- outs.write('{}\n'.format(' '.join(self.command)))
- if not self.exit_on_error:
- return
-
- if windows:
- outs.write(
- 'if %ERRORLEVEL% neq 0 goto {}\n'.format(_SCRIPT_END_LABEL))
- else:
- # Assume failing command produced relevant output.
- outs.write('if [ "$?" -ne 0 ]; then\n return 1\nfi\n')
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_command(self)
class Doctor(Command):
@@ -427,94 +234,35 @@
*args,
**kwargs)
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- super_call = lambda: super(Doctor, self).write(
- outs, windows=windows, replacements=replacements)
-
- if windows:
- outs.write('if "%PW_ACTIVATE_SKIP_CHECKS%"=="" (\n')
- super_call()
- outs.write(') else (\n')
- outs.write('echo Skipping environment check because '
- 'PW_ACTIVATE_SKIP_CHECKS is set\n')
- outs.write(')\n')
- else:
- outs.write('if [ -z "$PW_ACTIVATE_SKIP_CHECKS" ]; then\n')
- super_call()
- outs.write('else\n')
- outs.write('echo Skipping environment check because '
- 'PW_ACTIVATE_SKIP_CHECKS is set\n')
- outs.write('fi\n')
+ def accept(self, visitor):
+ visitor.visit_doctor(self)
class BlankLine(_Action):
"""Write a blank line to the init script."""
- def write( # pylint: disable=no-self-use
- self,
- outs,
- windows=(os.name == 'nt'),
- replacements=()):
- del replacements, windows # Unused.
- outs.write('\n')
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_blank_line(self)
class Function(_Action):
def __init__(self, name, body, *args, **kwargs):
super(Function, self).__init__(*args, **kwargs)
- self._name = name
- self._body = body
+ self.name = name
+ self.body = body
- def write(self, outs, windows=(os.name == 'nt'), replacements=()):
- del replacements # Unused.
- if windows:
- return
-
- outs.write("""
-{name}() {{
-{body}
-}}
- """.strip().format(name=self._name, body=self._body))
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_function(self)
class Hash(_Action):
- def write( # pylint: disable=no-self-use
- self,
- outs,
- windows=(os.name == 'nt'),
- replacements=()):
- del replacements # Unused.
-
- if windows:
- return
-
- outs.write('''
-# This should detect bash and zsh, which have a hash command that must be
-# called to get it to forget past commands. Without forgetting past
-# commands the $PATH changes we made may not be respected.
-if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
- hash -r\n
-fi
-''')
-
- def apply(self, env):
- pass
+ def accept(self, visitor):
+ visitor.visit_hash(self)
class Join(object): # pylint: disable=useless-object-inheritance
def __init__(self, pathsep=os.pathsep):
self.pathsep = pathsep
- def __call__(self, *args):
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
- args = args[0]
- return self.pathsep.join(args)
-
# TODO(mohrr) remove disable=useless-object-inheritance once in Python 3.
# pylint: disable=useless-object-inheritance
@@ -533,12 +281,12 @@
self._pathsep = pathsep
self._windows = windows
self._allcaps = allcaps
- self._replacements = []
+ self.replacements = []
self._join = Join(pathsep)
self._finalized = False
def add_replacement(self, variable, value=None):
- self._replacements.append((variable, value))
+ self.replacements.append((variable, value))
def normalize_key(self, name):
if self._allcaps:
@@ -643,49 +391,29 @@
if not self._windows:
buf = StringIO()
- for action in self._actions:
- action.write_deactivate(buf, windows=self._windows)
+ self.write_deactivate(buf)
self._actions.append(Function('_pw_deactivate', buf.getvalue()))
self._blankline()
- def write(self, outs):
- """Writes a shell init script to outs."""
- if self._windows:
- outs.write('@echo off\n')
-
- # This is a tuple and not a dictionary because we don't need random
- # access and order needs to be preserved.
- replacements = tuple((key, self.get(key) if value is None else value)
- for key, value in self._replacements)
-
+ def accept(self, visitor):
for action in self._actions:
- action.write(outs,
- windows=self._windows,
- replacements=replacements)
-
- if self._windows:
- outs.write(':{}\n'.format(_SCRIPT_END_LABEL))
+ action.accept(visitor)
def json(self, outs):
- data = {
- 'modify': {},
- 'set': {},
- }
+ json_visitor.JSONVisitor().serialize(self, outs)
- for action in self._actions:
- action.json(data)
-
- json.dump(data, outs, indent=4, separators=(',', ': '))
- outs.write('\n')
+ def write(self, outs):
+ if self._windows:
+ visitor = batch_visitor.BatchVisitor(pathsep=self._pathsep)
+ else:
+ visitor = shell_visitor.ShellVisitor(pathsep=self._pathsep)
+ visitor.serialize(self, outs)
def write_deactivate(self, outs):
if self._windows:
- outs.write('@echo off\n')
-
- for action in reversed(self._actions):
- action.write_deactivate(outs,
- windows=self._windows,
- replacements=())
+ return
+ visitor = shell_visitor.DeactivateShellVisitor(pathsep=self._pathsep)
+ visitor.serialize(self, outs)
@contextlib.contextmanager
def __call__(self, export=True):
@@ -711,15 +439,20 @@
else:
env = os.environ.copy()
- for action in self._actions:
- action.apply(env)
+ apply = apply_visitor.ApplyVisitor(pathsep=self._pathsep)
+ apply.apply(self, env)
yield env
finally:
if export:
- for action in self._actions:
- action.unapply(env=os.environ, orig_env=orig_env)
+ for key in set(os.environ):
+ try:
+ os.environ[key] = orig_env[key]
+ except KeyError:
+ del os.environ[key]
+ for key in set(orig_env) - set(os.environ):
+ os.environ[key] = orig_env[key]
def get(self, key, default=None):
"""Get the value of a variable within context of this object."""
diff --git a/pw_env_setup/py/pw_env_setup/json_visitor.py b/pw_env_setup/py/pw_env_setup/json_visitor.py
new file mode 100644
index 0000000..f96acac
--- /dev/null
+++ b/pw_env_setup/py/pw_env_setup/json_visitor.py
@@ -0,0 +1,89 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Serializes an Environment into a JSON file."""
+
+import json
+
+# Disable super() warnings since this file must be Python 2 compatible.
+# pylint: disable=super-with-arguments
+
+
+class JSONVisitor(object): # pylint: disable=useless-object-inheritance
+ """Serializes an Environment into a JSON file."""
+ def __init__(self, *args, **kwargs):
+ super(JSONVisitor, self).__init__(*args, **kwargs)
+ self._data = {}
+
+ def serialize(self, env, outs):
+ self._data = {
+ 'modify': {},
+ 'set': {},
+ }
+
+ env.accept(self)
+
+ json.dump(self._data, outs, indent=4, separators=(',', ': '))
+ outs.write('\n')
+ self._data = {}
+
+ def visit_set(self, set): # pylint: disable=redefined-builtin
+ self._data['set'][set.name] = set.value
+
+ def visit_clear(self, clear):
+ self._data['set'][clear.name] = None
+
+ def _initialize_path_like_variable(self, name):
+ default = {'append': [], 'prepend': [], 'remove': []}
+ self._data['modify'].setdefault(name, default)
+
+ def visit_remove(self, remove):
+ self._initialize_path_like_variable(remove.name)
+ self._data['modify'][remove.name]['remove'].append(remove.value)
+ if remove.value in self._data['modify'][remove.name]['append']:
+ self._data['modify'][remove.name]['append'].remove(remove.value)
+ if remove.value in self._data['modify'][remove.name]['prepend']:
+ self._data['modify'][remove.name]['prepend'].remove(remove.value)
+
+ def visit_prepend(self, prepend):
+ self._initialize_path_like_variable(prepend.name)
+ self._data['modify'][prepend.name]['prepend'].append(prepend.value)
+ if prepend.value in self._data['modify'][prepend.name]['remove']:
+ self._data['modify'][prepend.name]['remove'].remove(prepend.value)
+
+ def visit_append(self, append):
+ self._initialize_path_like_variable(append.name)
+ self._data['modify'][append.name]['append'].append(append.value)
+ if append.value in self._data['modify'][append.name]['remove']:
+ self._data['modify'][append.name]['remove'].remove(append.value)
+
+ def visit_echo(self, echo):
+ pass
+
+ def visit_comment(self, comment):
+ pass
+
+ def visit_command(self, command):
+ pass
+
+ def visit_doctor(self, doctor):
+ pass
+
+ def visit_blank_line(self, blank_line):
+ pass
+
+ def visit_function(self, function):
+ pass
+
+ def visit_hash(self, hash): # pylint: disable=redefined-builtin
+ pass
diff --git a/pw_env_setup/py/pw_env_setup/shell_visitor.py b/pw_env_setup/py/pw_env_setup/shell_visitor.py
new file mode 100644
index 0000000..18a42fb
--- /dev/null
+++ b/pw_env_setup/py/pw_env_setup/shell_visitor.py
@@ -0,0 +1,196 @@
+# Copyright 2021 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Serializes an Environment into a shell file."""
+
+import inspect
+
+# Disable super() warnings since this file must be Python 2 compatible.
+# pylint: disable=super-with-arguments
+
+
+class _BaseShellVisitor(object): # pylint: disable=useless-object-inheritance
+ def __init__(self, *args, **kwargs):
+ pathsep = kwargs.pop('pathsep', ':')
+ super(_BaseShellVisitor, self).__init__(*args, **kwargs)
+ self._pathsep = pathsep
+ self._outs = None
+
+ def _remove_value_from_path(self, variable, value):
+ return ('{variable}="$(echo "${variable}"'
+ ' | sed "s|{pathsep}{value}{pathsep}|{pathsep}|g;"'
+ ' | sed "s|^{value}{pathsep}||g;"'
+ ' | sed "s|{pathsep}{value}$||g;"'
+ ')"\nexport {variable}\n'.format(variable=variable,
+ value=value,
+ pathsep=self._pathsep))
+
+ def visit_hash(self, hash): # pylint: disable=redefined-builtin
+ del hash
+ self._outs.write(
+ inspect.cleandoc('''
+ # This should detect bash and zsh, which have a hash command that must
+ # be called to get it to forget past commands. Without forgetting past
+ # commands the $PATH changes we made may not be respected.
+ if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
+ hash -r\n
+ fi
+ '''))
+
+
+class ShellVisitor(_BaseShellVisitor):
+ """Serializes an Environment into a shell file."""
+ def __init__(self, *args, **kwargs):
+ super(ShellVisitor, self).__init__(*args, **kwargs)
+ self._replacements = ()
+
+ def serialize(self, env, outs):
+ try:
+ self._replacements = tuple(
+ (key, env.get(key) if value is None else value)
+ for key, value in env.replacements)
+ self._outs = outs
+
+ env.accept(self)
+
+ finally:
+ self._replacements = ()
+ self._outs = None
+
+ def _apply_replacements(self, action):
+ value = action.value
+ for var, replacement in self._replacements:
+ if var != action.name:
+ value = value.replace(replacement, '${}'.format(var))
+ return value
+
+ def visit_set(self, set): # pylint: disable=redefined-builtin
+ value = self._apply_replacements(set)
+ self._outs.write('{name}="{value}"\nexport {name}\n'.format(
+ name=set.name, value=value))
+
+ def visit_clear(self, clear):
+ self._outs.write('unset {name}\n'.format(**vars(clear)))
+
+ def visit_remove(self, remove):
+ value = self._apply_replacements(remove)
+ self._outs.write('# Remove \n# {value}\n# from\n# {value}\n# '
+ 'before adding it back.\n')
+ self._outs.write(self._remove_value_from_path(remove.name, value))
+
+ def _join(self, *args):
+ if len(args) == 1 and isinstance(args[0], (list, tuple)):
+ args = args[0]
+ return self._pathsep.join(args)
+
+ def visit_prepend(self, prepend):
+ value = self._apply_replacements(prepend)
+ value = self._join(value, '${}'.format(prepend.name))
+ self._outs.write('{name}="{value}"\nexport {name}\n'.format(
+ name=prepend.name, value=value))
+
+ def visit_append(self, append):
+ value = self._apply_replacements(append)
+ value = self._join('${}'.format(append.name), value)
+ self._outs.write('{name}="{value}"\nexport {name}\n'.format(
+ name=append.name, value=value))
+
+ def visit_echo(self, echo):
+ # TODO(mohrr) use shlex.quote().
+ self._outs.write('if [ -z "${PW_ENVSETUP_QUIET:-}" ]; then\n')
+ if echo.newline:
+ self._outs.write(' echo "{}"\n'.format(echo.value))
+ else:
+ self._outs.write(' echo -n "{}"\n'.format(echo.value))
+ self._outs.write('fi\n')
+
+ def visit_comment(self, comment):
+ for line in comment.value.splitlines():
+ self._outs.write('# {}\n'.format(line))
+
+ def visit_command(self, command):
+ # TODO(mohrr) use shlex.quote here?
+ self._outs.write('{}\n'.format(' '.join(command.command)))
+ if not command.exit_on_error:
+ return
+
+ # Assume failing command produced relevant output.
+ self._outs.write('if [ "$?" -ne 0 ]; then\n return 1\nfi\n')
+
+ def visit_doctor(self, doctor):
+ self._outs.write('if [ -z "$PW_ACTIVATE_SKIP_CHECKS" ]; then\n')
+ self.visit_command(doctor)
+ self._outs.write('else\n')
+ self._outs.write('echo Skipping environment check because '
+ 'PW_ACTIVATE_SKIP_CHECKS is set\n')
+ self._outs.write('fi\n')
+
+ def visit_blank_line(self, blank_line):
+ del blank_line
+ self._outs.write('\n')
+
+ def visit_function(self, function):
+ self._outs.write('{name}() {{\n{body}\n}}\n'.format(
+ name=function.name, body=function.body))
+
+
+class DeactivateShellVisitor(_BaseShellVisitor):
+ """Removes values from an Environment."""
+ def __init__(self, *args, **kwargs):
+ pathsep = kwargs.pop('pathsep', ':')
+ super(DeactivateShellVisitor, self).__init__(*args, **kwargs)
+ self._pathsep = pathsep
+
+ def serialize(self, env, outs):
+ try:
+ self._outs = outs
+
+ env.accept(self)
+
+ finally:
+ self._outs = None
+
+ def visit_set(self, set): # pylint: disable=redefined-builtin
+ self._outs.write('unset {name}\n'.format(name=set.name))
+
+ def visit_clear(self, clear):
+ pass # Not relevant.
+
+ def visit_remove(self, remove):
+ pass # Not relevant.
+
+ def visit_prepend(self, prepend):
+ self._outs.write(
+ self._remove_value_from_path(prepend.name, prepend.value))
+
+ def visit_append(self, append):
+ self._outs.write(
+ self._remove_value_from_path(append.name, append.value))
+
+ def visit_echo(self, echo):
+ pass # Not relevant.
+
+ def visit_comment(self, comment):
+ pass # Not relevant.
+
+ def visit_command(self, command):
+ pass # Not relevant.
+
+ def visit_doctor(self, doctor):
+ pass # Not relevant.
+
+ def visit_blank_line(self, blank_line):
+ pass # Not relevant.
+
+ def visit_function(self, function):
+ pass # Not relevant.