Add support for including comments from .proto file (#85, #645)
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index 0899129..485f5c9 100755
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -272,9 +272,82 @@
else:
return 2**32 - 1
-class Enum:
- def __init__(self, names, desc, enum_options):
- '''desc is EnumDescriptorProto'''
+
+'''
+Constants regarding path of proto elements in file descriptor.
+They are used to connect proto elements with source code information (comments)
+These values come from:
+ https://github.com/google/protobuf/blob/master/src/google/protobuf/descriptor.proto
+'''
+MESSAGE_PATH = 4
+ENUM_PATH = 5
+FIELD_PATH = 2
+
+
+class ProtoElement(object):
+ def __init__(self, path, index, comments):
+ '''
+ path is a predefined value for each element type in proto file.
+ For example, message == 4, enum == 5, service == 6
+ index is the N-th occurance of the `path` in the proto file.
+ For example, 4-th message in the proto file or 2-nd enum etc ...
+ comments is a dictionary mapping between element path & SourceCodeInfo.Location
+ (contains information about source comments).
+ '''
+ self.path = path
+ self.index = index
+ self.comments = comments
+
+ def element_path(self):
+ '''Get path to proto element.'''
+ return [self.path, self.index]
+
+ def member_path(self, member_index):
+ '''Get path to member of proto element.
+ Example paths:
+ [4, m] - message comments, m: msgIdx in proto from 0
+ [4, m, 2, f] - field comments in message, f: fieldIdx in message from 0
+ [6, s] - service comments, s: svcIdx in proto from 0
+ [6, s, 2, r] - rpc comments in service, r: rpc method def in service from 0
+ '''
+ return self.element_path() + [FIELD_PATH, member_index]
+
+ def get_comments(self, path, leading_indent=True):
+ '''Get leading & trailing comments for enum member based on path.
+
+ path is the proto path of an element or member (ex. [5 0] or [4 1 2 0])
+ leading_indent is a flag that indicates if leading comments should be indented
+ '''
+
+ # Obtain SourceCodeInfo.Location object containing comment
+ # information (based on the member path)
+ comment = self.comments.get(str(path))
+
+ leading_comment = ""
+ trailing_comment = ""
+
+ if not comment:
+ return leading_comment, trailing_comment
+
+ if comment.leading_comments:
+ leading_comment = " " if leading_indent else ""
+ leading_comment += "/* %s */" % comment.leading_comments.strip()
+
+ if comment.trailing_comments:
+ trailing_comment = "/* %s */" % comment.trailing_comments.strip()
+
+ return leading_comment, trailing_comment
+
+
+class Enum(ProtoElement):
+ def __init__(self, names, desc, enum_options, index, comments):
+ '''
+ desc is EnumDescriptorProto
+ index is the index of this enum element inside the file
+ comments is a dictionary mapping between element path & SourceCodeInfo.Location
+ (contains information about source comments)
+ '''
+ super(Enum, self).__init__(ENUM_PATH, index, comments)
self.options = enum_options
self.names = names
@@ -300,8 +373,32 @@
return max([varint_max_size(v) for n,v in self.values])
def __str__(self):
- result = 'typedef enum _%s {\n' % self.names
- result += ',\n'.join([" %s = %d" % x for x in self.values])
+ enum_path = self.element_path()
+ leading_comment, trailing_comment = self.get_comments(enum_path, leading_indent=False)
+
+ result = ''
+ if leading_comment:
+ result = '%s\n' % leading_comment
+
+ result += 'typedef enum _%s { %s\n' % (self.names, trailing_comment)
+
+ enum_length = len(self.values)
+ enum_values = []
+ for index, (name, value) in enumerate(self.values):
+ member_path = self.member_path(index)
+ leading_comment, trailing_comment = self.get_comments(member_path)
+
+ if leading_comment:
+ enum_values.append(leading_comment)
+
+ comma = ","
+ if index == enum_length - 1:
+ # last enum member should not end with a comma
+ comma = ""
+
+ enum_values.append(" %s = %d%s %s" % (name, value, comma, trailing_comment))
+
+ result += '\n'.join(enum_values)
result += '\n}'
if self.packed:
@@ -866,7 +963,8 @@
else:
self.skip = False
self.rules = 'REQUIRED' # We don't really want the has_field for extensions
- self.msg = Message(self.fullname + "extmsg", None, field_options)
+ # currently no support for comments for extension fields => provide 0, {}
+ self.msg = Message(self.fullname + "extmsg", None, field_options, 0, {})
self.msg.fields.append(self)
def tags(self):
@@ -1021,8 +1119,9 @@
# ---------------------------------------------------------------------------
-class Message:
- def __init__(self, names, desc, message_options):
+class Message(ProtoElement):
+ def __init__(self, names, desc, message_options, index, comments):
+ super(Message, self).__init__(MESSAGE_PATH, index, comments)
self.name = names
self.fields = []
self.oneofs = {}
@@ -1107,14 +1206,31 @@
return deps
def __str__(self):
- result = 'typedef struct _%s {\n' % self.name
+ message_path = self.element_path()
+ leading_comment, trailing_comment = self.get_comments(message_path, leading_indent=False)
+
+ result = ''
+ if leading_comment:
+ result = '%s\n' % leading_comment
+
+ result += 'typedef struct _%s { %s\n' % (self.name, trailing_comment)
if not self.fields:
# Empty structs are not allowed in C standard.
# Therefore add a dummy field if an empty message occurs.
result += ' char dummy_field;'
- result += '\n'.join([str(f) for f in self.fields])
+ msg_fields = []
+ for index, field in enumerate(self.fields):
+ member_path = self.member_path(index)
+ leading_comment, trailing_comment = self.get_comments(member_path)
+
+ if leading_comment:
+ msg_fields.append(leading_comment)
+
+ msg_fields.append("%s %s" % (str(field), trailing_comment))
+
+ result += '\n'.join(msg_fields)
if Globals.protoc_insertion_points:
result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name
@@ -1496,12 +1612,21 @@
else:
base_name = Names()
- for enum in self.fdesc.enum_type:
+ # process source code comment locations
+ # ignores any locations that do not contain any comment information
+ self.comment_locations = {
+ str(location.path): location
+ for location in self.fdesc.source_code_info.location
+ if location.leading_comments or location.leading_detached_comments or location.trailing_comments
+ }
+ # breakpoint()
+
+ for index, enum in enumerate(self.fdesc.enum_type):
name = create_name(enum.name)
enum_options = get_nanopb_suboptions(enum, self.file_options, name)
- self.enums.append(Enum(name, enum, enum_options))
+ self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations))
- for names, message in iterate_messages(self.fdesc, flatten):
+ for index, (names, message) in enumerate(iterate_messages(self.fdesc, flatten)):
name = create_name(names)
message_options = get_nanopb_suboptions(message, self.file_options, name)
@@ -1513,11 +1638,11 @@
if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
field.type_name = mangle_field_typename(field.type_name)
- self.messages.append(Message(name, message, message_options))
- for enum in message.enum_type:
+ self.messages.append(Message(name, message, message_options, index, self.comment_locations))
+ for index, enum in enumerate(message.enum_type):
name = create_name(names + enum.name)
enum_options = get_nanopb_suboptions(enum, message_options, name)
- self.enums.append(Enum(name, enum, enum_options))
+ self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations))
for names, extension in iterate_extensions(self.fdesc, flatten):
name = create_name(names + extension.name)