| /* |
| * |
| * Copyright 2015 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| #include "src/compiler/python_generator.h" |
| |
| #include <algorithm> |
| #include <cstddef> |
| #include <map> |
| #include <set> |
| #include <sstream> |
| #include <string> |
| |
| #include "codegen/idl_namer.h" |
| #include "codegen/namer.h" |
| #include "codegen/python.h" |
| #include "flatbuffers/idl.h" |
| #include "flatbuffers/util.h" |
| |
| namespace flatbuffers { |
| namespace python { |
| namespace grpc { |
| namespace { |
| bool ClientStreaming(const RPCCall* method) { |
| const Value* val = method->attributes.Lookup("streaming"); |
| return val != nullptr && |
| (val->constant == "client" || val->constant == "bidi"); |
| } |
| |
| bool ServerStreaming(const RPCCall* method) { |
| const Value* val = method->attributes.Lookup("streaming"); |
| return val != nullptr && |
| (val->constant == "server" || val->constant == "bidi"); |
| } |
| |
| void FormatImports(std::stringstream& ss, const Imports& imports) { |
| std::set<std::string> modules; |
| std::map<std::string, std::set<std::string>> names_by_module; |
| for (const Import& import : imports.imports) { |
| if (import.IsLocal()) continue; // skip all local imports |
| if (import.name == "") { |
| modules.insert(import.module); |
| } else { |
| names_by_module[import.module].insert(import.name); |
| } |
| } |
| |
| for (const std::string& module : modules) { |
| ss << "import " << module << '\n'; |
| } |
| ss << '\n'; |
| for (const auto& import : names_by_module) { |
| ss << "from " << import.first << " import "; |
| size_t i = 0; |
| for (const std::string& name : import.second) { |
| if (i > 0) ss << ", "; |
| ss << name; |
| ++i; |
| } |
| ss << '\n'; |
| } |
| ss << "\n\n"; |
| } |
| |
| class BaseGenerator { |
| protected: |
| BaseGenerator(const Parser& parser, const Namer::Config& config, |
| const std::string& path, const Version& version) |
| : parser_{parser}, |
| namer_{WithFlagOptions(config, parser.opts, path), Keywords(version)}, |
| version_{version}, |
| path_(path) {} |
| |
| protected: |
| std::string ModuleForFile(const std::string& file) const { |
| std::string module = parser_.opts.include_prefix + StripExtension(file) + |
| parser_.opts.filename_suffix; |
| std::replace(module.begin(), module.end(), '/', '.'); |
| return module; |
| } |
| |
| template <typename T> |
| std::string ModuleFor(const T* def) const { |
| if (parser_.opts.one_file) return ModuleForFile(def->file); |
| return namer_.NamespacedType(*def); |
| } |
| |
| std::string NamespaceDir(const Parser& parser, const std::string& path, |
| const Namespace& ns, const bool dasherize) { |
| EnsureDirExists(path); |
| if (parser.opts.one_file) return path; |
| std::string namespace_dir = path; // Either empty or ends in separator. |
| auto& namespaces = ns.components; |
| for (auto it = namespaces.begin(); it != namespaces.end(); ++it) { |
| namespace_dir += |
| !dasherize ? *it : ConvertCase(*it, Case::kDasher, Case::kUpperCamel); |
| namespace_dir += kPathSeparator; |
| EnsureDirExists(namespace_dir); |
| } |
| return namespace_dir; |
| } |
| |
| std::string NamespaceDir(const Namespace& ns, const bool dasherize) { |
| return NamespaceDir(parser_, path_, ns, dasherize); |
| } |
| |
| const Parser& parser_; |
| const IdlNamer namer_; |
| const Version version_; |
| const std::string& path_; |
| }; |
| |
| class StubGenerator : public BaseGenerator { |
| public: |
| StubGenerator(const Parser& parser, const std::string& path, |
| const Version& version) |
| : BaseGenerator(parser, kStubConfig, path, version) {} |
| |
| bool Generate() { |
| Imports imports; |
| std::stringstream stub; |
| std::string ns_name{}; |
| for (const ServiceDef* service : parser_.services_.vec) { |
| Generate(stub, service, &imports); |
| ns_name = NamespaceDir(*service->defined_namespace, false); |
| } |
| |
| std::string sanitized_suffix{parser_.opts.grpc_filename_suffix}; |
| std::replace(sanitized_suffix.begin(), sanitized_suffix.end(), '.', '_'); |
| std::string filename = |
| ns_name + kPathSeparator + |
| StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" + |
| sanitized_suffix + namer_.config_.filename_extension; |
| |
| return SaveStub(filename, imports, stub.str()); |
| } |
| |
| private: |
| bool SaveStub(const std::string& filename, const Imports& imports, |
| const std::string& content) { |
| std::stringstream ss; |
| ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n" |
| << '\n' |
| << "from __future__ import annotations\n" |
| << '\n'; |
| FormatImports(ss, imports); |
| ss << content << '\n'; |
| |
| EnsureDirExists(StripFileName(filename)); |
| return parser_.opts.file_saver->SaveFile(filename.c_str(), ss.str(), false); |
| } |
| |
| void Generate(std::stringstream& ss, const ServiceDef* service, |
| Imports* imports) { |
| imports->Import("grpc"); |
| |
| ss << "class " << service->name << "Stub(object):\n" |
| << " def __init__(self, channel: grpc.Channel) -> None: ...\n"; |
| |
| for (const RPCCall* method : service->calls.vec) { |
| std::string request = "bytes"; |
| std::string response = "bytes"; |
| |
| if (parser_.opts.grpc_python_typed_handlers) { |
| request = namer_.Type(*method->request); |
| response = namer_.Type(*method->response); |
| |
| imports->Import(ModuleFor(method->request), request); |
| imports->Import(ModuleFor(method->response), response); |
| } |
| |
| ss << " def " << method->name << "(self, "; |
| if (ClientStreaming(method)) { |
| imports->Import("typing"); |
| ss << "request_iterator: typing.Iterator[" << request << "]"; |
| } else { |
| ss << "request: " << request; |
| } |
| ss << ") -> "; |
| if (ServerStreaming(method)) { |
| imports->Import("typing"); |
| ss << "typing.Iterator[" << response << "]"; |
| } else { |
| ss << response; |
| } |
| ss << ": ...\n"; |
| } |
| |
| ss << "\n\n"; |
| ss << "class " << service->name << "Servicer(object):\n"; |
| |
| for (const RPCCall* method : service->calls.vec) { |
| std::string request = "bytes"; |
| std::string response = "bytes"; |
| |
| if (parser_.opts.grpc_python_typed_handlers) { |
| request = namer_.Type(*method->request); |
| response = namer_.Type(*method->response); |
| |
| imports->Import(ModuleFor(method->request), request); |
| imports->Import(ModuleFor(method->response), response); |
| } |
| |
| ss << " def " << method->name << "(self, "; |
| if (ClientStreaming(method)) { |
| imports->Import("typing"); |
| ss << "request_iterator: typing.Iterator[" << request << "]"; |
| } else { |
| ss << "request: " << request; |
| } |
| ss << ", context: grpc.ServicerContext) -> "; |
| if (ServerStreaming(method)) { |
| imports->Import("typing"); |
| ss << "typing.Iterator[" << response << "]"; |
| } else { |
| ss << response; |
| } |
| ss << ": ...\n"; |
| } |
| |
| ss << '\n' |
| << '\n' |
| << "def add_" << service->name |
| << "Servicer_to_server(servicer: " << service->name |
| << "Servicer, server: grpc.Server) -> None: ...\n"; |
| } |
| }; |
| |
| class ServiceGenerator : public BaseGenerator { |
| public: |
| ServiceGenerator(const Parser& parser, const std::string& path, |
| const Version& version) |
| : BaseGenerator(parser, kConfig, path, version) {} |
| |
| bool Generate() { |
| Imports imports; |
| std::stringstream ss; |
| |
| imports.Import("flatbuffers"); |
| |
| if (parser_.opts.grpc_python_typed_handlers) { |
| ss << "def _serialize_to_bytes(table):\n" |
| << " buf = table._tab.Bytes\n" |
| << " n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)\n" |
| << " if table._tab.Pos != n:\n" |
| << " raise ValueError('must be a top-level table')\n" |
| << " return bytes(buf)\n" |
| << '\n' |
| << '\n'; |
| } |
| |
| std::string ns_name{}; |
| for (const ServiceDef* service : parser_.services_.vec) { |
| GenerateStub(ss, service, &imports); |
| GenerateServicer(ss, service, &imports); |
| GenerateRegister(ss, service, &imports); |
| ns_name = NamespaceDir(*service->defined_namespace, false); |
| } |
| |
| std::string sanitized_suffix{parser_.opts.grpc_filename_suffix}; |
| std::replace(sanitized_suffix.begin(), sanitized_suffix.end(), '.', '_'); |
| std::string filename = |
| ns_name + kPathSeparator + |
| StripPath(StripExtension(parser_.file_being_parsed_)) + "_grpc" + |
| sanitized_suffix + namer_.config_.filename_extension; |
| |
| return SaveService(filename, imports, ss.str()); |
| } |
| |
| private: |
| bool SaveService(const std::string& filename, const Imports& imports, |
| const std::string& content) { |
| std::stringstream ss; |
| ss << "# Generated by the gRPC FlatBuffers compiler. DO NOT EDIT!\n" |
| << '\n'; |
| FormatImports(ss, imports); |
| ss << content << '\n'; |
| |
| EnsureDirExists(StripFileName(filename)); |
| return parser_.opts.file_saver->SaveFile(filename.c_str(), ss.str(), false); |
| } |
| |
| void GenerateStub(std::stringstream& ss, const ServiceDef* service, |
| Imports* imports) { |
| ss << "class " << service->name << "Stub"; |
| if (version_.major != 3) ss << "(object)"; |
| ss << ":\n" |
| << " '''Interface exported by the server.'''\n" |
| << '\n' |
| << " def __init__(self, channel):\n" |
| << " '''Constructor.\n" |
| << '\n' |
| << " Args:\n" |
| << " channel: A grpc.Channel.\n" |
| << " '''\n" |
| << '\n'; |
| |
| for (const RPCCall* method : service->calls.vec) { |
| std::string response = namer_.Type(*method->response); |
| |
| imports->Import(ModuleFor(method->response), response); |
| |
| ss << " self." << method->name << " = channel." |
| << (ClientStreaming(method) ? "stream" : "unary") << "_" |
| << (ServerStreaming(method) ? "stream" : "unary") << "(\n" |
| << " method='/" |
| << service->defined_namespace->GetFullyQualifiedName(service->name) |
| << "/" << method->name << "'"; |
| |
| if (parser_.opts.grpc_python_typed_handlers) { |
| ss << ",\n" |
| << " request_serializer=_serialize_to_bytes,\n" |
| << " response_deserializer=" << response << ".GetRootAs"; |
| } |
| ss << ")\n\n"; |
| } |
| |
| ss << '\n'; |
| } |
| |
| void GenerateServicer(std::stringstream& ss, const ServiceDef* service, |
| Imports* imports) { |
| imports->Import("grpc"); |
| |
| ss << "class " << service->name << "Servicer"; |
| if (version_.major != 3) ss << "(object)"; |
| ss << ":\n" |
| << " '''Interface exported by the server.'''\n" |
| << '\n'; |
| |
| for (const RPCCall* method : service->calls.vec) { |
| const std::string request_param = |
| ClientStreaming(method) ? "request_iterator" : "request"; |
| ss << " def " << method->name << "(self, " << request_param |
| << ", context):\n" |
| << " context.set_code(grpc.StatusCode.UNIMPLEMENTED)\n" |
| << " context.set_details('Method not implemented!')\n" |
| << " raise NotImplementedError('Method not implemented!')\n" |
| << '\n'; |
| } |
| |
| ss << '\n'; |
| } |
| |
| void GenerateRegister(std::stringstream& ss, const ServiceDef* service, |
| Imports* imports) { |
| imports->Import("grpc"); |
| |
| ss << "def add_" << service->name |
| << "Servicer_to_server(servicer, server):\n" |
| << " rpc_method_handlers = {\n"; |
| |
| for (const RPCCall* method : service->calls.vec) { |
| std::string request = namer_.Type(*method->request); |
| |
| imports->Import(ModuleFor(method->request), request); |
| |
| ss << " '" << method->name << "': grpc." |
| << (ClientStreaming(method) ? "stream" : "unary") << "_" |
| << (ServerStreaming(method) ? "stream" : "unary") |
| << "_rpc_method_handler(\n" |
| << " servicer." << method->name; |
| |
| if (parser_.opts.grpc_python_typed_handlers) { |
| ss << ",\n" |
| << " request_deserializer=" << request << ".GetRootAs,\n" |
| << " response_serializer=_serialize_to_bytes"; |
| } |
| ss << "),\n"; |
| } |
| ss << " }\n" |
| << '\n' |
| << " generic_handler = grpc.method_handlers_generic_handler(\n" |
| << " '" |
| << service->defined_namespace->GetFullyQualifiedName(service->name) |
| << "', rpc_method_handlers)\n" |
| << '\n' |
| << " server.add_generic_rpc_handlers((generic_handler,))\n" |
| << '\n'; |
| } |
| }; |
| } // namespace |
| |
| bool Generate(const Parser& parser, const std::string& path, |
| const Version& version) { |
| ServiceGenerator generator{parser, path, version}; |
| return generator.Generate(); |
| } |
| |
| bool GenerateStub(const Parser& parser, const std::string& path, |
| const Version& version) { |
| StubGenerator generator{parser, path, version}; |
| return generator.Generate(); |
| } |
| |
| } // namespace grpc |
| } // namespace python |
| } // namespace flatbuffers |