feat: allow setting custom environment variables on pip_repository and whl_library (#460)

* feat: allow setting custom environment variables on pip_repository and whl_library

* Serialize and deserialize environment dict in python process instead
of starlark.

* Refactor shared functions between extract_wheel and extract_single_wheel.
* Every structured arg now has the same key when serialized. fixes #490
* test for pip_data_exclude in arguments parsing test.

* Also update docs in repository rule attr definition

Co-authored-by: Jonathon Belotti <jonathon@canva.com>
Co-authored-by: Henry Fuller <hrofuller@gmail.com>
diff --git a/examples/pip_install/WORKSPACE b/examples/pip_install/WORKSPACE
index 56f7afd..4da2381 100644
--- a/examples/pip_install/WORKSPACE
+++ b/examples/pip_install/WORKSPACE
@@ -29,6 +29,12 @@
     # (Optional) You can set quiet to False if you want to see pip output.
     #quiet = False,
 
+    # (Optional) You can set an environment in the pip process to control its
+    # behavior. Note that pip is run in "isolated" mode so no PIP_<VAR>_<NAME>
+    # style env vars are read, but env vars that control requests and urllib3
+    # can be passed
+    #environment = {"HTTP_PROXY": "http://my.proxy.fun/"},
+
     # Uses the default repository name "pip"
     requirements = "//:requirements.txt",
 )
diff --git a/examples/pip_parse/WORKSPACE b/examples/pip_parse/WORKSPACE
index 379486c..78b6773 100644
--- a/examples/pip_parse/WORKSPACE
+++ b/examples/pip_parse/WORKSPACE
@@ -29,6 +29,12 @@
     # (Optional) You can set quiet to False if you want to see pip output.
     #quiet = False,
 
+    # (Optional) You can set an environment in the pip process to control its
+    # behavior. Note that pip is run in "isolated" mode so no PIP_<VAR>_<NAME>
+    # style env vars are read, but env vars that control requests and urllib3
+    # can be passed
+    # environment = {"HTTPS_PROXY": "http://my.proxy.fun/"},
+
     # Uses the default repository name "pip_parsed_deps"
     requirements_lock = "//:requirements_lock.txt",
 )
diff --git a/python/pip_install/extract_wheels/__init__.py b/python/pip_install/extract_wheels/__init__.py
index 96913cd..228346a 100644
--- a/python/pip_install/extract_wheels/__init__.py
+++ b/python/pip_install/extract_wheels/__init__.py
@@ -60,21 +60,21 @@
     )
     arguments.parse_common_args(parser)
     args = parser.parse_args()
+    deserialized_args = dict(vars(args))
+    arguments.deserialize_structured_args(deserialized_args)
 
-    pip_args = [sys.executable, "-m", "pip", "--isolated", "wheel", "-r", args.requirements]
-    if args.extra_pip_args:
-        pip_args += json.loads(args.extra_pip_args)["args"]
+    pip_args = (
+        [sys.executable, "-m", "pip", "--isolated", "wheel", "-r", args.requirements] +
+        deserialized_args["extra_pip_args"]
+    )
 
+    env = os.environ.copy()
+    env.update(deserialized_args["environment"])
     # Assumes any errors are logged by pip so do nothing. This command will fail if pip fails
-    subprocess.run(pip_args, check=True)
+    subprocess.run(pip_args, check=True, env=env)
 
     extras = requirements.parse_extras(args.requirements)
 
-    if args.pip_data_exclude:
-        pip_data_exclude = json.loads(args.pip_data_exclude)["exclude"]
-    else:
-        pip_data_exclude = []
-
     repo_label = "@%s" % args.repo
 
     targets = [
@@ -82,7 +82,7 @@
         % (
             repo_label,
             bazel.extract_wheel(
-                whl, extras, pip_data_exclude, args.enable_implicit_namespace_pkgs
+                whl, extras, deserialized_args["pip_data_exclude"], args.enable_implicit_namespace_pkgs
             ),
         )
         for whl in glob.glob("*.whl")
diff --git a/python/pip_install/extract_wheels/lib/arguments.py b/python/pip_install/extract_wheels/lib/arguments.py
index ee9a649..46d08a8 100644
--- a/python/pip_install/extract_wheels/lib/arguments.py
+++ b/python/pip_install/extract_wheels/lib/arguments.py
@@ -1,3 +1,4 @@
+import json
 from argparse import ArgumentParser
 
 
@@ -21,4 +22,24 @@
         action="store_true",
         help="Disables conversion of implicit namespace packages into pkg-util style packages.",
     )
