blob: 310f2eb097bf52e2b9373bff68479f3cfa9a20c2 [file] [log] [blame]
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A simple precompiler to generate deterministic pyc files for Bazel."""
# NOTE: Imports specific to the persistent worker should only be imported
# when a persistent worker is used. Avoiding the unnecessary imports
# saves significant startup time for non-worker invocations.
import argparse
import py_compile
import sys
def _create_parser() -> "argparse.Namespace":
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument("--invalidation_mode", default="CHECKED_HASH")
parser.add_argument("--optimize", type=int, default=-1)
parser.add_argument("--python_version")
parser.add_argument("--src", action="append", dest="srcs")
parser.add_argument("--src_name", action="append", dest="src_names")
parser.add_argument("--pyc", action="append", dest="pycs")
parser.add_argument("--persistent_worker", action="store_true")
parser.add_argument("--log_level", default="ERROR")
parser.add_argument("--worker_impl", default="async")
return parser
def _compile(options: "argparse.Namespace") -> None:
try:
invalidation_mode = py_compile.PycInvalidationMode[
options.invalidation_mode.upper()
]
except KeyError as e:
raise ValueError(
f"Unknown PycInvalidationMode: {options.invalidation_mode}"
) from e
if not (len(options.srcs) == len(options.src_names) == len(options.pycs)):
raise AssertionError(
"Mismatched number of --src, --src_name, and/or --pyc args"
)
for src, src_name, pyc in zip(options.srcs, options.src_names, options.pycs):
py_compile.compile(
src,
pyc,
doraise=True,
dfile=src_name,
optimize=options.optimize,
invalidation_mode=invalidation_mode,
)
return 0
# A stub type alias for readability.
# See the Bazel WorkRequest object definition:
# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto
JsonWorkerRequest = object
# A stub type alias for readability.
# See the Bazel WorkResponse object definition:
# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto
JsonWorkerResponse = object
class _SerialPersistentWorker:
"""Simple, synchronous, serial persistent worker."""
def __init__(self, instream: "typing.TextIO", outstream: "typing.TextIO"):
self._instream = instream
self._outstream = outstream
self._parser = _create_parser()
def run(self) -> None:
try:
while True:
request = None
try:
request = self._get_next_request()
if request is None:
_logger.info("Empty request: exiting")
break
response = self._process_request(request)
if response: # May be none for cancel request
self._send_response(response)
except Exception:
_logger.exception("Unhandled error: request=%s", request)
output = (
f"Unhandled error:\nRequest: {request}\n"
+ traceback.format_exc()
)
request_id = 0 if not request else request.get("requestId", 0)
self._send_response(
{
"exitCode": 3,
"output": output,
"requestId": request_id,
}
)
finally:
_logger.info("Worker shutting down")
def _get_next_request(self) -> "object | None":
line = self._instream.readline()
if not line:
return None
return json.loads(line)
def _process_request(self, request: "JsonWorkRequest") -> "JsonWorkResponse | None":
if request.get("cancel"):
return None
options = self._options_from_request(request)
_compile(options)
response = {
"requestId": request.get("requestId", 0),
"exitCode": 0,
}
return response
def _options_from_request(
self, request: "JsonWorkResponse"
) -> "argparse.Namespace":
options = self._parser.parse_args(request["arguments"])
if request.get("sandboxDir"):
prefix = request["sandboxDir"]
options.srcs = [os.path.join(prefix, v) for v in options.srcs]
options.pycs = [os.path.join(prefix, v) for v in options.pycs]
return options
def _send_response(self, response: "JsonWorkResponse") -> None:
self._outstream.write(json.dumps(response) + "\n")
self._outstream.flush()
class _AsyncPersistentWorker:
"""Asynchronous, concurrent, persistent worker."""
def __init__(self, reader: "typing.TextIO", writer: "typing.TextIO"):
self._reader = reader
self._writer = writer
self._parser = _create_parser()
self._request_id_to_task = {}
self._task_to_request_id = {}
@classmethod
async def main(cls, instream: "typing.TextIO", outstream: "typing.TextIO") -> None:
reader, writer = await cls._connect_streams(instream, outstream)
await cls(reader, writer).run()
@classmethod
async def _connect_streams(
cls, instream: "typing.TextIO", outstream: "typing.TextIO"
) -> "tuple[asyncio.StreamReader, asyncio.StreamWriter]":
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, instream)
w_transport, w_protocol = await loop.connect_write_pipe(
asyncio.streams.FlowControlMixin, outstream
)
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
return reader, writer
async def run(self) -> None:
while True:
_logger.info("pending requests: %s", len(self._request_id_to_task))
request = await self._get_next_request()
request_id = request.get("requestId", 0)
task = asyncio.create_task(
self._process_request(request), name=f"request_{request_id}"
)
self._request_id_to_task[request_id] = task
self._task_to_request_id[task] = request_id
task.add_done_callback(self._handle_task_done)
async def _get_next_request(self) -> "JsonWorkRequest":
_logger.debug("awaiting line")
line = await self._reader.readline()
_logger.debug("recv line: %s", line)
return json.loads(line)
def _handle_task_done(self, task: "asyncio.Task") -> None:
request_id = self._task_to_request_id[task]
_logger.info("task done: %s %s", request_id, task)
del self._task_to_request_id[task]
del self._request_id_to_task[request_id]
async def _process_request(self, request: "JsonWorkRequest") -> None:
_logger.info("request %s: start: %s", request.get("requestId"), request)
try:
if request.get("cancel", False):
await self._process_cancel_request(request)
else:
await self._process_compile_request(request)
except asyncio.CancelledError:
_logger.info(
"request %s: cancel received, stopping processing",
request.get("requestId"),
)
# We don't send a response because we assume the request that
# triggered cancelling sent the response
raise
except:
_logger.exception("Unhandled error: request=%s", request)
self._send_response(
{
"exitCode": 3,
"output": f"Unhandled error:\nRequest: {request}\n"
+ traceback.format_exc(),
"requestId": 0 if not request else request.get("requestId", 0),
}
)
async def _process_cancel_request(self, request: "JsonWorkRequest") -> None:
request_id = request.get("requestId", 0)
task = self._request_id_to_task.get(request_id)
if not task:
# It must be already completed, so ignore the request, per spec
return
task.cancel()
self._send_response({"requestId": request_id, "wasCancelled": True})
async def _process_compile_request(self, request: "JsonWorkRequest") -> None:
options = self._options_from_request(request)
# _compile performs a varity of blocking IO calls, so run it separately
await asyncio.to_thread(_compile, options)
self._send_response(
{
"requestId": request.get("requestId", 0),
"exitCode": 0,
}
)
def _options_from_request(self, request: "JsonWorkRequest") -> "argparse.Namespace":
options = self._parser.parse_args(request["arguments"])
if request.get("sandboxDir"):
prefix = request["sandboxDir"]
options.srcs = [os.path.join(prefix, v) for v in options.srcs]
options.pycs = [os.path.join(prefix, v) for v in options.pycs]
return options
def _send_response(self, response: "JsonWorkResponse") -> None:
_logger.info("request %s: respond: %s", response.get("requestId"), response)
self._writer.write(json.dumps(response).encode("utf8") + b"\n")
def main(args: "list[str]") -> int:
options = _create_parser().parse_args(args)
# Persistent workers are started with the `--persistent_worker` flag.
# See the following docs for details on persistent workers:
# https://bazel.build/remote/persistent
# https://bazel.build/remote/multiplex
# https://bazel.build/remote/creating
if options.persistent_worker:
global asyncio, itertools, json, logging, os, traceback, _logger
import asyncio
import itertools
import json
import logging
import os.path
import traceback
_logger = logging.getLogger("precompiler")
# Only configure logging for workers. This prevents non-worker
# invocations from spamming stderr with logging info
logging.basicConfig(level=getattr(logging, options.log_level))
_logger.info("persistent worker: impl=%s", options.worker_impl)
if options.worker_impl == "serial":
_SerialPersistentWorker(sys.stdin, sys.stdout).run()
elif options.worker_impl == "async":
asyncio.run(_AsyncPersistentWorker.main(sys.stdin, sys.stdout))
else:
raise ValueError(f"Unknown worker impl: {options.worker_impl}")
else:
_compile(options)
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))