Add from_file tag
diff --git a/go/private/extensions.bzl b/go/private/extensions.bzl index 54f4cc5..24b8547 100644 --- a/go/private/extensions.bzl +++ b/go/private/extensions.bzl
@@ -15,6 +15,7 @@ load("@io_bazel_rules_go_bazel_features//:features.bzl", "bazel_features") load("//go/private:nogo.bzl", "DEFAULT_NOGO", "NOGO_DEFAULT_EXCLUDES", "NOGO_DEFAULT_INCLUDES", "go_register_nogo") load("//go/private:sdk.bzl", "detect_host_platform", "go_download_sdk_rule", "go_host_sdk_rule", "go_multiple_toolchains") +load("//go/private/extensions:go_mod.bzl", "sdk_from_go_mod") def host_compatible_toolchain_impl(ctx): ctx.file("BUILD.bazel") @@ -48,23 +49,34 @@ ), } -_download_tag = tag_class( - attrs = _COMMON_ATTRS | { - "goos": attr.string(), - "goarch": attr.string(), - "sdks": attr.string_list_dict(), - "urls": attr.string_list(default = ["https://dl.google.com/go/{}"]), - "patches": attr.label_list( - doc = "A list of patches to apply to the SDK after downloading it", +_from_file_tag = tag_class( + attrs = { + "go_mod": attr.label( + mandatory = True, + doc = "The go.mod file to read Go SDK information from", ), - "patch_strip": attr.int( - default = 0, - doc = "The number of leading path segments to be stripped from the file name in the patches.", - ), - "strip_prefix": attr.string(default = "go"), }, ) +_DOWNLOAD_ATTRS = _COMMON_ATTRS | { + "goos": attr.string(), + "goarch": attr.string(), + "sdks": attr.string_list_dict(), + "urls": attr.string_list(default = ["https://dl.google.com/go/{}"]), + "patches": attr.label_list( + doc = "A list of patches to apply to the SDK after downloading it", + ), + "patch_strip": attr.int( + default = 0, + doc = "The number of leading path segments to be stripped from the file name in the patches.", + ), + "strip_prefix": attr.string(default = "go"), +} + +_download_tag = tag_class( + attrs = _DOWNLOAD_ATTRS, +) + _host_tag = tag_class( attrs = _COMMON_ATTRS, ) @@ -160,7 +172,7 @@ host_detected_goos, host_detected_goarch = detect_host_platform(ctx) toolchains = [] for module in ctx.modules: - for index, download_tag in enumerate(module.tags.download): + for index, download_tag in enumerate(_get_download_tags(ctx, module.tags)): # SDKs without an explicit version are fetched even when not selected by toolchain # resolution. This is acceptable if brought in by the root module, but transitive # dependencies should not slow down the build in this way. @@ -299,6 +311,23 @@ else: return None +def _get_download_tags(ctx, tags): + download_tags = tags.download + if tags.from_file: + if len(tags.from_file) > 1: + fail("go_sdk: only one tag can be specified per module, got:\n", *[t for p in zip(tags.from_file, len(tags.from_file) * ["\n"]) for t in p]) + from_file_tag = tags.from_file[0] + versions = sdk_from_go_mod(ctx, from_file_tag.go_mod) + download_tag_attrs = {k: None for k in _DOWNLOAD_ATTRS.keys()} + download_tag_attrs["version"] = versions.toolchain.removeprefix("go") if versions.toolchain else versions.go + download_tag_attrs["language_version"] = versions.go + + # Required by go_multiple_toolchains. + download_tag_attrs["goos"] = "" + download_tag_attrs["goarch"] = "" + download_tags = [struct(**download_tag_attrs)] + download_tags + return download_tags + def _default_go_sdk_name(*, module, multi_version, tag_type, index, suffix = ""): # Keep the version out of the repository name if possible to prevent unnecessary rebuilds when # it changes. @@ -335,6 +364,7 @@ go_sdk = module_extension( implementation = _go_sdk_impl, tag_classes = { + "from_file": _from_file_tag, "download": _download_tag, "host": _host_tag, "nogo": _nogo_tag,
diff --git a/go/private/extensions/BUILD.bazel b/go/private/extensions/BUILD.bazel new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/go/private/extensions/BUILD.bazel
diff --git a/go/private/extensions/go_mod.bzl b/go/private/extensions/go_mod.bzl new file mode 100644 index 0000000..8d7f1e3 --- /dev/null +++ b/go/private/extensions/go_mod.bzl
@@ -0,0 +1,135 @@ +# Copyright 2023 The Bazel Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +visibility(["//go/private"]) + +def sdk_from_go_mod(module_ctx, go_mod_label): + """Loads the Go SDK information from go.mod. + + Args: + module_ctx: a https://bazel.build/rules/lib/module_ctx object passed + from the MODULE.bazel call. + go_mod_label: a Label for a `go.mod` file. + + Returns: + a struct with the following fields: + go: a string containing the effective go directive value (never None) + toolchain: a string containing the raw toolchain directive value (can be None) + """ + if go_mod_label.name != "go.mod": + fail("go_deps.from_file requires a 'go.mod' file, not '{}'".format(go_mod_label.name)) + + go_mod_path = module_ctx.path(go_mod_label) + go_mod_content = module_ctx.read(go_mod_path) + return _parse_go_mod(go_mod_content, go_mod_path) + +def _parse_go_mod(content, path): + # See https://go.dev/ref/mod#go-mod-file. + + # Valid directive values understood by this parser never contain tabs or + # carriage returns, so we can simplify the parsing below by canonicalizing + # whitespace upfront. + content = content.replace("\t", " ").replace("\r", " ") + + state = { + "go": None, + "toolchain": None, + } + + # Since we only care about simple directives, we parse the file in a lax way + # and just skip over anything that looks like a block. + in_block = None + for line_no, line in enumerate(content.splitlines(), 1): + tokens, comment = _tokenize_line(line, path, line_no) + if not tokens: + continue + tok = tokens[0] + if not in_block: + if tok in ("go", "toolchain"): + if len(tokens) == 1: + fail("{}:{}: expected another token after '{}'".format(path, line_no, tok)) + if state[tok] != None: + fail("{}:{}: unexpected second '{}' directive".format(path, line_no, tok)) + if len(tokens) > 2: + fail("{}:{}: unexpected token '{}' after '{}'".format(path, line_no, tokens[2], tokens[1])) + state[tok] = tokens[1] + continue + if len(tokens) >= 2 and tokens[1] == "(": + in_block = True + if len(tokens) > 2: + fail("{}:{}: unexpected token '{}' after '('".format(path, line_no, tokens[2])) + elif tok == ")": + in_block = False + if len(tokens) > 1: + fail("{}:{}: unexpected token '{}' after ')'".format(path, line_no, tokens[1])) + + return struct( + # "As of the Go 1.17 release, if the go directive is missing, go 1.16 is assumed." + go = state["go"] or "1.16", + toolchain = state["toolchain"], + ) + +def _tokenize_line(line, path, line_no): + tokens = [] + r = line + for _ in range(len(line)): + r = r.strip() + if not r: + break + + if r[0] == "`": + end = r.find("`", 1) + if end == -1: + fail("{}:{}: unterminated raw string".format(path, line_no)) + + tokens.append(r[1:end]) + r = r[end + 1:] + + elif r[0] == "\"": + value = "" + escaped = False + found_end = False + for pos in range(1, len(r)): + c = r[pos] + + if escaped: + value += c + escaped = False + continue + + if c == "\\": + escaped = True + continue + + if c == "\"": + found_end = True + break + + value += c + + if not found_end: + fail("{}:{}: unterminated interpreted string".format(path, line_no)) + + tokens.append(value) + r = r[pos + 1:] + + elif r.startswith("//"): + # A comment always ends the current line + return tokens, r[len("//"):].strip() + + else: + token, _, r = r.partition(" ") + tokens.append(token) + + return tokens, None
diff --git a/tests/bcr/BUILD.bazel b/tests/bcr/BUILD.bazel index a7a0c14..10e241b 100644 --- a/tests/bcr/BUILD.bazel +++ b/tests/bcr/BUILD.bazel
@@ -1,8 +1,10 @@ +load("@bazel_skylib//rules:native_binary.bzl", "native_test") load("@my_rules_go//extras:gomock.bzl", "gomock") load( "@my_rules_go//go:def.bzl", "TOOLS_NOGO", "go_binary", + "go_cross_binary", "go_library", "go_test", "nogo", @@ -32,9 +34,23 @@ embed = [":lib"], ) -go_test( +native_test( name = "sdk_patch_test", + src = ":sdk_patch_test_bin", +) + +go_cross_binary( + name = "sdk_patch_test_bin", + testonly = True, + # Select the patched SDK. + sdk_version = "1.23.0", + target = ":sdk_patch_test_orig", +) + +go_test( + name = "sdk_patch_test_orig", srcs = ["sdk_patch_test.go"], + tags = ["manual"], ) go_test(
diff --git a/tests/bcr/MODULE.bazel b/tests/bcr/MODULE.bazel index 3e79a79..5cedbb8 100644 --- a/tests/bcr/MODULE.bazel +++ b/tests/bcr/MODULE.bazel
@@ -18,6 +18,7 @@ path = "../..", ) +bazel_dep(name = "bazel_skylib", version = "1.6.1") bazel_dep(name = "gazelle", version = "0.33.0") bazel_dep(name = "protobuf", version = "3.19.6") @@ -26,13 +27,13 @@ bazel_dep(name = "toolchains_protoc", version = "0.2.1") go_sdk = use_extension("@my_rules_go//go:extensions.bzl", "go_sdk") +go_sdk.from_file(go_mod = "//:go.mod") go_sdk.download( - language_version = "1.21", patch_strip = 1, patches = [ "//:test_go_sdk.patch", ], - version = "1.23.1", + version = "1.23.0", ) # Request an invalid SDK to verify that it isn't fetched since the first tag takes precedence.