+    parser.add_argument(
+        "--environment",
+        action="store",
+        help="Extra environment variables to set on the pip environment.",
+    )
     return parser
+
+
+def deserialize_structured_args(args):
+    """Deserialize structured arguments passed from the starlark rules.
+    Args:
+        args: dict of parsed command line arguments
+    """
+    structured_args = ("extra_pip_args", "pip_data_exclude", "environment")
+    for arg_name in structured_args:
+        if args.get(arg_name) is not None:
+            args[arg_name] = json.loads(args[arg_name])["arg"]
+        else:
+            args[arg_name] = []
+    return args
+
diff --git a/python/pip_install/extract_wheels/lib/arguments_test.py b/python/pip_install/extract_wheels/lib/arguments_test.py
index c0338bd..53e19a3 100644
--- a/python/pip_install/extract_wheels/lib/arguments_test.py
+++ b/python/pip_install/extract_wheels/lib/arguments_test.py
@@ -3,7 +3,6 @@
 import unittest
 
 from python.pip_install.extract_wheels.lib import arguments
-from python.pip_install.parse_requirements_to_bzl import deserialize_structured_args
 
 
 class ArgumentsTestCase(unittest.TestCase):
@@ -12,15 +11,26 @@
         parser = arguments.parse_common_args(parser)
         repo_name = "foo"
         index_url = "--index_url=pypi.org/simple"
+        extra_pip_args = [index_url]
         args_dict = vars(parser.parse_args(
-            args=["--repo", repo_name, "--extra_pip_args={index_url}".format(index_url=json.dumps({"args": index_url}))]))
-        args_dict = deserialize_structured_args(args_dict)
+            args=["--repo", repo_name, f"--extra_pip_args={json.dumps({'arg': extra_pip_args})}"]))
+        args_dict = arguments.deserialize_structured_args(args_dict)
         self.assertIn("repo", args_dict)
         self.assertIn("extra_pip_args", args_dict)
         self.assertEqual(args_dict["pip_data_exclude"], [])
         self.assertEqual(args_dict["enable_implicit_namespace_pkgs"], False)
         self.assertEqual(args_dict["repo"], repo_name)
-        self.assertEqual(args_dict["extra_pip_args"], index_url)
+        self.assertEqual(args_dict["extra_pip_args"], extra_pip_args)
+
+    def test_deserialize_structured_args(self) -> None:
+        serialized_args = {
+            "pip_data_exclude": json.dumps({"arg": ["**.foo"]}),
+            "environment": json.dumps({"arg": {"PIP_DO_SOMETHING": "True"}}),
+        }
+        args = arguments.deserialize_structured_args(serialized_args)
+        self.assertEqual(args["pip_data_exclude"], ["**.foo"])
+        self.assertEqual(args["environment"], {"PIP_DO_SOMETHING": "True"})
+        self.assertEqual(args["extra_pip_args"], [])
 
 
 if __name__ == "__main__":
diff --git a/python/pip_install/parse_requirements_to_bzl/__init__.py b/python/pip_install/parse_requirements_to_bzl/__init__.py
index 8802ef4..f27e2a2 100644
--- a/python/pip_install/parse_requirements_to_bzl/__init__.py
+++ b/python/pip_install/parse_requirements_to_bzl/__init__.py
@@ -45,19 +45,6 @@
         for ir, line in install_reqs
     ]
 
-def deserialize_structured_args(args):
-    """Deserialize structured arguments passed from the starlark rules.
-        Args:
-            args: dict of parsed command line arguments
-    """
-    structured_args = ("extra_pip_args", "pip_data_exclude")
-    for arg_name in structured_args:
-        if args.get(arg_name) is not None:
-            args[arg_name] = json.loads(args[arg_name])["args"]
-        else:
-            args[arg_name] = []
-    return args
-
 
 def generate_parsed_requirements_contents(all_args: argparse.Namespace) -> str:
     """
@@ -69,7 +56,7 @@
     """
 
     args = dict(vars(all_args))
