feat: add persistent worker for sphinxdocs (#2938)

This implements a simple, serialized persistent worker for Sphinxdocs
with several optimizations. It is enabled by default.

* The worker computes what inputs have changed, allowing Sphinx to only
rebuild what
  is necessary.
* Doctrees are written to a separate directory so they are retained
between builds.
* The worker tells Sphinx to write output to an internal directory, then
copies it
to the expected Bazel output directory afterwards. This allows Sphinx to
only
  write output files that need to be updated.

This works by having the worker compute what files have changed and
having a Sphinx
extension use the `get-env-outdated` event to tell Sphinx which files
have changed.
The extension is based on https://pwrev.dev/294057, but re-implemented
to be
in-memory as part of the worker instead of a separate extension projects
must configure.

For rules_python's doc building, this reduces incremental building from
about 8 seconds
to about 0.8 seconds. From what I can tell, about half the time is spent
generating
doctrees, and the other half generating the output files.

Worker mode is enabled by default and can be disabled on the target or
by adjusting
the Bazel flags controlling execution strategy. Docs added to explain
how.

Because `--doctree-dir` is now always specified and outside the output
dir,
non-worker invocations can benefit, too, if run without sandboxing. Docs
added to
explain how to do this.

Along the way:

* Remove `--write-all` and `--fresh-env` from run args. This lets direct
  invocations benefit from the normal caching Sphinx does.
* Change the args formatting to `--foo=bar` so they are a single
element; just
  a bit nicer to see when debugging.


Work towards https://github.com/bazel-contrib/rules_python/issues/2878,
https://github.com/bazel-contrib/rules_python/issues/2879

---------

Co-authored-by: Kayce Basques <kayce@google.com>
Co-authored-by: Richard Levasseur <rlevasseur@google.com>
diff --git a/sphinxdocs/docs/index.md b/sphinxdocs/docs/index.md
index bd6448c..2ea1146 100644
--- a/sphinxdocs/docs/index.md
+++ b/sphinxdocs/docs/index.md
@@ -11,6 +11,29 @@
 While it is primarily oriented towards docgen for Starlark code, the core of it
 is agnostic as to what is being documented.
 
+### Optimization
+
+Normally, Sphinx keeps various cache files to improve incremental building.
+Unfortunately, programs performing their own caching don't interact well
+with Bazel's model of precisely declaring and strictly enforcing what are
+inputs, what are outputs, and what files are available when running a program.
+The net effect is programs don't have a prior invocation's cache files
+available.
+
+There are two mechanisms available to make some cache available to Sphinx under
+Bazel:
+
+* Disable sandboxing, which allows some files from prior invocations to be
+  visible to subsequent invocations. This can be done multiple ways:
+  * Set `tags = ["no-sandbox"]` on the `sphinx_docs` target
+  * `--modify_execution_info=SphinxBuildDocs=+no-sandbox` (Bazel flag)
+  * `--strategy=SphinxBuildDocs=local` (Bazel flag)
+* Use persistent workers (enabled by default) by setting
+  `allow_persistent_workers=True` on the `sphinx_docs` target. Note that other
+  Bazel flags can disable using workers even if an action supports it. Setting
+  `--strategy=SphinxBuildDocs=dynamic,worker,local,sandbox` should tell Bazel
+  to use workers if possible, otherwise fallback to non-worker invocations.
+
 
 ```{toctree}
 :hidden:
diff --git a/sphinxdocs/private/sphinx.bzl b/sphinxdocs/private/sphinx.bzl
index 8d19d87..ee6b994 100644
--- a/sphinxdocs/private/sphinx.bzl
+++ b/sphinxdocs/private/sphinx.bzl
@@ -103,6 +103,7 @@
         strip_prefix = "",
         extra_opts = [],
         tools = [],
+        allow_persistent_workers = True,
         **kwargs):
     """Generate docs using Sphinx.
 
@@ -142,6 +143,9 @@
         tools: {type}`list[label]` Additional tools that are used by Sphinx and its plugins.
             This just makes the tools available during Sphinx execution. To locate
             them, use {obj}`extra_opts` and `$(location)`.
+        allow_persistent_workers: {type}`bool` (experimental) If true, allow
+            using persistent workers for running Sphinx, if Bazel decides to do so.
+            This can improve incremental building of docs.
         **kwargs: {type}`dict` Common attributes to pass onto rules.
     """
     add_tag(kwargs, "@rules_python//sphinxdocs:sphinx_docs")
@@ -165,6 +169,7 @@
         source_tree = internal_name + "/_sources",
         extra_opts = extra_opts,
         tools = tools,
+        allow_persistent_workers = allow_persistent_workers,
         **kwargs
     )
 
@@ -209,6 +214,7 @@
             source_path = source_dir_path,
             output_prefix = paths.join(ctx.label.name, "_build"),
             inputs = inputs,
+            allow_persistent_workers = ctx.attr.allow_persistent_workers,
         )
         outputs[format] = output_dir
         per_format_args[format] = args_env
@@ -229,6 +235,10 @@
 _sphinx_docs = rule(
     implementation = _sphinx_docs_impl,
     attrs = {
+        "allow_persistent_workers": attr.bool(
+            doc = "(experimental) Whether to invoke Sphinx as a persistent worker.",
+            default = False,
+        ),
         "extra_opts": attr.string_list(
             doc = "Additional options to pass onto Sphinx. These are added after " +
                   "other options, but before the source/output args.",
@@ -254,16 +264,27 @@
     },
 )
 
-def _run_sphinx(ctx, format, source_path, inputs, output_prefix):
+def _run_sphinx(ctx, format, source_path, inputs, output_prefix, allow_persistent_workers):
     output_dir = ctx.actions.declare_directory(paths.join(output_prefix, format))
 
     run_args = []  # Copy of the args to forward along to debug runner
     args = ctx.actions.args()  # Args passed to the action
 
+    # An args file is required for persistent workers, but we don't know if
+    # the action will use worker mode or not (settings we can't see may
+    # force non-worker mode). For consistency, always use a params file.
+    args.use_param_file("@%s", use_always = True)
+    args.set_param_file_format("multiline")
+
+    # NOTE: sphinx_build.py relies on the first two args being the srcdir and
+    # outputdir, in that order.
+    args.add(source_path)
+    args.add(output_dir.path)
+
     args.add("--show-traceback")  # Full tracebacks on error
     run_args.append("--show-traceback")
-    args.add("--builder", format)
-    run_args.extend(("--builder", format))
+    args.add(format, format = "--builder=%s")
+    run_args.append("--builder={}".format(format))
 
     if ctx.attr._quiet_flag[BuildSettingInfo].value:
         # Not added to run_args because run_args is for debugging
@@ -271,11 +292,17 @@
 
     # Build in parallel, if possible
     # Don't add to run_args: parallel building breaks interactive debugging
-    args.add("--jobs", "auto")
-    args.add("--fresh-env")  # Don't try to use cache files. Bazel can't make use of them.
-    run_args.append("--fresh-env")
-    args.add("--write-all")  # Write all files; don't try to detect "changed" files
-    run_args.append("--write-all")
+    args.add("--jobs=auto")
+
+    # Put the doctree dir outside of the output directory.
+    # This allows it to be reused between invocations when possible; Bazel
+    # clears the output directory every action invocation.
+    # * For workers, they can fully re-use it.
+    # * For non-workers, it can be reused when sandboxing is disabled via
+    #   the `no-sandbox` tag or execution requirement.
+    #
+    # We also use a non-dot prefixed name so it shows up more visibly.
+    args.add(paths.join(output_dir.path + "_doctrees"), format = "--doctree-dir=%s")
 
     for opt in ctx.attr.extra_opts:
         expanded = ctx.expand_location(opt)
@@ -287,9 +314,6 @@
     for define in extra_defines:
         run_args.extend(("--define", define))
 
-    args.add(source_path)
-    args.add(output_dir.path)
-
     env = dict([
         v.split("=", 1)
         for v in ctx.attr._extra_env_flag[_FlagInfo].value
@@ -299,6 +323,14 @@
     for tool in ctx.attr.tools:
         tools.append(tool[DefaultInfo].files_to_run)
 
+    # NOTE: Command line flags or RBE capabilities may override the execution
+    # requirements and disable workers. Thus, we can't assume that these
+    # exec requirements will actually be respected.
+    execution_requirements = {}
+    if allow_persistent_workers:
+        execution_requirements["supports-workers"] = "1"
+        execution_requirements["requires-worker-protocol"] = "json"
+
     ctx.actions.run(
         executable = ctx.executable.sphinx,
         arguments = [args],
@@ -308,6 +340,7 @@
         mnemonic = "SphinxBuildDocs",
         progress_message = "Sphinx building {} for %{{label}}".format(format),
         env = env,
+        execution_requirements = execution_requirements,
     )
     return output_dir, struct(args = run_args, env = env)
 
diff --git a/sphinxdocs/private/sphinx_build.py b/sphinxdocs/private/sphinx_build.py
index 3b7b32e..e971104 100644
--- a/sphinxdocs/private/sphinx_build.py
+++ b/sphinxdocs/private/sphinx_build.py
@@ -1,8 +1,235 @@
+import contextlib
+import io
+import json
+import logging
 import os
-import pathlib
+import shutil
 import sys
+import traceback
+import typing
 
+import sphinx.application
 from sphinx.cmd.build import main
 
+WorkRequest = object
+WorkResponse = object
+
+logger = logging.getLogger("sphinxdocs_build")
+
+_WORKER_SPHINX_EXT_MODULE_NAME = "bazel_worker_sphinx_ext"
+
+# Config value name for getting the path to the request info file
+_REQUEST_INFO_CONFIG_NAME = "bazel_worker_request_info_path"
+
+
+class Worker:
+
+    def __init__(
+        self, instream: "typing.TextIO", outstream: "typing.TextIO", exec_root: str
+    ):
+        # NOTE: Sphinx performs its own logging re-configuration, so any
+        # logging config we do isn't respected by Sphinx. Controlling where
+        # stdout and stderr goes are the main mechanisms. Recall that
+        # Bazel send worker stderr to the worker log file.
+        # outputBase=$(bazel info output_base)
+        # find $outputBase/bazel-workers/ -type f -printf '%T@ %p\n' | sort -n | tail -1 | awk '{print $2}'
+        logging.basicConfig(level=logging.WARN)
+        logger.info("Initializing worker")
+
+        # The directory that paths are relative to.
+        self._exec_root = exec_root
+        # Where requests are read from.
+        self._instream = instream
+        # Where responses are written to.
+        self._outstream = outstream
+
+        # dict[str srcdir, dict[str path, str digest]]
+        self._digests = {}
+
+        # Internal output directories the worker gives to Sphinx that need
+        # to be cleaned up upon exit.
+        # set[str path]
+        self._worker_outdirs = set()
+        self._extension = BazelWorkerExtension()
+
+        sys.modules[_WORKER_SPHINX_EXT_MODULE_NAME] = self._extension
+        sphinx.application.builtin_extensions += (_WORKER_SPHINX_EXT_MODULE_NAME,)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self):
+        for worker_outdir in self._worker_outdirs:
+            shutil.rmtree(worker_outdir, ignore_errors=True)
+
+    def run(self) -> None:
+        logger.info("Worker started")
+        try:
+            while True:
+                request = None
+                try:
+                    request = self._get_next_request()
+                    if request is None:
+                        logger.info("Empty request: exiting")
+                        break
+                    response = self._process_request(request)
+                    if response:
+                        self._send_response(response)
+                except Exception:
+                    logger.exception("Unhandled error: request=%s", request)
+                    output = (
+                        f"Unhandled error:\nRequest id: {request.get('id')}\n"
+                        + traceback.format_exc()
+                    )
+                    request_id = 0 if not request else request.get("requestId", 0)
+                    self._send_response(
+                        {
+                            "exitCode": 3,
+                            "output": output,
+                            "requestId": request_id,
+                        }
+                    )
+        finally:
+            logger.info("Worker shutting down")
+
+    def _get_next_request(self) -> "object | None":
+        line = self._instream.readline()
+        if not line:
+            return None
+        return json.loads(line)
+
+    def _send_response(self, response: "WorkResponse") -> None:
+        self._outstream.write(json.dumps(response) + "\n")
+        self._outstream.flush()
+
+    def _prepare_sphinx(self, request):
+        sphinx_args = request["arguments"]
+        srcdir = sphinx_args[0]
+
+        incoming_digests = {}
+        current_digests = self._digests.setdefault(srcdir, {})
+        changed_paths = []
+        request_info = {"exec_root": self._exec_root, "inputs": request["inputs"]}
+        for entry in request["inputs"]:
+            path = entry["path"]
+            digest = entry["digest"]
+            # Make the path srcdir-relative so Sphinx understands it.
+            path = path.removeprefix(srcdir + "/")
+            incoming_digests[path] = digest
+
+            if path not in current_digests:
+                logger.info("path %s new", path)
+                changed_paths.append(path)
+            elif current_digests[path] != digest:
+                logger.info("path %s changed", path)
+                changed_paths.append(path)
+
+        self._digests[srcdir] = incoming_digests
+        self._extension.changed_paths = changed_paths
+        request_info["changed_sources"] = changed_paths
+
+        bazel_outdir = sphinx_args[1]
+        worker_outdir = bazel_outdir + ".worker-out.d"
+        self._worker_outdirs.add(worker_outdir)
+        sphinx_args[1] = worker_outdir
+
+        request_info_path = os.path.join(srcdir, "_bazel_worker_request_info.json")
+        with open(request_info_path, "w") as fp:
+            json.dump(request_info, fp)
+        sphinx_args.append(f"--define={_REQUEST_INFO_CONFIG_NAME}={request_info_path}")
+
+        return worker_outdir, bazel_outdir, sphinx_args
+
+    @contextlib.contextmanager
+    def _redirect_streams(self):
+        out = io.StringIO()
+        orig_stdout = sys.stdout
+        try:
+            sys.stdout = out
+            yield out
+        finally:
+            sys.stdout = orig_stdout
+
+    def _process_request(self, request: "WorkRequest") -> "WorkResponse | None":
+        logger.info("Request: %s", json.dumps(request, sort_keys=True, indent=2))
+        if request.get("cancel"):
+            return None
+
+        worker_outdir, bazel_outdir, sphinx_args = self._prepare_sphinx(request)
+
+        # Prevent anything from going to stdout because it breaks the worker
+        # protocol. We have limited control over where Sphinx sends output.
+        with self._redirect_streams() as stdout:
+            logger.info("main args: %s", sphinx_args)
+            exit_code = main(sphinx_args)
+
+        if exit_code:
+            raise Exception(
+                "Sphinx main() returned failure: "
+                + f"  exit code: {exit_code}\n"
+                + "========== STDOUT START ==========\n"
+                + stdout.getvalue().rstrip("\n")
+                + "\n"
+                + "========== STDOUT END ==========\n"
+            )
+
+        # Copying is unfortunately necessary because Bazel doesn't know to
+        # implicily bring along what the symlinks point to.
+        shutil.copytree(worker_outdir, bazel_outdir, dirs_exist_ok=True)
+
+        response = {
+            "requestId": request.get("requestId", 0),
+            "output": stdout.getvalue(),
+            "exitCode": 0,
+        }
+        return response
+
+
+class BazelWorkerExtension:
+    """A Sphinx extension implemented as a class acting like a module."""
+
+    def __init__(self):
+        # Make it look like a Module object
+        self.__name__ = _WORKER_SPHINX_EXT_MODULE_NAME
+        # set[str] of src-dir relative path names
+        self.changed_paths = set()
+
+    def setup(self, app):
+        app.add_config_value(_REQUEST_INFO_CONFIG_NAME, "", "")
+        app.connect("env-get-outdated", self._handle_env_get_outdated)
+        return {"parallel_read_safe": True, "parallel_write_safe": True}
+
+    def _handle_env_get_outdated(self, app, env, added, changed, removed):
+        changed = {
+            # NOTE: path2doc returns None if it's not a doc path
+            env.path2doc(p)
+            for p in self.changed_paths
+        }
+
+        logger.info("changed docs: %s", changed)
+        return changed
+
+
+def _worker_main(stdin, stdout, exec_root):
+    with Worker(stdin, stdout, exec_root) as worker:
+        return worker.run()
+
+
+def _non_worker_main():
+    args = []
+    for arg in sys.argv:
+        if arg.startswith("@"):
+            with open(arg.removeprefix("@")) as fp:
+                lines = [line.strip() for line in fp if line.strip()]
+            args.extend(lines)
+        else:
+            args.append(arg)
+    sys.argv[:] = args
+    return main()
+
+
 if __name__ == "__main__":
-    sys.exit(main())
+    if "--persistent_worker" in sys.argv:
+        sys.exit(_worker_main(sys.stdin, sys.stdout, os.getcwd()))
+    else:
+        sys.exit(_non_worker_main())
diff --git a/sphinxdocs/tests/sphinx_docs/doc1.md b/sphinxdocs/tests/sphinx_docs/doc1.md
new file mode 100644
index 0000000..f6f70ba
--- /dev/null
+++ b/sphinxdocs/tests/sphinx_docs/doc1.md
@@ -0,0 +1,3 @@
+# doc1
+
+hello doc 1
diff --git a/sphinxdocs/tests/sphinx_docs/doc2.md b/sphinxdocs/tests/sphinx_docs/doc2.md
new file mode 100644
index 0000000..06eb76a
--- /dev/null
+++ b/sphinxdocs/tests/sphinx_docs/doc2.md
@@ -0,0 +1,3 @@
+# doc 2
+
+hello doc 3