blob: ae34f08a8adb9fddd0e7b6011e4f359e31ac3796 [file] [log] [blame]
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""Contains FieldMask class."""
from google.protobuf.descriptor import FieldDescriptor
class FieldMask(object):
"""Class for FieldMask message type."""
__slots__ = ()
def ToJsonString(self):
"""Converts FieldMask to string according to proto3 JSON spec."""
camelcase_paths = []
for path in self.paths:
camelcase_paths.append(_SnakeCaseToCamelCase(path))
return ','.join(camelcase_paths)
def FromJsonString(self, value):
"""Converts string to FieldMask according to proto3 JSON spec."""
if not isinstance(value, str):
raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
self.Clear()
if value:
for path in value.split(','):
self.paths.append(_CamelCaseToSnakeCase(path))
def IsValidForDescriptor(self, message_descriptor):
"""Checks whether the FieldMask is valid for Message Descriptor."""
for path in self.paths:
if not _IsValidPath(message_descriptor, path):
return False
return True
def AllFieldsFromDescriptor(self, message_descriptor):
"""Gets all direct fields of Message Descriptor to FieldMask."""
self.Clear()
for field in message_descriptor.fields:
self.paths.append(field.name)
def CanonicalFormFromMask(self, mask):
"""Converts a FieldMask to the canonical form.
Removes paths that are covered by another path. For example,
"foo.bar" is covered by "foo" and will be removed if "foo"
is also in the FieldMask. Then sorts all paths in alphabetical order.
Args:
mask: The original FieldMask to be converted.
"""
tree = _FieldMaskTree(mask)
tree.ToFieldMask(self)
def Union(self, mask1, mask2):
"""Merges mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
tree.MergeFromFieldMask(mask2)
tree.ToFieldMask(self)
def Intersect(self, mask1, mask2):
"""Intersects mask1 and mask2 into this FieldMask."""
_CheckFieldMaskMessage(mask1)
_CheckFieldMaskMessage(mask2)
tree = _FieldMaskTree(mask1)
intersection = _FieldMaskTree()
for path in mask2.paths:
tree.IntersectPath(path, intersection)
intersection.ToFieldMask(self)
def MergeMessage(
self, source, destination,
replace_message_field=False, replace_repeated_field=False):
"""Merges fields specified in FieldMask from source to destination.
Args:
source: Source message.
destination: The destination message to be merged into.
replace_message_field: Replace message field if True. Merge message
field if False.
replace_repeated_field: Replace repeated field if True. Append
elements of repeated field if False.
"""
tree = _FieldMaskTree(self)
tree.MergeMessage(
source, destination, replace_message_field, replace_repeated_field)
def _IsValidPath(message_descriptor, path):
"""Checks whether the path is valid for Message Descriptor."""
parts = path.split('.')
last = parts.pop()
for name in parts:
field = message_descriptor.fields_by_name.get(name)
if (field is None or
field.label == FieldDescriptor.LABEL_REPEATED or
field.type != FieldDescriptor.TYPE_MESSAGE):
return False
message_descriptor = field.message_type
return last in message_descriptor.fields_by_name
def _CheckFieldMaskMessage(message):
"""Raises ValueError if message is not a FieldMask."""
message_descriptor = message.DESCRIPTOR
if (message_descriptor.name != 'FieldMask' or
message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
raise ValueError('Message {0} is not a FieldMask.'.format(
message_descriptor.full_name))
def _SnakeCaseToCamelCase(path_name):
"""Converts a path name from snake_case to camelCase."""
result = []
after_underscore = False
for c in path_name:
if c.isupper():
raise ValueError(
'Fail to print FieldMask to Json string: Path name '
'{0} must not contain uppercase letters.'.format(path_name))
if after_underscore:
if c.islower():
result.append(c.upper())
after_underscore = False
else:
raise ValueError(
'Fail to print FieldMask to Json string: The '
'character after a "_" must be a lowercase letter '
'in path name {0}.'.format(path_name))
elif c == '_':
after_underscore = True
else:
result += c
if after_underscore:
raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
'in path name {0}.'.format(path_name))
return ''.join(result)
def _CamelCaseToSnakeCase(path_name):
"""Converts a field name from camelCase to snake_case."""
result = []
for c in path_name:
if c == '_':
raise ValueError('Fail to parse FieldMask: Path name '
'{0} must not contain "_"s.'.format(path_name))
if c.isupper():
result += '_'
result += c.lower()
else:
result += c
return ''.join(result)
class _FieldMaskTree(object):
"""Represents a FieldMask in a tree structure.
For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
the FieldMaskTree will be:
[_root] -+- foo -+- bar
| |
| +- baz
|
+- bar --- baz
In the tree, each leaf node represents a field path.
"""
__slots__ = ('_root',)
def __init__(self, field_mask=None):
"""Initializes the tree by FieldMask."""
self._root = {}
if field_mask:
self.MergeFromFieldMask(field_mask)
def MergeFromFieldMask(self, field_mask):
"""Merges a FieldMask to the tree."""
for path in field_mask.paths:
self.AddPath(path)
def AddPath(self, path):
"""Adds a field path into the tree.
If the field path to add is a sub-path of an existing field path
in the tree (i.e., a leaf node), it means the tree already matches
the given path so nothing will be added to the tree. If the path
matches an existing non-leaf node in the tree, that non-leaf node
will be turned into a leaf node with all its children removed because
the path matches all the node's children. Otherwise, a new path will
be added.
Args:
path: The field path to add.
"""
node = self._root
for name in path.split('.'):
if name not in node:
node[name] = {}
elif not node[name]:
# Pre-existing empty node implies we already have this entire tree.
return
node = node[name]
# Remove any sub-trees we might have had.
node.clear()
def ToFieldMask(self, field_mask):
"""Converts the tree to a FieldMask."""
field_mask.Clear()
_AddFieldPaths(self._root, '', field_mask)
def IntersectPath(self, path, intersection):
"""Calculates the intersection part of a field path with this tree.
Args:
path: The field path to calculates.
intersection: The out tree to record the intersection part.
"""
node = self._root
for name in path.split('.'):
if name not in node:
return
elif not node[name]:
intersection.AddPath(path)
return
node = node[name]
intersection.AddLeafNodes(path, node)
def AddLeafNodes(self, prefix, node):
"""Adds leaf nodes begin with prefix to this tree."""
if not node:
self.AddPath(prefix)
for name in node:
child_path = prefix + '.' + name
self.AddLeafNodes(child_path, node[name])
def MergeMessage(
self, source, destination,
replace_message, replace_repeated):
"""Merge all fields specified by this tree from source to destination."""
_MergeMessage(
self._root, source, destination, replace_message, replace_repeated)
def _StrConvert(value):
"""Converts value to str if it is not."""
# This file is imported by c extension and some methods like ClearField
# requires string for the field name. py2/py3 has different text
# type and may use unicode.
if not isinstance(value, str):
return value.encode('utf-8')
return value
def _MergeMessage(
node, source, destination, replace_message, replace_repeated):
"""Merge all fields specified by a sub-tree from source to destination."""
source_descriptor = source.DESCRIPTOR
for name in node:
child = node[name]
field = source_descriptor.fields_by_name[name]
if field is None:
raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
name, source_descriptor.full_name))
if child:
# Sub-paths are only allowed for singular message fields.
if (field.label == FieldDescriptor.LABEL_REPEATED or
field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
raise ValueError('Error: Field {0} in message {1} is not a singular '
'message field and cannot have sub-fields.'.format(
name, source_descriptor.full_name))
if source.HasField(name):
_MergeMessage(
child, getattr(source, name), getattr(destination, name),
replace_message, replace_repeated)
continue
if field.label == FieldDescriptor.LABEL_REPEATED:
if replace_repeated:
destination.ClearField(_StrConvert(name))
repeated_source = getattr(source, name)
repeated_destination = getattr(destination, name)
repeated_destination.MergeFrom(repeated_source)
else:
if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
if replace_message:
destination.ClearField(_StrConvert(name))
if source.HasField(name):
getattr(destination, name).MergeFrom(getattr(source, name))
else:
setattr(destination, name, getattr(source, name))
def _AddFieldPaths(node, prefix, field_mask):
"""Adds the field paths descended from node to field_mask."""
if not node and prefix:
field_mask.paths.append(prefix)
return
for name in sorted(node):
if prefix:
child_path = prefix + '.' + name
else:
child_path = name
_AddFieldPaths(node[name], child_path, field_mask)