-    args = deserialize_structured_args(args)
+    args = arguments.deserialize_structured_args(args)
     args.setdefault("python_interpreter", sys.executable)
     # Pop this off because it wont be used as a config argument to the whl_library rule.
     requirements_lock = args.pop("requirements_lock")
diff --git a/python/pip_install/parse_requirements_to_bzl/extract_single_wheel/__init__.py b/python/pip_install/parse_requirements_to_bzl/extract_single_wheel/__init__.py
index 4ea3758..a46ea2e 100644
--- a/python/pip_install/parse_requirements_to_bzl/extract_single_wheel/__init__.py
+++ b/python/pip_install/parse_requirements_to_bzl/extract_single_wheel/__init__.py
@@ -23,12 +23,15 @@
     )
     arguments.parse_common_args(parser)
     args = parser.parse_args()
+    deserialized_args = dict(vars(args))
+    arguments.deserialize_structured_args(deserialized_args)
 
     configure_reproducible_wheels()
 
-    pip_args = [sys.executable, "-m", "pip", "--isolated", "wheel", "--no-deps"]
-    if args.extra_pip_args:
-        pip_args += json.loads(args.extra_pip_args)["args"]
+    pip_args = (
+        [sys.executable, "-m", "pip", "--isolated", "wheel", "--no-deps"] +
+        deserialized_args["extra_pip_args"]
+    )
 
     requirement_file = NamedTemporaryFile(mode='wb', delete=False)
     try:
@@ -41,8 +44,10 @@
         # so write our single requirement into a temp file in case it has any of those flags.
         pip_args.extend(["-r", requirement_file.name])
 
+        env = os.environ.copy()
+        env.update(deserialized_args["environment"])
         # Assumes any errors are logged by pip so do nothing. This command will fail if pip fails
-        subprocess.run(pip_args, check=True)
+        subprocess.run(pip_args, check=True, env=env)
     finally:
         try:
             os.unlink(requirement_file.name)
@@ -53,16 +58,11 @@
     name, extras_for_pkg = requirements._parse_requirement_for_extra(args.requirement)
     extras = {name: extras_for_pkg} if extras_for_pkg and name else dict()
 
-    if args.pip_data_exclude:
-        pip_data_exclude = json.loads(args.pip_data_exclude)["exclude"]
-    else:
-        pip_data_exclude = []
-
     whl = next(iter(glob.glob("*.whl")))
     bazel.extract_wheel(
         whl,
         extras,
-        pip_data_exclude,
+        deserialized_args["pip_data_exclude"],
         args.enable_implicit_namespace_pkgs,
         incremental=True,
         incremental_repo_prefix=bazel.whl_library_repo_prefix(args.repo)
diff --git a/python/pip_install/parse_requirements_to_bzl/parse_requirements_to_bzl_test.py b/python/pip_install/parse_requirements_to_bzl/parse_requirements_to_bzl_test.py
index 7199cea..0ac5668 100644
--- a/python/pip_install/parse_requirements_to_bzl/parse_requirements_to_bzl_test.py
+++ b/python/pip_install/parse_requirements_to_bzl/parse_requirements_to_bzl_test.py
@@ -23,7 +23,10 @@
             args.requirements_lock = requirements_lock.name
             args.repo = "pip_parsed_deps"
             extra_pip_args = ["--index-url=pypi.org/simple"]
-            args.extra_pip_args = json.dumps({"args": extra_pip_args})
+            pip_data_exclude = ["**.foo"]
+            args.extra_pip_args = json.dumps({"arg": extra_pip_args})
+            args.pip_data_exclude= json.dumps({"arg": pip_data_exclude})
+            args.environment= json.dumps({"arg": {}})
             contents = generate_parsed_requirements_contents(args)
             library_target = "@pip_parsed_deps_pypi__foo//:pkg"
             whl_target = "@pip_parsed_deps_pypi__foo//:whl"
@@ -32,9 +35,11 @@
             self.assertIn(all_requirements, contents, contents)
             self.assertIn(all_whl_requirements, contents, contents)
             self.assertIn(requirement_string, contents, contents)
-            self.assertIn(requirement_string, contents, contents)
             all_flags = extra_pip_args + ["--require-hashes", "True"]
             self.assertIn("'extra_pip_args': {}".format(repr(all_flags)), contents, contents)
+            self.assertIn("'pip_data_exclude': {}".format(repr(pip_data_exclude)), contents, contents)
+            # Assert it gets set to an empty dict by default.
+            self.assertIn("'environment': {}", contents, contents)
 
 
 if __name__ == "__main__":
diff --git a/python/pip_install/pip_repository.bzl b/python/pip_install/pip_repository.bzl
index da4678f..d7d1113 100644
--- a/python/pip_install/pip_repository.bzl
+++ b/python/pip_install/pip_repository.bzl
@@ -27,26 +27,38 @@
 def _parse_optional_attrs(rctx, args):
     """Helper function to parse common attributes of pip_repository and whl_library repository rules.
 
+    This function also serializes the structured arguments as JSON
+    so they can be passed on the command line to subprocesses.
+
     Args:
         rctx: Handle to the rule repository context.
         args: A list of parsed args for the rule.
     Returns: Augmented args list.
     """
-    if rctx.attr.extra_pip_args:
+
+    # Check for None so we use empty default types from our attrs.
+    # Some args want to be list, and some want to be dict.
+    if rctx.attr.extra_pip_args != None:
         args += [
             "--extra_pip_args",
-            struct(args = rctx.attr.extra_pip_args).to_json(),
+            struct(arg = rctx.attr.extra_pip_args).to_json(),
         ]
 
-    if rctx.attr.pip_data_exclude:
+    if rctx.attr.pip_data_exclude != None:
         args += [
             "--pip_data_exclude",
-            struct(exclude = rctx.attr.pip_data_exclude).to_json(),
+            struct(arg = rctx.attr.pip_data_exclude).to_json(),
         ]
 
     if rctx.attr.enable_implicit_namespace_pkgs:
         args.append("--enable_implicit_namespace_pkgs")
 
+    if rctx.attr.environment != None:
+        args += [
+            "--environment",
+            struct(arg = rctx.attr.environment).to_json(),
+        ]
+
     return args
 
 _BUILD_FILE_CONTENTS = """\
@@ -102,10 +114,8 @@
 
     result = rctx.execute(
         args,
-        environment = {
-            # Manually construct the PYTHONPATH since we cannot use the toolchain here
-            "PYTHONPATH": pypath,
-        },
+        # Manually construct the PYTHONPATH since we cannot use the toolchain here
+        environment = {"PYTHONPATH": _construct_pypath(rctx)},
         timeout = rctx.attr.timeout,
         quiet = rctx.attr.quiet,
     )
@@ -126,6 +136,16 @@
 This option is required to support some packages which cannot handle the conversion to pkg-util style.
             """,
     ),
+    "environment": attr.string_dict(
+        doc = """
+Environment variables to set in the pip subprocess.
+Can be used to set common variables such as `http_proxy`, `https_proxy` and `no_proxy`
+Note that pip is run with "--isolated" on the CLI so PIP_<VAR>_<NAME>
+style env vars are ignored, but env vars that control requests and urllib3
+can be passed.
+        """,
+        default = {},
+    ),
     "extra_pip_args": attr.string_list(
         doc = "Extra arguments to pass on to pip. Must not contain spaces.",
     ),
@@ -221,7 +241,6 @@
 def _impl_whl_library(rctx):
     # pointer to parent repo so these rules rerun if the definitions in requirements.bzl change.
     _parent_repo_label = Label("@{parent}//:requirements.bzl".format(parent = rctx.attr.repo))
-    pypath = _construct_pypath(rctx)
     args = [
         rctx.attr.python_interpreter,
         "-m",
@@ -232,12 +251,11 @@
         rctx.attr.repo,
     ]
     args = _parse_optional_attrs(rctx, args)
+
     result = rctx.execute(
         args,
-        environment = {
-            # Manually construct the PYTHONPATH since we cannot use the toolchain here
-            "PYTHONPATH": pypath,
-        },
+        # Manually construct the PYTHONPATH since we cannot use the toolchain here
+        environment = {"PYTHONPATH": _construct_pypath(rctx)},
         quiet = rctx.attr.quiet,
         timeout = rctx.attr.timeout,
     )