diff options
author | Wink Saville <wink@google.com> | 2010-05-29 13:00:38 -0700 |
---|---|---|
committer | Wink Saville <wink@google.com> | 2010-05-29 13:00:38 -0700 |
commit | d0332953cda33fb4f8e24ebff9c49159b69c43d6 (patch) | |
tree | 81612e8b12f590310aeb0ebf1da37b304eb7baa6 /python/google | |
parent | ede38fe9b9f93888e6e41afc7abb09525f44da95 (diff) | |
download | external_protobuf-d0332953cda33fb4f8e24ebff9c49159b69c43d6.zip external_protobuf-d0332953cda33fb4f8e24ebff9c49159b69c43d6.tar.gz external_protobuf-d0332953cda33fb4f8e24ebff9c49159b69c43d6.tar.bz2 |
Add protobuf 2.3.0 sources
This is the contents of protobuf-2.3.0.tar.bz2 from
http://code.google.com/p/protobuf/downloads/list.
Change-Id: Idfde09ce7ef5ac027b07ee83f2674fbbed5c30b2
Diffstat (limited to 'python/google')
23 files changed, 3090 insertions, 3332 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 8e3fc2e..aa4ab96 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -44,12 +44,24 @@ file, in types that make this information accessible in Python. __author__ = 'robinson@google.com (Will Robinson)' + +class Error(Exception): + """Base error for this module.""" + + class DescriptorBase(object): """Descriptors base class. This class is the base of all descriptor classes. It provides common options related functionaility. + + Attributes: + has_options: True if the descriptor has non-default options. Usually it + is not necessary to read this -- just call GetOptions() which will + happily return the default instance. However, it's sometimes useful + for efficiency, and also useful inside the protobuf implementation to + avoid some bootstrapping issues. """ def __init__(self, options, options_class_name): @@ -60,6 +72,9 @@ class DescriptorBase(object): self._options = options self._options_class_name = options_class_name + # Does this descriptor have non-default options? + self.has_options = options is not None + def GetOptions(self): """Retrieves descriptor options. @@ -78,7 +93,70 @@ class DescriptorBase(object): return self._options -class Descriptor(DescriptorBase): +class _NestedDescriptorBase(DescriptorBase): + """Common class for descriptors that can be nested.""" + + def __init__(self, options, options_class_name, name, full_name, + file, containing_type, serialized_start=None, + serialized_end=None): + """Constructor. + + Args: + options: Protocol message options or None + to use default message options. + options_class_name: (str) The class name of the above options. + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + file: (FileDescriptor) Reference to file info. + containing_type: if provided, this is a nested descriptor, with this + descriptor as parent, otherwise None. + serialized_start: The start index (inclusive) in block in the + file.serialized_pb that describes this descriptor. + serialized_end: The end index (exclusive) in block in the + file.serialized_pb that describes this descriptor. + """ + super(_NestedDescriptorBase, self).__init__( + options, options_class_name) + + self.name = name + # TODO(falk): Add function to calculate full_name instead of having it in + # memory? + self.full_name = full_name + self.file = file + self.containing_type = containing_type + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def GetTopLevelContainingType(self): + """Returns the root if this is a nested type, or itself if its the root.""" + desc = self + while desc.containing_type is not None: + desc = desc.containing_type + return desc + + def CopyToProto(self, proto): + """Copies this to the matching proto in descriptor_pb2. + + Args: + proto: An empty proto instance from descriptor_pb2. + + Raises: + Error: If self couldnt be serialized, due to to few constructor arguments. + """ + if (self.file is not None and + self._serialized_start is not None and + self._serialized_end is not None): + proto.ParseFromString(self.file.serialized_pb[ + self._serialized_start:self._serialized_end]) + else: + raise Error('Descriptor does not contain serialization.') + + +class Descriptor(_NestedDescriptorBase): """Descriptor for a protocol message type. @@ -89,10 +167,8 @@ class Descriptor(DescriptorBase): which will include protocol "package" name and the name of any enclosing types. - filename: (str) Name of the .proto file containing this message. - containing_type: (Descriptor) Reference to the descriptor of the - type containing us, or None if we have no containing type. + type containing us, or None if this is top-level. fields: (list of FieldDescriptors) Field descriptors for all fields in this type. @@ -123,20 +199,28 @@ class Descriptor(DescriptorBase): objects as |extensions|, but indexed by "name" attribute of each FieldDescriptor. + is_extendable: Does this type define any extension ranges? + options: (descriptor_pb2.MessageOptions) Protocol message options or None to use default message options. + + file: (FileDescriptor) Reference to file descriptor. """ - def __init__(self, name, full_name, filename, containing_type, - fields, nested_types, enum_types, extensions, options=None): + def __init__(self, name, full_name, filename, containing_type, fields, + nested_types, enum_types, extensions, options=None, + is_extendable=True, extension_ranges=None, file=None, + serialized_start=None, serialized_end=None): """Arguments to __init__() are as described in the description of Descriptor fields above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. """ - super(Descriptor, self).__init__(options, 'MessageOptions') - self.name = name - self.full_name = full_name - self.filename = filename - self.containing_type = containing_type + super(Descriptor, self).__init__( + options, 'MessageOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_start) # We have fields in addition to fields_by_name and fields_by_number, # so that: @@ -163,6 +247,20 @@ class Descriptor(DescriptorBase): for extension in self.extensions: extension.extension_scope = self self.extensions_by_name = dict((f.name, f) for f in extensions) + self.is_extendable = is_extendable + self.extension_ranges = extension_ranges + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.DescriptorProto. + + Args: + proto: An empty descriptor_pb2.DescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(Descriptor, self).CopyToProto(proto) # TODO(robinson): We should have aggressive checking here, @@ -195,6 +293,8 @@ class FieldDescriptor(DescriptorBase): label: (One of the LABEL_* constants below) Tells whether this field is optional, required, or repeated. + has_default_value: (bool) True if this field has a default value defined, + otherwise false. default_value: (Varies) Default value of this field. Only meaningful for non-repeated scalar fields. Repeated fields should always set this to [], and non-repeated composite @@ -272,7 +372,8 @@ class FieldDescriptor(DescriptorBase): def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, - is_extension, extension_scope, options=None): + is_extension, extension_scope, options=None, + has_default_value=True): """The arguments are as described in the description of FieldDescriptor attributes above. @@ -288,6 +389,7 @@ class FieldDescriptor(DescriptorBase): self.type = type self.cpp_type = cpp_type self.label = label + self.has_default_value = has_default_value self.default_value = default_value self.containing_type = containing_type self.message_type = message_type @@ -296,7 +398,7 @@ class FieldDescriptor(DescriptorBase): self.extension_scope = extension_scope -class EnumDescriptor(DescriptorBase): +class EnumDescriptor(_NestedDescriptorBase): """Descriptor for an enum defined in a .proto file. @@ -305,7 +407,6 @@ class EnumDescriptor(DescriptorBase): name: (str) Name of the enum type. full_name: (str) Full name of the type, including package name and any enclosing type(s). - filename: (str) Name of the .proto file in which this appears. values: (list of EnumValueDescriptors) List of the values in this enum. @@ -317,23 +418,41 @@ class EnumDescriptor(DescriptorBase): type of this enum, or None if this is an enum defined at the top level in a .proto file. Set by Descriptor's constructor if we're passed into one. + file: (FileDescriptor) Reference to file descriptor. options: (descriptor_pb2.EnumOptions) Enum options message or None to use default enum options. """ def __init__(self, name, full_name, filename, values, - containing_type=None, options=None): - """Arguments are as described in the attribute description above.""" - super(EnumDescriptor, self).__init__(options, 'EnumOptions') - self.name = name - self.full_name = full_name - self.filename = filename + containing_type=None, options=None, file=None, + serialized_start=None, serialized_end=None): + """Arguments are as described in the attribute description above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. + """ + super(EnumDescriptor, self).__init__( + options, 'EnumOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_start) + self.values = values for value in self.values: value.type = self self.values_by_name = dict((v.name, v) for v in values) self.values_by_number = dict((v.number, v) for v in values) - self.containing_type = containing_type + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.EnumDescriptorProto. + + Args: + proto: An empty descriptor_pb2.EnumDescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(EnumDescriptor, self).CopyToProto(proto) class EnumValueDescriptor(DescriptorBase): @@ -360,7 +479,7 @@ class EnumValueDescriptor(DescriptorBase): self.type = type -class ServiceDescriptor(DescriptorBase): +class ServiceDescriptor(_NestedDescriptorBase): """Descriptor for a service. @@ -372,12 +491,15 @@ class ServiceDescriptor(DescriptorBase): service. options: (descriptor_pb2.ServiceOptions) Service options message or None to use default service options. + file: (FileDescriptor) Reference to file info. """ - def __init__(self, name, full_name, index, methods, options=None): - super(ServiceDescriptor, self).__init__(options, 'ServiceOptions') - self.name = name - self.full_name = full_name + def __init__(self, name, full_name, index, methods, options=None, file=None, + serialized_start=None, serialized_end=None): + super(ServiceDescriptor, self).__init__( + options, 'ServiceOptions', name, full_name, file, + None, serialized_start=serialized_start, + serialized_end=serialized_end) self.index = index self.methods = methods # Set the containing service for each method in this service. @@ -391,6 +513,15 @@ class ServiceDescriptor(DescriptorBase): return method return None + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.ServiceDescriptorProto. + + Args: + proto: An empty descriptor_pb2.ServiceDescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(ServiceDescriptor, self).CopyToProto(proto) + class MethodDescriptor(DescriptorBase): @@ -423,6 +554,32 @@ class MethodDescriptor(DescriptorBase): self.output_type = output_type +class FileDescriptor(DescriptorBase): + """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto. + + name: name of file, relative to root of source tree. + package: name of the package + serialized_pb: (str) Byte string of serialized + descriptor_pb2.FileDescriptorProto. + """ + + def __init__(self, name, package, options=None, serialized_pb=None): + """Constructor.""" + super(FileDescriptor, self).__init__(options, 'FileOptions') + + self.name = name + self.package = package + self.serialized_pb = serialized_pb + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.FileDescriptorProto. + + Args: + proto: An empty descriptor_pb2.FileDescriptorProto. + """ + proto.ParseFromString(self.serialized_pb) + + def _ParseOptions(message, string): """Parses serialized options. @@ -430,4 +587,4 @@ def _ParseOptions(message, string): proto2 files. It must not be used outside proto2. """ message.ParseFromString(string) - return message; + return message diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index d8a825d..5cc7d6d 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -54,8 +54,7 @@ class BaseContainer(object): Args: message_listener: A MessageListener implementation. The RepeatedScalarFieldContainer will call this object's - TransitionToNonempty() method when it transitions from being empty to - being nonempty. + Modified() method when it is modified. """ self._message_listener = message_listener self._values = [] @@ -73,6 +72,9 @@ class BaseContainer(object): # The concrete classes should define __eq__. return not self == other + def __repr__(self): + return repr(self._values) + class RepeatedScalarFieldContainer(BaseContainer): @@ -86,8 +88,7 @@ class RepeatedScalarFieldContainer(BaseContainer): Args: message_listener: A MessageListener implementation. The RepeatedScalarFieldContainer will call this object's - TransitionToNonempty() method when it transitions from being empty to - being nonempty. + Modified() method when it is modified. type_checker: A type_checkers.ValueChecker instance to run on elements inserted into this container. """ @@ -96,44 +97,47 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self.insert(len(self._values), value) + self._type_checker.CheckValue(value) + self._values.append(value) + if not self._message_listener.dirty: + self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" self._type_checker.CheckValue(value) self._values.insert(key, value) - self._message_listener.ByteSizeDirty() - if len(self._values) == 1: - self._message_listener.TransitionToNonempty() + if not self._message_listener.dirty: + self._message_listener.Modified() def extend(self, elem_seq): """Extends by appending the given sequence. Similar to list.extend().""" if not elem_seq: return - orig_empty = len(self._values) == 0 new_values = [] for elem in elem_seq: self._type_checker.CheckValue(elem) new_values.append(elem) self._values.extend(new_values) - self._message_listener.ByteSizeDirty() - if orig_empty: - self._message_listener.TransitionToNonempty() + self._message_listener.Modified() + + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one. We do not check the types of the individual fields. + """ + self._values.extend(other._values) + self._message_listener.Modified() def remove(self, elem): """Removes an item from the list. Similar to list.remove().""" self._values.remove(elem) - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __setitem__(self, key, value): """Sets the item on the specified position.""" - # No need to call TransitionToNonempty(), since if we're able to - # set the element at this index, we were already nonempty before - # this method was called. - self._message_listener.ByteSizeDirty() self._type_checker.CheckValue(value) self._values[key] = value + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -146,17 +150,17 @@ class RepeatedScalarFieldContainer(BaseContainer): self._type_checker.CheckValue(value) new_values.append(value) self._values[start:stop] = new_values - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __delitem__(self, key): """Deletes the item at the specified position.""" del self._values[key] - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __delslice__(self, start, stop): """Deletes the subset of items from between the specified indices.""" del self._values[start:stop] - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __eq__(self, other): """Compares the current instance with another one.""" @@ -186,8 +190,7 @@ class RepeatedCompositeFieldContainer(BaseContainer): Args: message_listener: A MessageListener implementation. The RepeatedCompositeFieldContainer will call this object's - TransitionToNonempty() method when it transitions from being empty to - being nonempty. + Modified() method when it is modified. message_descriptor: A Descriptor instance describing the protocol type that should be present in this container. We'll use the _concrete_class field of this descriptor when the client calls add(). @@ -199,10 +202,24 @@ class RepeatedCompositeFieldContainer(BaseContainer): new_element = self._message_descriptor._concrete_class() new_element._SetListener(self._message_listener) self._values.append(new_element) - self._message_listener.ByteSizeDirty() - self._message_listener.TransitionToNonempty() + if not self._message_listener.dirty: + self._message_listener.Modified() return new_element + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one, copying each individual message. + """ + message_class = self._message_descriptor._concrete_class + listener = self._message_listener + values = self._values + for message in other._values: + new_element = message_class() + new_element._SetListener(listener) + new_element.MergeFrom(message) + values.append(new_element) + listener.Modified() + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] @@ -210,12 +227,12 @@ class RepeatedCompositeFieldContainer(BaseContainer): def __delitem__(self, key): """Deletes the item at the specified position.""" del self._values[key] - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __delslice__(self, start, stop): """Deletes the subset of items from between the specified indices.""" del self._values[start:stop] - self._message_listener.ByteSizeDirty() + self._message_listener.Modified() def __eq__(self, other): """Compares the current instance with another one.""" diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 83d6fe0..461a30c 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,182 +28,614 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Class for decoding protocol buffer primitives. - -Contains the logic for decoding every logical protocol field type -from one of the 5 physical wire types. +"""Code for decoding protocol buffer primitives. + +This code is very similar to encoder.py -- read the docs for that module first. + +A "decoder" is a function with the signature: + Decode(buffer, pos, end, message, field_dict) +The arguments are: + buffer: The string containing the encoded message. + pos: The current position in the string. + end: The position in the string where the current message ends. May be + less than len(buffer) if we're reading a sub-message. + message: The message object into which we're parsing. + field_dict: message._fields (avoids a hashtable lookup). +The decoder reads the field and stores it into field_dict, returning the new +buffer position. A decoder for a repeated field may proactively decode all of +the elements of that field, if they appear consecutively. + +Note that decoders may throw any of the following: + IndexError: Indicates a truncated message. + struct.error: Unpacking of a fixed-width field failed. + message.DecodeError: Other errors. + +Decoders are expected to raise an exception if they are called with pos > end. +This allows callers to be lax about bounds checking: it's fineto read past +"end" as long as you are sure that someone else will notice and throw an +exception later on. + +Something up the call stack is expected to catch IndexError and struct.error +and convert them to message.DecodeError. + +Decoders are constructed using decoder constructors with the signature: + MakeDecoder(field_number, is_repeated, is_packed, key, new_default) +The arguments are: + field_number: The field number of the field we want to decode. + is_repeated: Is the field a repeated field? (bool) + is_packed: Is the field a packed field? (bool) + key: The key to use when looking up the field within field_dict. + (This is actually the FieldDescriptor but nothing in this + file should depend on that.) + new_default: A function which takes a message object as a parameter and + returns a new instance of the default value for this field. + (This is called for repeated fields and sub-messages, when an + instance does not already exist.) + +As with encoders, we define a decoder constructor for every type of field. +Then, for every field of every message class we construct an actual decoder. +That decoder goes into a dict indexed by tag, so when we decode a message +we repeatedly read a tag, look up the corresponding decoder, and invoke it. """ -__author__ = 'robinson@google.com (Will Robinson)' +__author__ = 'kenton@google.com (Kenton Varda)' import struct -from google.protobuf import message -from google.protobuf.internal import input_stream +from google.protobuf.internal import encoder from google.protobuf.internal import wire_format +from google.protobuf import message +# This is not for optimization, but rather to avoid conflicts with local +# variables named "message". +_DecodeError = message.DecodeError + + +def _VarintDecoder(mask): + """Return an encoder for a basic varint value (does not include tag). + + Decoded values will be bitwise-anded with the given mask before being + returned, e.g. to limit them to 32 bits. The returned decoder does not + take the usual "end" parameter -- the caller is expected to do bounds checking + after the fact (often the caller can defer such checking until later). The + decoder returns a (value, new_pos) pair. + """ + + local_ord = ord + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = local_ord(buffer[pos]) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + result &= mask + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + + +def _SignedVarintDecoder(mask): + """Like _VarintDecoder() but decodes signed values.""" + + local_ord = ord + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = local_ord(buffer[pos]) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + if result > 0x7fffffffffffffff: + result -= (1 << 64) + result |= ~mask + else: + result &= mask + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + + +_DecodeVarint = _VarintDecoder((1 << 64) - 1) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) + +# Use these versions for values which must be limited to 32 bits. +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) + + +def ReadTag(buffer, pos): + """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. + + We return the raw bytes of the tag rather than decoding them. The raw + bytes can then be used to look up the proper decoder. This effectively allows + us to trade some work that would be done in pure-python (decoding a varint) + for work that is done in C (searching for a byte string in a hash table). + In a low-level language it would be much cheaper to decode the varint and + use that, but not in Python. + """ + + start = pos + while ord(buffer[pos]) & 0x80: + pos += 1 + pos += 1 + return (buffer[start:pos], pos) + + +# -------------------------------------------------------------------- + + +def _SimpleDecoder(wire_type, decode_value): + """Return a constructor for a decoder for fields of a particular type. + + Args: + wire_type: The field's wire type. + decode_value: A function which decodes an individual value, e.g. + _DecodeVarint() + """ + + def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + (element, pos) = decode_value(buffer, pos) + value.append(element) + if pos > endpoint: + del value[-1] # Discard corrupt value. + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_type) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = decode_value(buffer, pos) + value.append(element) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (field_dict[key], pos) = decode_value(buffer, pos) + if pos > end: + del field_dict[key] # Discard corrupt value. + raise _DecodeError('Truncated message.') + return pos + return DecodeField + + return SpecificDecoder + + +def _ModifiedDecoder(wire_type, decode_value, modify_value): + """Like SimpleDecoder but additionally invokes modify_value on every value + before storing it. Usually modify_value is ZigZagDecode. + """ + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + def InnerDecode(buffer, pos): + (result, new_pos) = decode_value(buffer, pos) + return (modify_value(result), new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +def _StructPackDecoder(wire_type, format): + """Return a constructor for a decoder for a fixed-width field. + + Args: + wire_type: The field's wire type. + format: The format string to pass to struct.unpack(). + """ + + value_size = struct.calcsize(format) + local_unpack = struct.unpack + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + + def InnerDecode(buffer, pos): + new_pos = pos + value_size + result = local_unpack(format, buffer[pos:new_pos])[0] + return (result, new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +# -------------------------------------------------------------------- + + +Int32Decoder = EnumDecoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) + +Int64Decoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) + +UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) +UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) + +SInt32Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) +SInt64Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d') + +BoolDecoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) + + +def StringDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a string field.""" + + local_DecodeVarint = _DecodeVarint + local_unicode = unicode + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + return new_pos + return DecodeField + + +def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a bytes field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(buffer[pos:new_pos]) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + field_dict[key] = buffer[pos:new_pos] + return new_pos + return DecodeField + + +def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a group field.""" + + end_tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_END_GROUP) + end_tag_len = len(end_tag_bytes) + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_START_GROUP) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value.add()._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + return new_pos + return DecodeField + + +def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a message field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + return new_pos + return DecodeField + + +# -------------------------------------------------------------------- + +MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) + +def MessageSetItemDecoder(extensions_by_number): + """Returns a decoder for a MessageSet item. + + The parameter is the _extensions_by_number map for the message class. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + + type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) + message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) + item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) + + local_ReadTag = ReadTag + local_DecodeVarint = _DecodeVarint + local_SkipField = SkipField + + def DecodeItem(buffer, pos, end, message, field_dict): + type_id = -1 + message_start = -1 + message_end = -1 + + # Technically, type_id and message can appear in any order, so we need + # a little loop here. + while 1: + (tag_bytes, pos) = local_ReadTag(buffer, pos) + if tag_bytes == type_id_tag_bytes: + (type_id, pos) = local_DecodeVarint(buffer, pos) + elif tag_bytes == message_tag_bytes: + (size, message_start) = local_DecodeVarint(buffer, pos) + pos = message_end = message_start + size + elif tag_bytes == item_end_tag_bytes: + break + else: + pos = SkipField(buffer, pos, end, tag_bytes) + if pos == -1: + raise _DecodeError('Missing group end tag.') + + if pos > end: + raise _DecodeError('Truncated message.') + + if type_id == -1: + raise _DecodeError('MessageSet item missing type_id.') + if message_start == -1: + raise _DecodeError('MessageSet item missing message.') + + extension = extensions_by_number.get(type_id) + if extension is not None: + value = field_dict.get(extension) + if value is None: + value = field_dict.setdefault( + extension, extension.message_type._concrete_class()) + if value._InternalParse(buffer, message_start,message_end) != message_end: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + return pos + + return DecodeItem + +# -------------------------------------------------------------------- +# Optimization is not as heavy here because calls to SkipField() are rare, +# except for handling end-group tags. + +def _SkipVarint(buffer, pos, end): + """Skip a varint value. Returns the new position.""" + + while ord(buffer[pos]) & 0x80: + pos += 1 + pos += 1 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipFixed64(buffer, pos, end): + """Skip a fixed64 value. Returns the new position.""" + + pos += 8 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipLengthDelimited(buffer, pos, end): + """Skip a length-delimited value. Returns the new position.""" + + (size, pos) = _DecodeVarint(buffer, pos) + pos += size + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipGroup(buffer, pos, end): + """Skip sub-group. Returns the new position.""" + + while 1: + (tag_bytes, pos) = ReadTag(buffer, pos) + new_pos = SkipField(buffer, pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + +def _EndGroup(buffer, pos, end): + """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" + + return -1 + +def _SkipFixed32(buffer, pos, end): + """Skip a fixed32 value. Returns the new position.""" + + pos += 4 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _RaiseInvalidWireType(buffer, pos, end): + """Skip function for unknown wire types. Raises an exception.""" + + raise _DecodeError('Tag had invalid wire type.') + +def _FieldSkipper(): + """Constructs the SkipField function.""" + + WIRETYPE_TO_SKIPPER = [ + _SkipVarint, + _SkipFixed64, + _SkipLengthDelimited, + _SkipGroup, + _EndGroup, + _SkipFixed32, + _RaiseInvalidWireType, + _RaiseInvalidWireType, + ] + + wiretype_mask = wire_format.TAG_TYPE_MASK + local_ord = ord + + def SkipField(buffer, pos, end, tag_bytes): + """Skips a field with the specified tag. + + |pos| should point to the byte immediately after the tag. + + Returns: + The new position (after the tag value), or -1 if the tag is an end-group + tag (in which case the calling loop should break). + """ -# Note that much of this code is ported from //net/proto/ProtocolBuffer, and -# that the interface is strongly inspired by WireFormat from the C++ proto2 -# implementation. - - -class Decoder(object): - - """Decodes logical protocol buffer fields from the wire.""" + # The wire type is always in the first byte since varints are little-endian. + wire_type = local_ord(tag_bytes[0]) & wiretype_mask + return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) - def __init__(self, s): - """Initializes the decoder to read from s. + return SkipField - Args: - s: An immutable sequence of bytes, which must be accessible - via the Python buffer() primitive (i.e., buffer(s)). - """ - self._stream = input_stream.InputStream(s) - - def EndOfStream(self): - """Returns true iff we've reached the end of the bytes we're reading.""" - return self._stream.EndOfStream() - - def Position(self): - """Returns the 0-indexed position in |s|.""" - return self._stream.Position() - - def ReadFieldNumberAndWireType(self): - """Reads a tag from the wire. Returns a (field_number, wire_type) pair.""" - tag_and_type = self.ReadUInt32() - return wire_format.UnpackTag(tag_and_type) - - def SkipBytes(self, bytes): - """Skips the specified number of bytes on the wire.""" - self._stream.SkipBytes(bytes) - - # Note that the Read*() methods below are not exactly symmetrical with the - # corresponding Encoder.Append*() methods. Those Encoder methods first - # encode a tag, but the Read*() methods below assume that the tag has already - # been read, and that the client wishes to read a field of the specified type - # starting at the current position. - - def ReadInt32(self): - """Reads and returns a signed, varint-encoded, 32-bit integer.""" - return self._stream.ReadVarint32() - - def ReadInt64(self): - """Reads and returns a signed, varint-encoded, 64-bit integer.""" - return self._stream.ReadVarint64() - - def ReadUInt32(self): - """Reads and returns an signed, varint-encoded, 32-bit integer.""" - return self._stream.ReadVarUInt32() - - def ReadUInt64(self): - """Reads and returns an signed, varint-encoded,64-bit integer.""" - return self._stream.ReadVarUInt64() - - def ReadSInt32(self): - """Reads and returns a signed, zigzag-encoded, varint-encoded, - 32-bit integer.""" - return wire_format.ZigZagDecode(self._stream.ReadVarUInt32()) - - def ReadSInt64(self): - """Reads and returns a signed, zigzag-encoded, varint-encoded, - 64-bit integer.""" - return wire_format.ZigZagDecode(self._stream.ReadVarUInt64()) - - def ReadFixed32(self): - """Reads and returns an unsigned, fixed-width, 32-bit integer.""" - return self._stream.ReadLittleEndian32() - - def ReadFixed64(self): - """Reads and returns an unsigned, fixed-width, 64-bit integer.""" - return self._stream.ReadLittleEndian64() - - def ReadSFixed32(self): - """Reads and returns a signed, fixed-width, 32-bit integer.""" - value = self._stream.ReadLittleEndian32() - if value >= (1 << 31): - value -= (1 << 32) - return value - - def ReadSFixed64(self): - """Reads and returns a signed, fixed-width, 64-bit integer.""" - value = self._stream.ReadLittleEndian64() - if value >= (1 << 63): - value -= (1 << 64) - return value - - def ReadFloat(self): - """Reads and returns a 4-byte floating-point number.""" - serialized = self._stream.ReadBytes(4) - return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0] - - def ReadDouble(self): - """Reads and returns an 8-byte floating-point number.""" - serialized = self._stream.ReadBytes(8) - return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0] - - def ReadBool(self): - """Reads and returns a bool.""" - i = self._stream.ReadVarUInt32() - return bool(i) - - def ReadEnum(self): - """Reads and returns an enum value.""" - return self._stream.ReadVarUInt32() - - def ReadString(self): - """Reads and returns a length-delimited string.""" - bytes = self.ReadBytes() - return unicode(bytes, 'utf-8') - - def ReadBytes(self): - """Reads and returns a length-delimited byte sequence.""" - length = self._stream.ReadVarUInt32() - return self._stream.ReadBytes(length) - - def ReadMessageInto(self, msg): - """Calls msg.MergeFromString() to merge - length-delimited serialized message data into |msg|. - - REQUIRES: The decoder must be positioned at the serialized "length" - prefix to a length-delmiited serialized message. - - POSTCONDITION: The decoder is positioned just after the - serialized message, and we have merged those serialized - contents into |msg|. - """ - length = self._stream.ReadVarUInt32() - sub_buffer = self._stream.GetSubBuffer(length) - num_bytes_used = msg.MergeFromString(sub_buffer) - if num_bytes_used != length: - raise message.DecodeError( - 'Submessage told to deserialize from %d-byte encoding, ' - 'but used only %d bytes' % (length, num_bytes_used)) - self._stream.SkipBytes(num_bytes_used) - - def ReadGroupInto(self, expected_field_number, group): - """Calls group.MergeFromString() to merge - END_GROUP-delimited serialized message data into |group|. - We'll raise an exception if we don't find an END_GROUP - tag immediately after the serialized message contents. - - REQUIRES: The decoder is positioned just after the START_GROUP - tag for this group. - - POSTCONDITION: The decoder is positioned just after the - END_GROUP tag for this group, and we have merged - the contents of the group into |group|. - """ - sub_buffer = self._stream.GetSubBuffer() # No a priori length limit. - num_bytes_used = group.MergeFromString(sub_buffer) - if num_bytes_used < 0: - raise message.DecodeError('Group message reported negative bytes read.') - self._stream.SkipBytes(num_bytes_used) - field_number, field_type = self.ReadFieldNumberAndWireType() - if field_type != wire_format.WIRETYPE_END_GROUP: - raise message.DecodeError('Group message did not end with an END_GROUP.') - if field_number != expected_field_number: - raise message.DecodeError('END_GROUP tag had field ' - 'number %d, was expecting field number %d' % ( - field_number, expected_field_number)) - # We're now positioned just after the END_GROUP tag. Perfect. +SkipField = _FieldSkipper() diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py deleted file mode 100755 index 98e4647..0000000 --- a/python/google/protobuf/internal/decoder_test.py +++ /dev/null @@ -1,256 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test for google.protobuf.internal.decoder.""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import struct -import unittest -from google.protobuf.internal import decoder -from google.protobuf.internal import encoder -from google.protobuf.internal import input_stream -from google.protobuf.internal import wire_format -from google.protobuf import message -import logging -import mox - - -class DecoderTest(unittest.TestCase): - - def setUp(self): - self.mox = mox.Mox() - self.mock_stream = self.mox.CreateMock(input_stream.InputStream) - self.mock_message = self.mox.CreateMock(message.Message) - - def testReadFieldNumberAndWireType(self): - # Test field numbers that will require various varint sizes. - for expected_field_number in (1, 15, 16, 2047, 2048): - for expected_wire_type in range(6): # Highest-numbered wiretype is 5. - e = encoder.Encoder() - e.AppendTag(expected_field_number, expected_wire_type) - s = e.ToString() - d = decoder.Decoder(s) - field_number, wire_type = d.ReadFieldNumberAndWireType() - self.assertEqual(expected_field_number, field_number) - self.assertEqual(expected_wire_type, wire_type) - - def ReadScalarTestHelper(self, test_name, decoder_method, expected_result, - expected_stream_method_name, - stream_method_return, *args): - """Helper for testReadScalars below. - - Calls one of the Decoder.Read*() methods and ensures that the results are - as expected. - - Args: - test_name: Name of this test, used for logging only. - decoder_method: Unbound decoder.Decoder method to call. - expected_result: Value we expect returned from decoder_method(). - expected_stream_method_name: (string) Name of the InputStream - method we expect Decoder to call to actually read the value - on the wire. - stream_method_return: Value our mocked-out stream method should - return to the decoder. - args: Additional arguments that we expect to be passed to the - stream method. - """ - logging.info('Testing %s scalar input.\n' - 'Calling %r(), and expecting that to call the ' - 'stream method %s(%r), which will return %r. Finally, ' - 'expecting the Decoder method to return %r'% ( - test_name, decoder_method, - expected_stream_method_name, args, stream_method_return, - expected_result)) - - d = decoder.Decoder('') - d._stream = self.mock_stream - if decoder_method in (decoder.Decoder.ReadString, - decoder.Decoder.ReadBytes): - self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return)) - # We have to use names instead of methods to work around some - # mox weirdness. (ResetAll() is overzealous). - expected_stream_method = getattr(self.mock_stream, - expected_stream_method_name) - expected_stream_method(*args).AndReturn(stream_method_return) - - self.mox.ReplayAll() - result = decoder_method(d) - self.assertEqual(expected_result, result) - self.assert_(isinstance(result, type(expected_result))) - self.mox.VerifyAll() - self.mox.ResetAll() - - VAL = 1.125 # Perfectly representable as a float (no rounding error). - LITTLE_FLOAT_VAL = '\x00\x00\x90?' - LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?' - - def testReadScalars(self): - test_string = 'I can feel myself getting sutpider.' - scalar_tests = [ - ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0], - ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0], - ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0], - ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0], - ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff, - 'ReadLittleEndian32', 0xffffffff], - ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff, - 'ReadLittleEndian64', 0xffffffffffffffff], - ['sfixed32', decoder.Decoder.ReadSFixed32, long(-1), - 'ReadLittleEndian32', long(0xffffffff)], - ['sfixed64', decoder.Decoder.ReadSFixed64, long(-1), - 'ReadLittleEndian64', 0xffffffffffffffff], - ['float', decoder.Decoder.ReadFloat, self.VAL, - 'ReadBytes', self.LITTLE_FLOAT_VAL, 4], - ['double', decoder.Decoder.ReadDouble, self.VAL, - 'ReadBytes', self.LITTLE_DOUBLE_VAL, 8], - ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1], - ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23], - ['string', decoder.Decoder.ReadString, - unicode(test_string, 'utf-8'), 'ReadBytes', test_string, - len(test_string)], - ['utf8-string', decoder.Decoder.ReadString, - unicode(test_string, 'utf-8'), 'ReadBytes', test_string, - len(test_string)], - ['bytes', decoder.Decoder.ReadBytes, - test_string, 'ReadBytes', test_string, len(test_string)], - # We test zigzag decoding routines more extensively below. - ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1], - ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1], - ] - # Ensure that we're testing different Decoder methods and using - # different test names in all test cases above. - self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) - self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests))) - for args in scalar_tests: - self.ReadScalarTestHelper(*args) - - def testReadMessageInto(self): - length = 23 - def Test(simulate_error): - d = decoder.Decoder('') - d._stream = self.mock_stream - self.mock_stream.ReadVarUInt32().AndReturn(length) - sub_buffer = object() - self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer) - - if simulate_error: - self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1) - self.mox.ReplayAll() - self.assertRaises( - message.DecodeError, d.ReadMessageInto, self.mock_message) - else: - self.mock_message.MergeFromString(sub_buffer).AndReturn(length) - self.mock_stream.SkipBytes(length) - self.mox.ReplayAll() - d.ReadMessageInto(self.mock_message) - - self.mox.VerifyAll() - self.mox.ResetAll() - - Test(simulate_error=False) - Test(simulate_error=True) - - def testReadGroupInto_Success(self): - # Test both the empty and nonempty cases. - for num_bytes in (5, 0): - field_number = expected_field_number = 10 - d = decoder.Decoder('') - d._stream = self.mock_stream - sub_buffer = object() - self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) - self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes) - self.mock_stream.SkipBytes(num_bytes) - self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( - field_number, wire_format.WIRETYPE_END_GROUP)) - self.mox.ReplayAll() - d.ReadGroupInto(expected_field_number, self.mock_message) - self.mox.VerifyAll() - self.mox.ResetAll() - - def ReadGroupInto_FailureTestHelper(self, bytes_read): - d = decoder.Decoder('') - d._stream = self.mock_stream - sub_buffer = object() - self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) - self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read) - return d - - def testReadGroupInto_NegativeBytesReported(self): - expected_field_number = 10 - d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1) - self.mox.ReplayAll() - self.assertRaises(message.DecodeError, - d.ReadGroupInto, expected_field_number, - self.mock_message) - self.mox.VerifyAll() - - def testReadGroupInto_NoEndGroupTag(self): - field_number = expected_field_number = 10 - num_bytes = 5 - d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) - self.mock_stream.SkipBytes(num_bytes) - # Right field number, wrong wire type. - self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( - field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) - self.mox.ReplayAll() - self.assertRaises(message.DecodeError, - d.ReadGroupInto, expected_field_number, - self.mock_message) - self.mox.VerifyAll() - - def testReadGroupInto_WrongFieldNumberInEndGroupTag(self): - expected_field_number = 10 - field_number = expected_field_number + 1 - num_bytes = 5 - d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) - self.mock_stream.SkipBytes(num_bytes) - # Wrong field number, right wire type. - self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( - field_number, wire_format.WIRETYPE_END_GROUP)) - self.mox.ReplayAll() - self.assertRaises(message.DecodeError, - d.ReadGroupInto, expected_field_number, - self.mock_message) - self.mox.VerifyAll() - - def testSkipBytes(self): - d = decoder.Decoder('') - num_bytes = 1024 - self.mock_stream.SkipBytes(num_bytes) - d._stream = self.mock_stream - self.mox.ReplayAll() - d.SkipBytes(num_bytes) - self.mox.VerifyAll() - -if __name__ == '__main__': - unittest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index eb9f2be..05c2745 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -35,16 +35,30 @@ __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf import descriptor +from google.protobuf import text_format + + +TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """ +name: 'TestEmptyMessage' +""" + class DescriptorTest(unittest.TestCase): def setUp(self): + self.my_file = descriptor.FileDescriptor( + name='some/filename/some.proto', + package='protobuf_unittest' + ) self.my_enum = descriptor.EnumDescriptor( name='ForeignEnum', full_name='protobuf_unittest.ForeignEnum', - filename='ForeignEnum', + filename=None, + file=self.my_file, values=[ descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), @@ -53,7 +67,8 @@ class DescriptorTest(unittest.TestCase): self.my_message = descriptor.Descriptor( name='NestedMessage', full_name='protobuf_unittest.TestAllTypes.NestedMessage', - filename='some/filename/some.proto', + filename=None, + file=self.my_file, containing_type=None, fields=[ descriptor.FieldDescriptor( @@ -61,7 +76,7 @@ class DescriptorTest(unittest.TestCase): full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', index=0, number=1, type=5, cpp_type=1, label=1, - default_value=0, + has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None), ], @@ -80,6 +95,7 @@ class DescriptorTest(unittest.TestCase): self.my_service = descriptor.ServiceDescriptor( name='TestServiceWithOptions', full_name='protobuf_unittest.TestServiceWithOptions', + file=self.my_file, index=0, methods=[ self.my_method @@ -109,5 +125,210 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_service.GetOptions(), descriptor_pb2.ServiceOptions()) + def testFileDescriptorReferences(self): + self.assertEqual(self.my_enum.file, self.my_file) + self.assertEqual(self.my_message.file, self.my_file) + + def testFileDescriptor(self): + self.assertEqual(self.my_file.name, 'some/filename/some.proto') + self.assertEqual(self.my_file.package, 'protobuf_unittest') + + +class DescriptorCopyToProtoTest(unittest.TestCase): + """Tests for CopyTo functions of Descriptor.""" + + def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): + expected_proto = expected_class() + text_format.Merge(expected_ascii, expected_proto) + + self.assertEqual( + actual_proto, expected_proto, + 'Not equal,\nActual:\n%s\nExpected:\n%s\n' + % (str(actual_proto), str(expected_proto))) + + def _InternalTestCopyToProto(self, desc, expected_proto_class, + expected_proto_ascii): + actual = expected_proto_class() + desc.CopyToProto(actual) + self._AssertProtoEqual( + actual, expected_proto_class, expected_proto_ascii) + + def testCopyToProto_EmptyMessage(self): + self._InternalTestCopyToProto( + unittest_pb2.TestEmptyMessage.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII) + + def testCopyToProto_NestedMessage(self): + TEST_NESTED_MESSAGE_ASCII = """ + name: 'NestedMessage' + field: < + name: 'bb' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_NESTED_MESSAGE_ASCII) + + def testCopyToProto_ForeignNestedMessage(self): + TEST_FOREIGN_NESTED_ASCII = """ + name: 'TestForeignNested' + field: < + name: 'foreign_nested' + number: 1 + label: 1 # Optional + type: 11 # TYPE_MESSAGE + type_name: '.protobuf_unittest.TestAllTypes.NestedMessage' + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestForeignNested.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_FOREIGN_NESTED_ASCII) + + def testCopyToProto_ForeignEnum(self): + TEST_FOREIGN_ENUM_ASCII = """ + name: 'ForeignEnum' + value: < + name: 'FOREIGN_FOO' + number: 4 + > + value: < + name: 'FOREIGN_BAR' + number: 5 + > + value: < + name: 'FOREIGN_BAZ' + number: 6 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2._FOREIGNENUM, + descriptor_pb2.EnumDescriptorProto, + TEST_FOREIGN_ENUM_ASCII) + + def testCopyToProto_Options(self): + TEST_DEPRECATED_FIELDS_ASCII = """ + name: 'TestDeprecatedFields' + field: < + name: 'deprecated_int32' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + options: < + deprecated: true + > + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestDeprecatedFields.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_DEPRECATED_FIELDS_ASCII) + + def testCopyToProto_AllExtensions(self): + TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """ + name: 'TestEmptyMessageWithExtensions' + extension_range: < + start: 1 + end: 536870912 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII) + + def testCopyToProto_SeveralExtensions(self): + TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """ + name: 'TestMultipleExtensionRanges' + extension_range: < + start: 42 + end: 43 + > + extension_range: < + start: 4143 + end: 4244 + > + extension_range: < + start: 65536 + end: 536870912 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) + + def testCopyToProto_FileDescriptor(self): + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + name: 'google/protobuf/unittest_import.proto' + package: 'protobuf_unittest_import' + message_type: < + name: 'ImportMessage' + field: < + name: 'd' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + > + > + """ + + """enum_type: < + name: 'ImportEnum' + value: < + name: 'IMPORT_FOO' + number: 7 + > + value: < + name: 'IMPORT_BAR' + number: 8 + > + value: < + name: 'IMPORT_BAZ' + number: 9 + > + > + options: < + java_package: 'com.google.protobuf.test' + optimize_for: 1 # SPEED + > + """) + + self._InternalTestCopyToProto( + unittest_import_pb2.DESCRIPTOR, + descriptor_pb2.FileDescriptorProto, + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + + def testCopyToProto_ServiceDescriptor(self): + TEST_SERVICE_ASCII = """ + name: 'TestService' + method: < + name: 'Foo' + input_type: '.protobuf_unittest.FooRequest' + output_type: '.protobuf_unittest.FooResponse' + > + method: < + name: 'Bar' + input_type: '.protobuf_unittest.BarRequest' + output_type: '.protobuf_unittest.BarResponse' + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestService.DESCRIPTOR, + descriptor_pb2.ServiceDescriptorProto, + TEST_SERVICE_ASCII) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 3ec3b2b..aa05d5b 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,253 +28,659 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Class for encoding protocol message primitives. +"""Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type into one of the 5 physical wire types. + +This code is designed to push the Python interpreter's performance to the +limits. + +The basic idea is that at startup time, for every field (i.e. every +FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The +sizer takes a value of this field's type and computes its byte size. The +encoder takes a writer function and a value. It encodes the value into byte +strings and invokes the writer function to write those strings. Typically the +writer function is the write() method of a cStringIO. + +We try to do as much work as possible when constructing the writer and the +sizer rather than when calling them. In particular: +* We copy any needed global functions to local variables, so that we do not need + to do costly global table lookups at runtime. +* Similarly, we try to do any attribute lookups at startup time if possible. +* Every field's tag is encoded to bytes at startup, since it can't change at + runtime. +* Whatever component of the field size we can compute at startup, we do. +* We *avoid* sharing code if doing so would make the code slower and not sharing + does not burden us too much. For example, encoders for repeated fields do + not just call the encoders for singular fields in a loop because this would + add an extra function call overhead for every loop iteration; instead, we + manually inline the single-value encoder into the loop. +* If a Python function lacks a return statement, Python actually generates + instructions to pop the result of the last statement off the stack, push + None onto the stack, and then return that. If we really don't care what + value is returned, then we can save two instructions by returning the + result of the last statement. It looks funny but it helps. +* We assume that type and bounds checking has happened at a higher level. """ -__author__ = 'robinson@google.com (Will Robinson)' +__author__ = 'kenton@google.com (Kenton Varda)' import struct -from google.protobuf import message from google.protobuf.internal import wire_format -from google.protobuf.internal import output_stream - - -# Note that much of this code is ported from //net/proto/ProtocolBuffer, and -# that the interface is strongly inspired by WireFormat from the C++ proto2 -# implementation. - - -class Encoder(object): - - """Encodes logical protocol buffer fields to the wire format.""" - - def __init__(self): - self._stream = output_stream.OutputStream() - - def ToString(self): - """Returns all values encoded in this object as a string.""" - return self._stream.ToString() - - # Append*NoTag methods. These are necessary for serializing packed - # repeated fields. The Append*() methods call these methods to do - # the actual serialization. - def AppendInt32NoTag(self, value): - """Appends a 32-bit integer to our buffer, varint-encoded.""" - self._stream.AppendVarint32(value) - - def AppendInt64NoTag(self, value): - """Appends a 64-bit integer to our buffer, varint-encoded.""" - self._stream.AppendVarint64(value) - - def AppendUInt32NoTag(self, unsigned_value): - """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" - self._stream.AppendVarUInt32(unsigned_value) - - def AppendUInt64NoTag(self, unsigned_value): - """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" - self._stream.AppendVarUInt64(unsigned_value) - - def AppendSInt32NoTag(self, value): - """Appends a 32-bit integer to our buffer, zigzag-encoded and then - varint-encoded. - """ - zigzag_value = wire_format.ZigZagEncode(value) - self._stream.AppendVarUInt32(zigzag_value) - - def AppendSInt64NoTag(self, value): - """Appends a 64-bit integer to our buffer, zigzag-encoded and then - varint-encoded. - """ - zigzag_value = wire_format.ZigZagEncode(value) - self._stream.AppendVarUInt64(zigzag_value) - - def AppendFixed32NoTag(self, unsigned_value): - """Appends an unsigned 32-bit integer to our buffer, in little-endian - byte-order. - """ - self._stream.AppendLittleEndian32(unsigned_value) - - def AppendFixed64NoTag(self, unsigned_value): - """Appends an unsigned 64-bit integer to our buffer, in little-endian - byte-order. - """ - self._stream.AppendLittleEndian64(unsigned_value) - - def AppendSFixed32NoTag(self, value): - """Appends a signed 32-bit integer to our buffer, in little-endian - byte-order. - """ - sign = (value & 0x80000000) and -1 or 0 - if value >> 32 != sign: - raise message.EncodeError('SFixed32 out of range: %d' % value) - self._stream.AppendLittleEndian32(value & 0xffffffff) - - def AppendSFixed64NoTag(self, value): - """Appends a signed 64-bit integer to our buffer, in little-endian - byte-order. - """ - sign = (value & 0x8000000000000000) and -1 or 0 - if value >> 64 != sign: - raise message.EncodeError('SFixed64 out of range: %d' % value) - self._stream.AppendLittleEndian64(value & 0xffffffffffffffff) - - def AppendFloatNoTag(self, value): - """Appends a floating-point number to our buffer.""" - self._stream.AppendRawBytes( - struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value)) - - def AppendDoubleNoTag(self, value): - """Appends a double-precision floating-point number to our buffer.""" - self._stream.AppendRawBytes( - struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value)) - - def AppendBoolNoTag(self, value): - """Appends a boolean to our buffer.""" - self.AppendInt32NoTag(value) - - def AppendEnumNoTag(self, value): - """Appends an enum value to our buffer.""" - self.AppendInt32NoTag(value) - - - # All the Append*() methods below first append a tag+type pair to the buffer - # before appending the specified value. - - def AppendInt32(self, field_number, value): - """Appends a 32-bit integer to our buffer, varint-encoded.""" - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendInt32NoTag(value) - - def AppendInt64(self, field_number, value): - """Appends a 64-bit integer to our buffer, varint-encoded.""" - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendInt64NoTag(value) - - def AppendUInt32(self, field_number, unsigned_value): - """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendUInt32NoTag(unsigned_value) - - def AppendUInt64(self, field_number, unsigned_value): - """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendUInt64NoTag(unsigned_value) - - def AppendSInt32(self, field_number, value): - """Appends a 32-bit integer to our buffer, zigzag-encoded and then - varint-encoded. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendSInt32NoTag(value) - - def AppendSInt64(self, field_number, value): - """Appends a 64-bit integer to our buffer, zigzag-encoded and then - varint-encoded. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) - self.AppendSInt64NoTag(value) - - def AppendFixed32(self, field_number, unsigned_value): - """Appends an unsigned 32-bit integer to our buffer, in little-endian - byte-order. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) - self.AppendFixed32NoTag(unsigned_value) - - def AppendFixed64(self, field_number, unsigned_value): - """Appends an unsigned 64-bit integer to our buffer, in little-endian - byte-order. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) - self.AppendFixed64NoTag(unsigned_value) - - def AppendSFixed32(self, field_number, value): - """Appends a signed 32-bit integer to our buffer, in little-endian - byte-order. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) - self.AppendSFixed32NoTag(value) - - def AppendSFixed64(self, field_number, value): - """Appends a signed 64-bit integer to our buffer, in little-endian - byte-order. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) - self.AppendSFixed64NoTag(value) - - def AppendFloat(self, field_number, value): - """Appends a floating-point number to our buffer.""" - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) - self.AppendFloatNoTag(value) - - def AppendDouble(self, field_number, value): - """Appends a double-precision floating-point number to our buffer.""" - self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) - self.AppendDoubleNoTag(value) - - def AppendBool(self, field_number, value): - """Appends a boolean to our buffer.""" - self.AppendInt32(field_number, value) - - def AppendEnum(self, field_number, value): - """Appends an enum value to our buffer.""" - self.AppendInt32(field_number, value) - - def AppendString(self, field_number, value): - """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer, - with the length varint-encoded. - """ - self.AppendBytes(field_number, value.encode('utf-8')) - - def AppendBytes(self, field_number, value): - """Appends a length-prefixed sequence of bytes to our buffer, with the - length varint-encoded. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) - self._stream.AppendVarUInt32(len(value)) - self._stream.AppendRawBytes(value) - - # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to - # avoid the extra string copy here. We can do so if we widen the Message - # interface to be able to serialize to a stream in addition to a string. The - # challenge when thinking ahead to the Python/C API implementation of Message - # is finding a stream-like Python thing to which we can write raw bytes - # from C. I'm not sure such a thing exists(?). (array.array is pretty much - # what we want, but it's not directly exposed in the Python/C API). - - def AppendGroup(self, field_number, group): - """Appends a group to our buffer. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) - self._stream.AppendRawBytes(group.SerializeToString()) - self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) - - def AppendMessage(self, field_number, msg): - """Appends a nested message to our buffer. - """ - self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) - self._stream.AppendVarUInt32(msg.ByteSize()) - self._stream.AppendRawBytes(msg.SerializeToString()) - - def AppendMessageSetItem(self, field_number, msg): - """Appends an item using the message set wire format. - - The message set message looks like this: - message MessageSet { - repeated group Item = 1 { - required int32 type_id = 2; - required string message = 3; - } + + +def _VarintSize(value): + """Compute the size of a varint value.""" + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _SignedVarintSize(value): + """Compute the size of a signed varint value.""" + if value < 0: return 10 + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _TagSize(field_number): + """Returns the number of bytes required to serialize a tag with this field + number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarintSize(wire_format.PackTag(field_number, 0)) + + +# -------------------------------------------------------------------- +# In this section we define some generic sizers. Each of these functions +# takes parameters specific to a particular field type, e.g. int32 or fixed64. +# It returns another function which in turn takes parameters specific to a +# particular field, e.g. the field number and whether it is repeated or packed. +# Look at the next section to see how these are used. + + +def _SimpleSizer(compute_value_size): + """A sizer which uses the function compute_value_size to compute the size of + each value. Typically compute_value_size is _VarintSize.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(element) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(element) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(value) + return FieldSize + + return SpecificSizer + + +def _ModifiedSizer(compute_value_size, modify_value): + """Like SimpleSizer, but modify_value is invoked on each value before it is + passed to compute_value_size. modify_value is typically ZigZagEncode.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(modify_value(element)) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(modify_value(element)) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(modify_value(value)) + return FieldSize + + return SpecificSizer + + +def _FixedSizer(value_size): + """Like _SimpleSizer except for a fixed-size field. The input is the size + of one value.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = len(value) * value_size + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + element_size = value_size + tag_size + def RepeatedFieldSize(value): + return len(value) * element_size + return RepeatedFieldSize + else: + field_size = value_size + tag_size + def FieldSize(value): + return field_size + return FieldSize + + return SpecificSizer + + +# ==================================================================== +# Here we declare a sizer constructor for each field type. Each "sizer +# constructor" is a function that takes (field_number, is_repeated, is_packed) +# as parameters and returns a sizer, which in turn takes a field value as +# a parameter and returns its encoded size. + + +Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize) + +UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize) + +SInt32Sizer = SInt64Sizer = _ModifiedSizer( + _SignedVarintSize, wire_format.ZigZagEncode) + +Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4) +Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8) + +BoolSizer = _FixedSizer(1) + + +def StringSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a string field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element.encode('utf-8')) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value.encode('utf-8')) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def BytesSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a bytes field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def GroupSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a group field.""" + + tag_size = _TagSize(field_number) * 2 + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += element.ByteSize() + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + value.ByteSize() + return FieldSize + + +def MessageSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a message field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = element.ByteSize() + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = value.ByteSize() + return tag_size + local_VarintSize(l) + l + return FieldSize + + +# -------------------------------------------------------------------- +# MessageSet is special. + + +def MessageSetItemSizer(field_number): + """Returns a sizer for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) + + _TagSize(3)) + local_VarintSize = _VarintSize + + def FieldSize(value): + l = value.ByteSize() + return static_size + local_VarintSize(l) + l + + return FieldSize + + +# ==================================================================== +# Encoders! + + +def _VarintEncoder(): + """Return an encoder for a basic varint value (does not include tag).""" + + local_chr = chr + def EncodeVarint(write, value): + bits = value & 0x7f + value >>= 7 + while value: + write(local_chr(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_chr(bits)) + + return EncodeVarint + + +def _SignedVarintEncoder(): + """Return an encoder for a basic signed varint value (does not include + tag).""" + + local_chr = chr + def EncodeSignedVarint(write, value): + if value < 0: + value += (1 << 64) + bits = value & 0x7f + value >>= 7 + while value: + write(local_chr(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_chr(bits)) + + return EncodeSignedVarint + + +_EncodeVarint = _VarintEncoder() +_EncodeSignedVarint = _SignedVarintEncoder() + + +def _VarintBytes(value): + """Encode the given integer as a varint and return the bytes. This is only + called at startup time so it doesn't need to be fast.""" + + pieces = [] + _EncodeVarint(pieces.append, value) + return "".join(pieces) + + +def TagBytes(field_number, wire_type): + """Encode the given tag and return the bytes. Only called at startup.""" + + return _VarintBytes(wire_format.PackTag(field_number, wire_type)) + +# -------------------------------------------------------------------- +# As with sizers (see above), we have a number of common encoder +# implementations. + + +def _SimpleEncoder(wire_type, encode_value, compute_value_size): + """Return a constructor for an encoder for fields of a particular type. + + Args: + wire_type: The field's wire type, for encoding tags. + encode_value: A function which encodes an individual value, e.g. + _EncodeVarint(). + compute_value_size: A function which computes the size of an individual + value, e.g. _VarintSize(). + """ + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(element) + local_EncodeVarint(write, size) + for element in value: + encode_value(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + encode_value(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return encode_value(write, value) + return EncodeField + + return SpecificEncoder + + +def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value): + """Like SimpleEncoder but additionally invokes modify_value on every value + before passing it to encode_value. Usually modify_value is ZigZagEncode.""" + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(modify_value(element)) + local_EncodeVarint(write, size) + for element in value: + encode_value(write, modify_value(element)) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + encode_value(write, modify_value(element)) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return encode_value(write, modify_value(value)) + return EncodeField + + return SpecificEncoder + + +def _StructPackEncoder(wire_type, format): + """Return a constructor for an encoder for a fixed-width field. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + value_size = struct.calcsize(format) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size) + for element in value: + write(local_struct_pack(format, element)) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + write(local_struct_pack(format, element)) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return write(local_struct_pack(format, value)) + return EncodeField + + return SpecificEncoder + + +# ==================================================================== +# Here we declare an encoder constructor for each field type. These work +# very similarly to sizer constructors, described earlier. + + +Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) + +UInt32Encoder = UInt64Encoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) + +SInt32Encoder = SInt64Encoder = _ModifiedEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize, + wire_format.ZigZagEncode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d') + + +def BoolEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a boolean field.""" + + false_byte = chr(0) + true_byte = chr(1) + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value)) + for element in value: + if element: + write(true_byte) + else: + write(false_byte) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + if element: + write(true_byte) + else: + write(false_byte) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeField(write, value): + write(tag_bytes) + if value: + return write(true_byte) + return write(false_byte) + return EncodeField + + +def StringEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a string field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + encoded = element.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded)) + write(encoded) + return EncodeRepeatedField + else: + def EncodeField(write, value): + encoded = value.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded)) + return write(encoded) + return EncodeField + + +def BytesEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a bytes field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(tag) + local_EncodeVarint(write, local_len(element)) + write(element) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(tag) + local_EncodeVarint(write, local_len(value)) + return write(value) + return EncodeField + + +def GroupEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a group field.""" + + start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) + end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(start_tag) + element._InternalSerialize(write) + write(end_tag) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(start_tag) + value._InternalSerialize(write) + return write(end_tag) + return EncodeField + + +def MessageEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a message field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(tag) + local_EncodeVarint(write, element.ByteSize()) + element._InternalSerialize(write) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(tag) + local_EncodeVarint(write, value.ByteSize()) + return value._InternalSerialize(write) + return EncodeField + + +# -------------------------------------------------------------------- +# As before, MessageSet is special. + + +def MessageSetItemEncoder(field_number): + """Encoder for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; } - """ - self.AppendTag(1, wire_format.WIRETYPE_START_GROUP) - self.AppendInt32(2, field_number) - self.AppendMessage(3, msg) - self.AppendTag(1, wire_format.WIRETYPE_END_GROUP) - - def AppendTag(self, field_number, wire_type): - """Appends a tag containing field number and wire type information.""" - self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) + } + """ + start_bytes = "".join([ + TagBytes(1, wire_format.WIRETYPE_START_GROUP), + TagBytes(2, wire_format.WIRETYPE_VARINT), + _VarintBytes(field_number), + TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)]) + end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP) + local_EncodeVarint = _EncodeVarint + + def EncodeField(write, value): + write(start_bytes) + local_EncodeVarint(write, value.ByteSize()) + value._InternalSerialize(write) + return write(end_bytes) + + return EncodeField diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py deleted file mode 100755 index bf75ea8..0000000 --- a/python/google/protobuf/internal/encoder_test.py +++ /dev/null @@ -1,286 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test for google.protobuf.internal.encoder.""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import struct -import logging -import unittest -from google.protobuf.internal import wire_format -from google.protobuf.internal import encoder -from google.protobuf.internal import output_stream -from google.protobuf import message -import mox - - -class EncoderTest(unittest.TestCase): - - def setUp(self): - self.mox = mox.Mox() - self.encoder = encoder.Encoder() - self.mock_stream = self.mox.CreateMock(output_stream.OutputStream) - self.mock_message = self.mox.CreateMock(message.Message) - self.encoder._stream = self.mock_stream - - def PackTag(self, field_number, wire_type): - return wire_format.PackTag(field_number, wire_type) - - def AppendScalarTestHelper(self, test_name, encoder_method, - expected_stream_method_name, - wire_type, field_value, - expected_value=None, expected_length=None, - is_tag_test=True): - """Helper for testAppendScalars. - - Calls one of the Encoder methods, and ensures that the Encoder - in turn makes the expected calls into its OutputStream. - - Args: - test_name: Name of this test, used only for logging. - encoder_method: Callable on self.encoder. This is the Encoder - method we're testing. If is_tag_test=True, the encoder method - accepts a field_number and field_value. if is_tag_test=False, - the encoder method accepts a field_value. - expected_stream_method_name: (string) Name of the OutputStream - method we expect Encoder to call to actually put the value - on the wire. - wire_type: The WIRETYPE_* constant we expect encoder to - use in the specified encoder_method. - field_value: The value we're trying to encode. Passed - into encoder_method. - expected_value: The value we expect Encoder to pass into - the OutputStream method. If None, we expect field_value - to pass through unmodified. - expected_length: The length we expect Encoder to pass to the - AppendVarUInt32 method. If None we expect the length of the - field_value. - is_tag_test: A Boolean. If True (the default), we append the - the packed field number and wire_type to the stream before - the field value. - """ - if expected_value is None: - expected_value = field_value - - logging.info('Testing %s scalar output.\n' - 'Calling %r(%r), and expecting that to call the ' - 'stream method %s(%r).' % ( - test_name, encoder_method, field_value, - expected_stream_method_name, expected_value)) - - if is_tag_test: - field_number = 10 - # Should first append the field number and type information. - self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type)) - # If we're length-delimited, we should then append the length. - if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: - if expected_length is None: - expected_length = len(field_value) - self.mock_stream.AppendVarUInt32(expected_length) - - # Should then append the value itself. - # We have to use names instead of methods to work around some - # mox weirdness. (ResetAll() is overzealous). - expected_stream_method = getattr(self.mock_stream, - expected_stream_method_name) - expected_stream_method(expected_value) - - self.mox.ReplayAll() - if is_tag_test: - encoder_method(field_number, field_value) - else: - encoder_method(field_value) - self.mox.VerifyAll() - self.mox.ResetAll() - - VAL = 1.125 # Perfectly representable as a float (no rounding error). - LITTLE_FLOAT_VAL = '\x00\x00\x90?' - LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?' - - def testAppendScalars(self): - utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82' - utf8_string = unicode(utf8_bytes, 'utf-8') - scalar_tests = [ - ['int32', self.encoder.AppendInt32, 'AppendVarint32', - wire_format.WIRETYPE_VARINT, 0], - ['int64', self.encoder.AppendInt64, 'AppendVarint64', - wire_format.WIRETYPE_VARINT, 0], - ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32', - wire_format.WIRETYPE_VARINT, 0], - ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64', - wire_format.WIRETYPE_VARINT, 0], - ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32', - wire_format.WIRETYPE_FIXED32, 0], - ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64', - wire_format.WIRETYPE_FIXED64, 0], - ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32', - wire_format.WIRETYPE_FIXED32, -1, 0xffffffff], - ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64', - wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff], - ['float', self.encoder.AppendFloat, 'AppendRawBytes', - wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL], - ['double', self.encoder.AppendDouble, 'AppendRawBytes', - wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL], - ['bool', self.encoder.AppendBool, 'AppendVarint32', - wire_format.WIRETYPE_VARINT, False], - ['enum', self.encoder.AppendEnum, 'AppendVarint32', - wire_format.WIRETYPE_VARINT, 0], - ['string', self.encoder.AppendString, 'AppendRawBytes', - wire_format.WIRETYPE_LENGTH_DELIMITED, - "You're in a maze of twisty little passages, all alike."], - ['utf8-string', self.encoder.AppendString, 'AppendRawBytes', - wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string, - utf8_bytes, len(utf8_bytes)], - # We test zigzag encoding routines more extensively below. - ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32', - wire_format.WIRETYPE_VARINT, -1, 1], - ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64', - wire_format.WIRETYPE_VARINT, -1, 1], - ] - # Ensure that we're testing different Encoder methods and using - # different test names in all test cases above. - self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) - self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests))) - for args in scalar_tests: - self.AppendScalarTestHelper(*args) - - def testAppendScalarsWithoutTags(self): - scalar_no_tag_tests = [ - ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0], - ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0], - ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0], - ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0], - ['fixed32', self.encoder.AppendFixed32NoTag, - 'AppendLittleEndian32', None, 0], - ['fixed64', self.encoder.AppendFixed64NoTag, - 'AppendLittleEndian64', None, 0], - ['sfixed32', self.encoder.AppendSFixed32NoTag, - 'AppendLittleEndian32', None, 0], - ['sfixed64', self.encoder.AppendSFixed64NoTag, - 'AppendLittleEndian64', None, 0], - ['float', self.encoder.AppendFloatNoTag, - 'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL], - ['double', self.encoder.AppendDoubleNoTag, - 'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL], - ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0], - ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0], - ['sint32', self.encoder.AppendSInt32NoTag, - 'AppendVarUInt32', None, -1, 1], - ['sint64', self.encoder.AppendSInt64NoTag, - 'AppendVarUInt64', None, -1, 1], - ] - - self.assertEqual(len(scalar_no_tag_tests), - len(set(t[0] for t in scalar_no_tag_tests))) - self.assert_(len(scalar_no_tag_tests) >= - len(set(t[1] for t in scalar_no_tag_tests))) - for args in scalar_no_tag_tests: - # For no tag tests, the wire_type is not used, so we put in None. - self.AppendScalarTestHelper(is_tag_test=False, *args) - - def testAppendGroup(self): - field_number = 23 - # Should first append the start-group marker. - self.mock_stream.AppendVarUInt32( - self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP)) - # Should then serialize itself. - self.mock_message.SerializeToString().AndReturn('foo') - self.mock_stream.AppendRawBytes('foo') - # Should finally append the end-group marker. - self.mock_stream.AppendVarUInt32( - self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP)) - - self.mox.ReplayAll() - self.encoder.AppendGroup(field_number, self.mock_message) - self.mox.VerifyAll() - - def testAppendMessage(self): - field_number = 23 - byte_size = 42 - # Should first append the field number and type information. - self.mock_stream.AppendVarUInt32( - self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) - # Should then append its length. - self.mock_message.ByteSize().AndReturn(byte_size) - self.mock_stream.AppendVarUInt32(byte_size) - # Should then serialize itself to the encoder. - self.mock_message.SerializeToString().AndReturn('foo') - self.mock_stream.AppendRawBytes('foo') - - self.mox.ReplayAll() - self.encoder.AppendMessage(field_number, self.mock_message) - self.mox.VerifyAll() - - def testAppendMessageSetItem(self): - field_number = 23 - byte_size = 42 - # Should first append the field number and type information. - self.mock_stream.AppendVarUInt32( - self.PackTag(1, wire_format.WIRETYPE_START_GROUP)) - self.mock_stream.AppendVarUInt32( - self.PackTag(2, wire_format.WIRETYPE_VARINT)) - self.mock_stream.AppendVarint32(field_number) - self.mock_stream.AppendVarUInt32( - self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED)) - # Should then append its length. - self.mock_message.ByteSize().AndReturn(byte_size) - self.mock_stream.AppendVarUInt32(byte_size) - # Should then serialize itself to the encoder. - self.mock_message.SerializeToString().AndReturn('foo') - self.mock_stream.AppendRawBytes('foo') - self.mock_stream.AppendVarUInt32( - self.PackTag(1, wire_format.WIRETYPE_END_GROUP)) - - self.mox.ReplayAll() - self.encoder.AppendMessageSetItem(field_number, self.mock_message) - self.mox.VerifyAll() - - def testAppendSFixed(self): - # Most of our bounds-checking is done in output_stream.py, - # but encoder.py is responsible for transforming signed - # fixed-width integers into unsigned ones, so we test here - # to ensure that we're not losing any entropy when we do - # that conversion. - field_number = 10 - self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, - 10, wire_format.UINT32_MAX + 1) - self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, - 10, -(1 << 32)) - self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, - 10, wire_format.UINT64_MAX + 1) - self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, - 10, -(1 << 64)) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 11fcfa0..78360b5 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -35,15 +35,20 @@ # indirect testing of the protocol compiler output. """Unittest that directly tests the output of the pure-Python protocol -compiler. See //net/proto2/internal/reflection_test.py for a test which +compiler. See //google/protobuf/reflection_test.py for a test which further ensures that we can use Python protocol message objects as we expect. """ __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import unittest_no_generic_services_pb2 + + +MAX_EXTENSION = 536870912 class GeneratorTest(unittest.TestCase): @@ -71,6 +76,46 @@ class GeneratorTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testExtremeDefaultValues(self): + message = unittest_pb2.TestExtremeDefaultValues() + + # Python pre-2.6 does not have isinf() or isnan() functions, so we have + # to provide our own. + def isnan(val): + # NaN is never equal to itself. + return val != val + def isinf(val): + # Infinity times zero equals NaN. + return not isnan(val) and isnan(val * 0) + + self.assertTrue(isinf(message.inf_double)) + self.assertTrue(message.inf_double > 0) + self.assertTrue(isinf(message.neg_inf_double)) + self.assertTrue(message.neg_inf_double < 0) + self.assertTrue(isnan(message.nan_double)) + + self.assertTrue(isinf(message.inf_float)) + self.assertTrue(message.inf_float > 0) + self.assertTrue(isinf(message.neg_inf_float)) + self.assertTrue(message.neg_inf_float < 0) + self.assertTrue(isnan(message.nan_float)) + + def testHasDefaultValues(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR + + expected_has_default_by_name = { + 'optional_int32': False, + 'repeated_int32': False, + 'optional_nested_message': False, + 'default_int32': True, + } + + has_default_by_name = dict( + [(f.name, f.has_default_value) + for f in desc.fields + if f.name in expected_has_default_by_name]) + self.assertEqual(expected_has_default_by_name, has_default_by_name) + def testContainingTypeBehaviorForExtensions(self): self.assertEqual(unittest_pb2.optional_int32_extension.containing_type, unittest_pb2.TestAllExtensions.DESCRIPTOR) @@ -95,6 +140,81 @@ class GeneratorTest(unittest.TestCase): proto = unittest_mset_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + def testNestedTypes(self): + self.assertEquals( + set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), + set([ + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, + unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR, + unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR, + ])) + self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, []) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, []) + + def testContainingType(self): + self.assertTrue( + unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None) + self.assertTrue( + unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEqual( + unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + + def testContainingTypeInEnumDescriptor(self): + self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None) + self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + + def testPackage(self): + self.assertEqual( + unittest_pb2.TestAllTypes.DESCRIPTOR.file.package, + 'protobuf_unittest') + desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR + self.assertEqual(desc.file.package, 'protobuf_unittest') + self.assertEqual( + unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package, + 'protobuf_unittest_import') + + self.assertEqual( + unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest') + self.assertEqual( + unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package, + 'protobuf_unittest') + self.assertEqual( + unittest_import_pb2._IMPORTENUM.file.package, + 'protobuf_unittest_import') + + def testExtensionRange(self): + self.assertEqual( + unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, []) + self.assertEqual( + unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges, + [(1, MAX_EXTENSION)]) + self.assertEqual( + unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges, + [(42, 43), (4143, 4244), (65536, MAX_EXTENSION)]) + + def testFileDescriptor(self): + self.assertEqual(unittest_pb2.DESCRIPTOR.name, + 'google/protobuf/unittest.proto') + self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest') + self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None) + + def testNoGenericServices(self): + # unittest_no_generic_services.proto should contain defs for everything + # except services. + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO")) + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension")) + self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService")) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py deleted file mode 100755 index 7bda17e..0000000 --- a/python/google/protobuf/internal/input_stream.py +++ /dev/null @@ -1,338 +0,0 @@ -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""InputStream is the primitive interface for reading bits from the wire. - -All protocol buffer deserialization can be expressed in terms of -the InputStream primitives provided here. -""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import array -import struct -from google.protobuf import message -from google.protobuf.internal import wire_format - - -# Note that much of this code is ported from //net/proto/ProtocolBuffer, and -# that the interface is strongly inspired by CodedInputStream from the C++ -# proto2 implementation. - - -class InputStreamBuffer(object): - - """Contains all logic for reading bits, and dealing with stream position. - - If an InputStream method ever raises an exception, the stream is left - in an indeterminate state and is not safe for further use. - """ - - def __init__(self, s): - # What we really want is something like array('B', s), where elements we - # read from the array are already given to us as one-byte integers. BUT - # using array() instead of buffer() would force full string copies to result - # from each GetSubBuffer() call. - # - # So, if the N serialized bytes of a single protocol buffer object are - # split evenly between 2 child messages, and so on recursively, using - # array('B', s) instead of buffer() would incur an additional N*logN bytes - # copied during deserialization. - # - # The higher constant overhead of having to ord() for every byte we read - # from the buffer in _ReadVarintHelper() could definitely lead to worse - # performance in many real-world scenarios, even if the asymptotic - # complexity is better. However, our real answer is that the mythical - # Python/C extension module output mode for the protocol compiler will - # be blazing-fast and will eliminate most use of this class anyway. - self._buffer = buffer(s) - self._pos = 0 - - def EndOfStream(self): - """Returns true iff we're at the end of the stream. - If this returns true, then a call to any other InputStream method - will raise an exception. - """ - return self._pos >= len(self._buffer) - - def Position(self): - """Returns the current position in the stream, or equivalently, the - number of bytes read so far. - """ - return self._pos - - def GetSubBuffer(self, size=None): - """Returns a sequence-like object that represents a portion of our - underlying sequence. - - Position 0 in the returned object corresponds to self.Position() - in this stream. - - If size is specified, then the returned object ends after the - next "size" bytes in this stream. If size is not specified, - then the returned object ends at the end of this stream. - - We guarantee that the returned object R supports the Python buffer - interface (and thus that the call buffer(R) will work). - - Note that the returned buffer is read-only. - - The intended use for this method is for nested-message and nested-group - deserialization, where we want to make a recursive MergeFromString() - call on the portion of the original sequence that contains the serialized - nested message. (And we'd like to do so without making unnecessary string - copies). - - REQUIRES: size is nonnegative. - """ - # Note that buffer() doesn't perform any actual string copy. - if size is None: - return buffer(self._buffer, self._pos) - else: - if size < 0: - raise message.DecodeError('Negative size %d' % size) - return buffer(self._buffer, self._pos, size) - - def SkipBytes(self, num_bytes): - """Skip num_bytes bytes ahead, or go to the end of the stream, whichever - comes first. - - REQUIRES: num_bytes is nonnegative. - """ - if num_bytes < 0: - raise message.DecodeError('Negative num_bytes %d' % num_bytes) - self._pos += num_bytes - self._pos = min(self._pos, len(self._buffer)) - - def ReadBytes(self, size): - """Reads up to 'size' bytes from the stream, stopping early - only if we reach the end of the stream. Returns the bytes read - as a string. - """ - if size < 0: - raise message.DecodeError('Negative size %d' % size) - s = (self._buffer[self._pos : self._pos + size]) - self._pos += len(s) # Only advance by the number of bytes actually read. - return s - - def ReadLittleEndian32(self): - """Interprets the next 4 bytes of the stream as a little-endian - encoded, unsiged 32-bit integer, and returns that integer. - """ - try: - i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, - self._buffer[self._pos : self._pos + 4]) - self._pos += 4 - return i[0] # unpack() result is a 1-element tuple. - except struct.error, e: - raise message.DecodeError(e) - - def ReadLittleEndian64(self): - """Interprets the next 8 bytes of the stream as a little-endian - encoded, unsiged 64-bit integer, and returns that integer. - """ - try: - i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, - self._buffer[self._pos : self._pos + 8]) - self._pos += 8 - return i[0] # unpack() result is a 1-element tuple. - except struct.error, e: - raise message.DecodeError(e) - - def ReadVarint32(self): - """Reads a varint from the stream, interprets this varint - as a signed, 32-bit integer, and returns the integer. - """ - i = self.ReadVarint64() - if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: - raise message.DecodeError('Value out of range for int32: %d' % i) - return int(i) - - def ReadVarUInt32(self): - """Reads a varint from the stream, interprets this varint - as an unsigned, 32-bit integer, and returns the integer. - """ - i = self.ReadVarUInt64() - if i > wire_format.UINT32_MAX: - raise message.DecodeError('Value out of range for uint32: %d' % i) - return i - - def ReadVarint64(self): - """Reads a varint from the stream, interprets this varint - as a signed, 64-bit integer, and returns the integer. - """ - i = self.ReadVarUInt64() - if i > wire_format.INT64_MAX: - i -= (1 << 64) - return i - - def ReadVarUInt64(self): - """Reads a varint from the stream, interprets this varint - as an unsigned, 64-bit integer, and returns the integer. - """ - i = self._ReadVarintHelper() - if not 0 <= i <= wire_format.UINT64_MAX: - raise message.DecodeError('Value out of range for uint64: %d' % i) - return i - - def _ReadVarintHelper(self): - """Helper for the various varint-reading methods above. - Reads an unsigned, varint-encoded integer from the stream and - returns this integer. - - Does no bounds checking except to ensure that we read at most as many bytes - as could possibly be present in a varint-encoded 64-bit number. - """ - result = 0 - shift = 0 - while 1: - if shift >= 64: - raise message.DecodeError('Too many bytes when decoding varint.') - try: - b = ord(self._buffer[self._pos]) - except IndexError: - raise message.DecodeError('Truncated varint.') - self._pos += 1 - result |= ((b & 0x7f) << shift) - shift += 7 - if not (b & 0x80): - return result - - -class InputStreamArray(object): - - """Contains all logic for reading bits, and dealing with stream position. - - If an InputStream method ever raises an exception, the stream is left - in an indeterminate state and is not safe for further use. - - This alternative to InputStreamBuffer is used in environments where buffer() - is unavailble, such as Google App Engine. - """ - - def __init__(self, s): - self._buffer = array.array('B', s) - self._pos = 0 - - def EndOfStream(self): - return self._pos >= len(self._buffer) - - def Position(self): - return self._pos - - def GetSubBuffer(self, size=None): - if size is None: - return self._buffer[self._pos : ].tostring() - else: - if size < 0: - raise message.DecodeError('Negative size %d' % size) - return self._buffer[self._pos : self._pos + size].tostring() - - def SkipBytes(self, num_bytes): - if num_bytes < 0: - raise message.DecodeError('Negative num_bytes %d' % num_bytes) - self._pos += num_bytes - self._pos = min(self._pos, len(self._buffer)) - - def ReadBytes(self, size): - if size < 0: - raise message.DecodeError('Negative size %d' % size) - s = self._buffer[self._pos : self._pos + size].tostring() - self._pos += len(s) # Only advance by the number of bytes actually read. - return s - - def ReadLittleEndian32(self): - try: - i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, - self._buffer[self._pos : self._pos + 4]) - self._pos += 4 - return i[0] # unpack() result is a 1-element tuple. - except struct.error, e: - raise message.DecodeError(e) - - def ReadLittleEndian64(self): - try: - i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, - self._buffer[self._pos : self._pos + 8]) - self._pos += 8 - return i[0] # unpack() result is a 1-element tuple. - except struct.error, e: - raise message.DecodeError(e) - - def ReadVarint32(self): - i = self.ReadVarint64() - if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: - raise message.DecodeError('Value out of range for int32: %d' % i) - return int(i) - - def ReadVarUInt32(self): - i = self.ReadVarUInt64() - if i > wire_format.UINT32_MAX: - raise message.DecodeError('Value out of range for uint32: %d' % i) - return i - - def ReadVarint64(self): - i = self.ReadVarUInt64() - if i > wire_format.INT64_MAX: - i -= (1 << 64) - return i - - def ReadVarUInt64(self): - i = self._ReadVarintHelper() - if not 0 <= i <= wire_format.UINT64_MAX: - raise message.DecodeError('Value out of range for uint64: %d' % i) - return i - - def _ReadVarintHelper(self): - result = 0 - shift = 0 - while 1: - if shift >= 64: - raise message.DecodeError('Too many bytes when decoding varint.') - try: - b = self._buffer[self._pos] - except IndexError: - raise message.DecodeError('Truncated varint.') - self._pos += 1 - result |= ((b & 0x7f) << shift) - shift += 7 - if not (b & 0x80): - return result - - -try: - buffer('') - InputStream = InputStreamBuffer -except NotImplementedError: - # Google App Engine: dev_appserver.py - InputStream = InputStreamArray -except RuntimeError: - # Google App Engine: production - InputStream = InputStreamArray diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py deleted file mode 100755 index ecec7f7..0000000 --- a/python/google/protobuf/internal/input_stream_test.py +++ /dev/null @@ -1,314 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test for google.protobuf.internal.input_stream.""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import unittest -from google.protobuf import message -from google.protobuf.internal import wire_format -from google.protobuf.internal import input_stream - - -class InputStreamBufferTest(unittest.TestCase): - - def setUp(self): - self.__original_input_stream = input_stream.InputStream - input_stream.InputStream = input_stream.InputStreamBuffer - - def tearDown(self): - input_stream.InputStream = self.__original_input_stream - - def testEndOfStream(self): - stream = input_stream.InputStream('abcd') - self.assertFalse(stream.EndOfStream()) - self.assertEqual('abcd', stream.ReadBytes(10)) - self.assertTrue(stream.EndOfStream()) - - def testPosition(self): - stream = input_stream.InputStream('abcd') - self.assertEqual(0, stream.Position()) - self.assertEqual(0, stream.Position()) # No side-effects. - stream.ReadBytes(1) - self.assertEqual(1, stream.Position()) - stream.ReadBytes(1) - self.assertEqual(2, stream.Position()) - stream.ReadBytes(10) - self.assertEqual(4, stream.Position()) # Can't go past end of stream. - - def testGetSubBuffer(self): - stream = input_stream.InputStream('abcd') - # Try leaving out the size. - self.assertEqual('abcd', str(stream.GetSubBuffer())) - stream.SkipBytes(1) - # GetSubBuffer() always starts at current size. - self.assertEqual('bcd', str(stream.GetSubBuffer())) - # Try 0-size. - self.assertEqual('', str(stream.GetSubBuffer(0))) - # Negative sizes should raise an error. - self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1) - # Positive sizes should work as expected. - self.assertEqual('b', str(stream.GetSubBuffer(1))) - self.assertEqual('bc', str(stream.GetSubBuffer(2))) - # Sizes longer than remaining bytes in the buffer should - # return the whole remaining buffer. - self.assertEqual('bcd', str(stream.GetSubBuffer(1000))) - - def testSkipBytes(self): - stream = input_stream.InputStream('') - # Skipping bytes when at the end of stream - # should have no effect. - stream.SkipBytes(0) - stream.SkipBytes(1) - stream.SkipBytes(2) - self.assertTrue(stream.EndOfStream()) - self.assertEqual(0, stream.Position()) - - # Try skipping within a stream. - stream = input_stream.InputStream('abcd') - self.assertEqual(0, stream.Position()) - stream.SkipBytes(1) - self.assertEqual(1, stream.Position()) - stream.SkipBytes(10) # Can't skip past the end. - self.assertEqual(4, stream.Position()) - - # Ensure that a negative skip raises an exception. - stream = input_stream.InputStream('abcd') - stream.SkipBytes(1) - self.assertRaises(message.DecodeError, stream.SkipBytes, -1) - - def testReadBytes(self): - s = 'abcd' - # Also test going past the total stream length. - for i in range(len(s) + 10): - stream = input_stream.InputStream(s) - self.assertEqual(s[:i], stream.ReadBytes(i)) - self.assertEqual(min(i, len(s)), stream.Position()) - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadBytes, -1) - - def EnsureFailureOnEmptyStream(self, input_stream_method): - """Helper for integer-parsing tests below. - Ensures that the given InputStream method raises a DecodeError - if called on a stream with no bytes remaining. - """ - stream = input_stream.InputStream('') - self.assertRaises(message.DecodeError, input_stream_method, stream) - - def testReadLittleEndian32(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32) - s = '' - # Read 0. - s += '\x00\x00\x00\x00' - # Read 1. - s += '\x01\x00\x00\x00' - # Read a bunch of different bytes. - s += '\x01\x02\x03\x04' - # Read max unsigned 32-bit int. - s += '\xff\xff\xff\xff' - # Try a read with fewer than 4 bytes left in the stream. - s += '\x00\x00\x00' - stream = input_stream.InputStream(s) - self.assertEqual(0, stream.ReadLittleEndian32()) - self.assertEqual(4, stream.Position()) - self.assertEqual(1, stream.ReadLittleEndian32()) - self.assertEqual(8, stream.Position()) - self.assertEqual(0x04030201, stream.ReadLittleEndian32()) - self.assertEqual(12, stream.Position()) - self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32()) - self.assertEqual(16, stream.Position()) - self.assertRaises(message.DecodeError, stream.ReadLittleEndian32) - - def testReadLittleEndian64(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64) - s = '' - # Read 0. - s += '\x00\x00\x00\x00\x00\x00\x00\x00' - # Read 1. - s += '\x01\x00\x00\x00\x00\x00\x00\x00' - # Read a bunch of different bytes. - s += '\x01\x02\x03\x04\x05\x06\x07\x08' - # Read max unsigned 64-bit int. - s += '\xff\xff\xff\xff\xff\xff\xff\xff' - # Try a read with fewer than 8 bytes left in the stream. - s += '\x00\x00\x00' - stream = input_stream.InputStream(s) - self.assertEqual(0, stream.ReadLittleEndian64()) - self.assertEqual(8, stream.Position()) - self.assertEqual(1, stream.ReadLittleEndian64()) - self.assertEqual(16, stream.Position()) - self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64()) - self.assertEqual(24, stream.Position()) - self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64()) - self.assertEqual(32, stream.Position()) - self.assertRaises(message.DecodeError, stream.ReadLittleEndian64) - - def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method): - """Helper for tests below that test successful reads of various varints. - - Args: - varints_and_ints: Iterable of (str, integer) pairs, where the string - gives the wire encoding and the integer gives the value we expect - to be returned by the read_method upon encountering this string. - read_method: Unbound InputStream method that is capable of reading - the encoded strings provided in the first elements of varints_and_ints. - """ - s = ''.join(s for s, i in varints_and_ints) - stream = input_stream.InputStream(s) - expected_pos = 0 - self.assertEqual(expected_pos, stream.Position()) - for s, expected_int in varints_and_ints: - self.assertEqual(expected_int, read_method(stream)) - expected_pos += len(s) - self.assertEqual(expected_pos, stream.Position()) - - def testReadVarint32Success(self): - varints_and_ints = [ - ('\x00', 0), - ('\x01', 1), - ('\x7f', 127), - ('\x80\x01', 128), - ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), - ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX), - ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN), - ] - self.ReadVarintSuccessTestHelper(varints_and_ints, - input_stream.InputStream.ReadVarint32) - - def testReadVarint32Failure(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32) - - # Try and fail to read INT32_MAX + 1. - s = '\x80\x80\x80\x80\x08' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarint32) - - # Try and fail to read INT32_MIN - 1. - s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarint32) - - # Try and fail to read something that looks like - # a varint with more than 10 bytes. - s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarint32) - - def testReadVarUInt32Success(self): - varints_and_ints = [ - ('\x00', 0), - ('\x01', 1), - ('\x7f', 127), - ('\x80\x01', 128), - ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX), - ] - self.ReadVarintSuccessTestHelper(varints_and_ints, - input_stream.InputStream.ReadVarUInt32) - - def testReadVarUInt32Failure(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32) - # Try and fail to read UINT32_MAX + 1 - s = '\x80\x80\x80\x80\x10' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarUInt32) - - # Try and fail to read something that looks like - # a varint with more than 10 bytes. - s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarUInt32) - - def testReadVarint64Success(self): - varints_and_ints = [ - ('\x00', 0), - ('\x01', 1), - ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), - ('\x7f', 127), - ('\x80\x01', 128), - ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX), - ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN), - ] - self.ReadVarintSuccessTestHelper(varints_and_ints, - input_stream.InputStream.ReadVarint64) - - def testReadVarint64Failure(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64) - # Try and fail to read something with the mythical 64th bit set. - s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarint64) - - # Try and fail to read something that looks like - # a varint with more than 10 bytes. - s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarint64) - - def testReadVarUInt64Success(self): - varints_and_ints = [ - ('\x00', 0), - ('\x01', 1), - ('\x7f', 127), - ('\x80\x01', 128), - ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63), - ] - self.ReadVarintSuccessTestHelper(varints_and_ints, - input_stream.InputStream.ReadVarUInt64) - - def testReadVarUInt64Failure(self): - self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64) - # Try and fail to read something with the mythical 64th bit set. - s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarUInt64) - - # Try and fail to read something that looks like - # a varint with more than 10 bytes. - s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - stream = input_stream.InputStream(s) - self.assertRaises(message.DecodeError, stream.ReadVarUInt64) - - -class InputStreamArrayTest(InputStreamBufferTest): - - def setUp(self): - # Test InputStreamArray against the same tests in InputStreamBuffer - self.__original_input_stream = input_stream.InputStream - input_stream.InputStream = input_stream.InputStreamArray - - def tearDown(self): - input_stream.InputStream = self.__original_input_stream - - -if __name__ == '__main__': - unittest.main() diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py index 4397895..1080234 100755 --- a/python/google/protobuf/internal/message_listener.py +++ b/python/google/protobuf/internal/message_listener.py @@ -39,22 +39,34 @@ __author__ = 'robinson@google.com (Will Robinson)' class MessageListener(object): - """Listens for transitions to nonempty and for invalidations of cached - byte sizes. Meant to be registered via Message._SetListener(). + """Listens for modifications made to a message. Meant to be registered via + Message._SetListener(). + + Attributes: + dirty: If True, then calling Modified() would be a no-op. This can be + used to avoid these calls entirely in the common case. """ - def TransitionToNonempty(self): - """Called the *first* time that this message becomes nonempty. - Implementations are free (but not required) to call this method multiple - times after the message has become nonempty. - """ - raise NotImplementedError + def Modified(self): + """Called every time the message is modified in such a way that the parent + message may need to be updated. This currently means either: + (a) The message was modified for the first time, so the parent message + should henceforth mark the message as present. + (b) The message's cached byte size became dirty -- i.e. the message was + modified for the first time after a previous call to ByteSize(). + Therefore the parent should also mark its byte size as dirty. + Note that (a) implies (b), since new objects start out with a client cached + size (zero). However, we document (a) explicitly because it is important. + + Modified() will *only* be called in response to one of these two events -- + not every time the sub-message is modified. - def ByteSizeDirty(self): - """Called *every* time the cached byte size value - for this object is invalidated (transitions from being - "clean" to "dirty"). + Note that if the listener's |dirty| attribute is true, then calling + Modified at the moment would be a no-op, so it can be skipped. Performance- + sensitive callers should check this attribute directly before calling since + it will be true most of the time. """ + raise NotImplementedError @@ -62,8 +74,5 @@ class NullMessageListener(object): """No-op MessageListener implementation.""" - def TransitionToNonempty(self): - pass - - def ByteSizeDirty(self): + def Modified(self): pass diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index df344cf..73a9a3a 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -30,7 +30,16 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Tests python protocol buffers against the golden message.""" +"""Tests python protocol buffers against the golden message. + +Note that the golden messages exercise every known field type, thus this +test ends up exercising and verifying nearly all of the parsing and +serialization code in the whole library. + +TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of +sense to call this a test of the "message" module, which only declares an +abstract interface. +""" __author__ = 'gps@google.com (Gregory P. Smith)' @@ -40,14 +49,41 @@ from google.protobuf import unittest_pb2 from google.protobuf.internal import test_util -class MessageTest(test_util.GoldenMessageTestCase): +class MessageTest(unittest.TestCase): def testGoldenMessage(self): golden_data = test_util.GoldenFile('golden_message').read() golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) - self.ExpectAllFieldsSet(golden_message) + test_util.ExpectAllFieldsSet(self, golden_message) + self.assertTrue(golden_message.SerializeToString() == golden_data) + + def testGoldenExtensions(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(golden_message.SerializeToString() == golden_data) + + def testGoldenPackedMessage(self): + golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(all_set.SerializeToString() == golden_data) + def testGoldenPackedExtensions(self): + golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_message = unittest_pb2.TestPackedExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(all_set.SerializeToString() == golden_data) if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/output_stream.py b/python/google/protobuf/internal/output_stream.py deleted file mode 100755 index 6c2d6f6..0000000 --- a/python/google/protobuf/internal/output_stream.py +++ /dev/null @@ -1,125 +0,0 @@ -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""OutputStream is the primitive interface for sticking bits on the wire. - -All protocol buffer serialization can be expressed in terms of -the OutputStream primitives provided here. -""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import array -import struct -from google.protobuf import message -from google.protobuf.internal import wire_format - - - -# Note that much of this code is ported from //net/proto/ProtocolBuffer, and -# that the interface is strongly inspired by CodedOutputStream from the C++ -# proto2 implementation. - - -class OutputStream(object): - - """Contains all logic for writing bits, and ToString() to get the result.""" - - def __init__(self): - self._buffer = array.array('B') - - def AppendRawBytes(self, raw_bytes): - """Appends raw_bytes to our internal buffer.""" - self._buffer.fromstring(raw_bytes) - - def AppendLittleEndian32(self, unsigned_value): - """Appends an unsigned 32-bit integer to the internal buffer, - in little-endian byte order. - """ - if not 0 <= unsigned_value <= wire_format.UINT32_MAX: - raise message.EncodeError( - 'Unsigned 32-bit out of range: %d' % unsigned_value) - self._buffer.fromstring(struct.pack( - wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value)) - - def AppendLittleEndian64(self, unsigned_value): - """Appends an unsigned 64-bit integer to the internal buffer, - in little-endian byte order. - """ - if not 0 <= unsigned_value <= wire_format.UINT64_MAX: - raise message.EncodeError( - 'Unsigned 64-bit out of range: %d' % unsigned_value) - self._buffer.fromstring(struct.pack( - wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value)) - - def AppendVarint32(self, value): - """Appends a signed 32-bit integer to the internal buffer, - encoded as a varint. (Note that a negative varint32 will - always require 10 bytes of space.) - """ - if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX: - raise message.EncodeError('Value out of range: %d' % value) - self.AppendVarint64(value) - - def AppendVarUInt32(self, value): - """Appends an unsigned 32-bit integer to the internal buffer, - encoded as a varint. - """ - if not 0 <= value <= wire_format.UINT32_MAX: - raise message.EncodeError('Value out of range: %d' % value) - self.AppendVarUInt64(value) - - def AppendVarint64(self, value): - """Appends a signed 64-bit integer to the internal buffer, - encoded as a varint. - """ - if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX: - raise message.EncodeError('Value out of range: %d' % value) - if value < 0: - value += (1 << 64) - self.AppendVarUInt64(value) - - def AppendVarUInt64(self, unsigned_value): - """Appends an unsigned 64-bit integer to the internal buffer, - encoded as a varint. - """ - if not 0 <= unsigned_value <= wire_format.UINT64_MAX: - raise message.EncodeError('Value out of range: %d' % unsigned_value) - while True: - bits = unsigned_value & 0x7f - unsigned_value >>= 7 - if not unsigned_value: - self._buffer.append(bits) - break - self._buffer.append(0x80|bits) - - def ToString(self): - """Returns a string containing the bytes in our internal buffer.""" - return self._buffer.tostring() diff --git a/python/google/protobuf/internal/output_stream_test.py b/python/google/protobuf/internal/output_stream_test.py deleted file mode 100755 index df92eec..0000000 --- a/python/google/protobuf/internal/output_stream_test.py +++ /dev/null @@ -1,178 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test for google.protobuf.internal.output_stream.""" - -__author__ = 'robinson@google.com (Will Robinson)' - -import unittest -from google.protobuf import message -from google.protobuf.internal import output_stream -from google.protobuf.internal import wire_format - - -class OutputStreamTest(unittest.TestCase): - - def setUp(self): - self.stream = output_stream.OutputStream() - - def testAppendRawBytes(self): - # Empty string. - self.stream.AppendRawBytes('') - self.assertEqual('', self.stream.ToString()) - - # Nonempty string. - self.stream.AppendRawBytes('abc') - self.assertEqual('abc', self.stream.ToString()) - - # Ensure that we're actually appending. - self.stream.AppendRawBytes('def') - self.assertEqual('abcdef', self.stream.ToString()) - - def AppendNumericTestHelper(self, append_fn, values_and_strings): - """For each (value, expected_string) pair in values_and_strings, - calls an OutputStream.Append*(value) method on an OutputStream and ensures - that the string written to that stream matches expected_string. - - Args: - append_fn: Unbound OutputStream method that takes an integer or - long value as input. - values_and_strings: Iterable of (value, expected_string) pairs. - """ - for conversion in (int, long): - for value, string in values_and_strings: - stream = output_stream.OutputStream() - expected_string = '' - append_fn(stream, conversion(value)) - expected_string += string - self.assertEqual(expected_string, stream.ToString()) - - def AppendOverflowTestHelper(self, append_fn, value): - """Calls an OutputStream.Append*(value) method and asserts - that the method raises message.EncodeError. - - Args: - append_fn: Unbound OutputStream method that takes an integer or - long value as input. - value: Value to pass to append_fn which should cause an - message.EncodeError. - """ - stream = output_stream.OutputStream() - self.assertRaises(message.EncodeError, append_fn, stream, value) - - def testAppendLittleEndian32(self): - append_fn = output_stream.OutputStream.AppendLittleEndian32 - values_and_expected_strings = [ - (0, '\x00\x00\x00\x00'), - (1, '\x01\x00\x00\x00'), - ((1 << 32) - 1, '\xff\xff\xff\xff'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, 1 << 32) - self.AppendOverflowTestHelper(append_fn, -1) - - def testAppendLittleEndian64(self): - append_fn = output_stream.OutputStream.AppendLittleEndian64 - values_and_expected_strings = [ - (0, '\x00\x00\x00\x00\x00\x00\x00\x00'), - (1, '\x01\x00\x00\x00\x00\x00\x00\x00'), - ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, 1 << 64) - self.AppendOverflowTestHelper(append_fn, -1) - - def testAppendVarint32(self): - append_fn = output_stream.OutputStream.AppendVarint32 - values_and_expected_strings = [ - (0, '\x00'), - (1, '\x01'), - (127, '\x7f'), - (128, '\x80\x01'), - (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), - (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'), - (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1) - self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1) - - def testAppendVarUInt32(self): - append_fn = output_stream.OutputStream.AppendVarUInt32 - values_and_expected_strings = [ - (0, '\x00'), - (1, '\x01'), - (127, '\x7f'), - (128, '\x80\x01'), - (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, -1) - self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1) - - def testAppendVarint64(self): - append_fn = output_stream.OutputStream.AppendVarint64 - values_and_expected_strings = [ - (0, '\x00'), - (1, '\x01'), - (127, '\x7f'), - (128, '\x80\x01'), - (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), - (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'), - (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1) - self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1) - - def testAppendVarUInt64(self): - append_fn = output_stream.OutputStream.AppendVarUInt64 - values_and_expected_strings = [ - (0, '\x00'), - (1, '\x01'), - (127, '\x7f'), - (128, '\x80\x01'), - (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), - ] - self.AppendNumericTestHelper(append_fn, values_and_expected_strings) - - self.AppendOverflowTestHelper(append_fn, -1) - self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 8610177..2c9fa30 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -38,6 +38,7 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' import operator +import struct import unittest # TODO(robinson): When we split this test in two, only some of these imports @@ -56,6 +57,51 @@ from google.protobuf.internal import test_util from google.protobuf.internal import decoder +class _MiniDecoder(object): + """Decodes a stream of values from a string. + + Once upon a time we actually had a class called decoder.Decoder. Then we + got rid of it during a redesign that made decoding much, much faster overall. + But a couple tests in this file used it to check that the serialized form of + a message was correct. So, this class implements just the methods that were + used by said tests, so that we don't have to rewrite the tests. + """ + + def __init__(self, bytes): + self._bytes = bytes + self._pos = 0 + + def ReadVarint(self): + result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) + return result + + ReadInt32 = ReadVarint + ReadInt64 = ReadVarint + ReadUInt32 = ReadVarint + ReadUInt64 = ReadVarint + + def ReadSInt64(self): + return wire_format.ZigZagDecode(self.ReadVarint()) + + ReadSInt32 = ReadSInt64 + + def ReadFieldNumberAndWireType(self): + return wire_format.UnpackTag(self.ReadVarint()) + + def ReadFloat(self): + result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] + self._pos += 4 + return result + + def ReadDouble(self): + result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] + self._pos += 8 + return result + + def EndOfStream(self): + return self._pos == len(self._bytes) + + class ReflectionTest(unittest.TestCase): def assertIs(self, values, others): @@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase): for i in range(len(values)): self.assertTrue(values[i] is others[i]) + def testScalarConstructor(self): + # Constructor with only scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_double=54.321, + optional_string='optional_string') + + self.assertEqual(24, proto.optional_int32) + self.assertEqual(54.321, proto.optional_double) + self.assertEqual('optional_string', proto.optional_string) + + def testRepeatedScalarConstructor(self): + # Constructor with only repeated scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_int32=[1, 2, 3, 4], + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_string=["optional_string"]) + + self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals(["optional_string"], list(proto.repeated_string)) + + def testRepeatedCompositeConstructor(self): + # Constructor with only repeated composite types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + repeatedgroup=[ + unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) + + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + self.assertEquals( + [unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], + list(proto.repeatedgroup)) + + def testMixedConstructor(self): + # Constructor with only mixed types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_string='optional_string', + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)]) + + self.assertEqual(24, proto.optional_int32) + self.assertEqual('optional_string', proto.optional_string) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + def testSimpleHasBits(self): # Test a scalar. proto = unittest_pb2.TestAllTypes() @@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase): proto.optional_fixed32 = 1 proto.optional_int32 = 5 proto.optional_string = 'foo' + # Access sub-message but don't set it yet. + nested_message = proto.optional_nested_message self.assertEqual( [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], proto.ListFields()) + proto.optional_nested_message.bb = 123 + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), + (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], + nested_message) ], + proto.ListFields()) + def testRepeatedListFields(self): proto = unittest_pb2.TestAllTypes() proto.repeated_fixed32.append(1) @@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase): proto.repeated_string.append('baz') proto.repeated_string.extend(str(x) for x in xrange(2)) proto.optional_int32 = 21 + proto.repeated_bool # Access but don't set anything; should not be listed. self.assertEqual( [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), @@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase): extendee_proto.ClearExtension(extension) extension_proto.foreign_message_int = 23 - self.assertTrue(not toplevel.HasField('submessage')) self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) def testExtensionFailureModes(self): @@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def assertInitialized(self, proto): + self.assertTrue(proto.IsInitialized()) + # Neither method should raise an exception. + proto.SerializeToString() + proto.SerializePartialToString() + + def assertNotInitialized(self, proto): + self.assertFalse(proto.IsInitialized()) + self.assertRaises(message.EncodeError, proto.SerializeToString) + # "Partial" serialization doesn't care if message is uninitialized. + proto.SerializePartialToString() + def testIsInitialized(self): # Trivial cases - all optional fields and extensions. proto = unittest_pb2.TestAllTypes() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) proto = unittest_pb2.TestAllExtensions() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # The case of uninitialized required fields. proto = unittest_pb2.TestRequired() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.a = proto.b = proto.c = 2 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # The case of uninitialized submessage. proto = unittest_pb2.TestRequiredForeign() - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) proto.optional_message.a = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.optional_message.b = 0 proto.optional_message.c = 0 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized repeated submessage. message1 = proto.repeated_message.add() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message1.a = message1.b = message1.c = 0 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized repeated group in an extension. proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.TestRequired.multi message1 = proto.Extensions[extension].add() message2 = proto.Extensions[extension].add() - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message1.a = 1 message1.b = 1 message1.c = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) message2.a = 2 message2.b = 2 message2.c = 2 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) # Uninitialized nonrepeated message in an extension. proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.TestRequired.single proto.Extensions[extension].a = 1 - self.assertFalse(proto.IsInitialized()) + self.assertNotInitialized(proto) proto.Extensions[extension].b = 2 proto.Extensions[extension].c = 3 - self.assertTrue(proto.IsInitialized()) + self.assertInitialized(proto) + + # Try passing an errors list. + errors = [] + proto = unittest_pb2.TestRequired() + self.assertFalse(proto.IsInitialized(errors)) + self.assertEqual(errors, ['a', 'b', 'c']) def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase): test_utf8_bytes, len(test_utf8_bytes) * '\xff') self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + def testEmptyNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.CopyFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.ParseFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + serialized = proto.SerializeToString() + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFromString(serialized) + self.assertTrue(proto2.HasField('optional_nested_message')) + + def testSetInParent(self): + proto = unittest_pb2.TestAllTypes() + self.assertFalse(proto.HasField('optionalgroup')) + proto.optionalgroup.SetInParent() + self.assertTrue(proto.HasField('optionalgroup')) + # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. @@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase): second_proto.MergeFromString(serialized) self.assertEqual(first_proto, second_proto) + def testSerializeNegativeValues(self): + first_proto = unittest_pb2.TestAllTypes() + + first_proto.optional_int32 = -1 + first_proto.optional_int64 = -(2 << 40) + first_proto.optional_sint32 = -3 + first_proto.optional_sint64 = -(4 << 40) + first_proto.optional_sfixed32 = -5 + first_proto.optional_sfixed64 = -(6 << 40) + + second_proto = unittest_pb2.TestAllTypes.FromString( + first_proto.SerializeToString()) + + self.assertEqual(first_proto, second_proto) + + def testParseTruncated(self): + first_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + + for truncation_point in xrange(len(serialized) + 1): + try: + second_proto = unittest_pb2.TestAllTypes() + unknown_fields = unittest_pb2.TestEmptyMessage() + pos = second_proto._InternalParse(serialized, 0, truncation_point) + # If we didn't raise an error then we read exactly the amount expected. + self.assertEqual(truncation_point, pos) + + # Parsing to unknown fields should not throw if parsing to known fields + # did not. + try: + pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) + self.assertEqual(truncation_point, pos2) + except message.DecodeError: + self.fail('Parsing unknown fields failed when parsing known fields ' + 'did not.') + except message.DecodeError: + # Parsing unknown fields should also fail. + self.assertRaises(message.DecodeError, unknown_fields._InternalParse, + serialized, 0, truncation_point) + def testCanonicalSerializationOrder(self): proto = more_messages_pb2.OutOfOrderFields() # These are also their tag numbers. Even though we're setting these in @@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase): proto.optional_int32 = 1 serialized = proto.SerializeToString() self.assertEqual(proto.ByteSize(), len(serialized)) - d = decoder.Decoder(serialized) + d = _MiniDecoder(serialized) ReadTag = d.ReadFieldNumberAndWireType self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) self.assertEqual(1, d.ReadInt32()) @@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.a is not set.') + 'Message is missing required fields: a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.b is not set.') + 'Message is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Required field protobuf_unittest.TestRequired.c is not set.') + 'Message is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase): self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) + def testSerializeUninitializedSubMessage(self): + proto = unittest_pb2.TestRequiredForeign() + + # Sub-message doesn't exist yet, so this succeeds. + proto.SerializeToString() + + proto.optional_message.a = 1 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'optional_message.b,optional_message.c') + + proto.optional_message.b = 2 + proto.optional_message.c = 3 + proto.SerializeToString() + + proto.repeated_message.add().a = 1 + proto.repeated_message.add().b = 2 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'repeated_message[0].b,repeated_message[0].c,' + 'repeated_message[1].a,repeated_message[1].c') + + proto.repeated_message[0].b = 2 + proto.repeated_message[0].c = 3 + proto.repeated_message[1].a = 1 + proto.repeated_message[1].c = 3 + proto.SerializeToString() + def testSerializeAllPackedFields(self): first_proto = unittest_pb2.TestPackedTypes() second_proto = unittest_pb2.TestPackedTypes() @@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase): proto.packed_float.append(2.0) # 4 bytes, will be before double serialized = proto.SerializeToString() self.assertEqual(proto.ByteSize(), len(serialized)) - d = decoder.Decoder(serialized) + d = _MiniDecoder(serialized) ReadTag = d.ReadFieldNumberAndWireType self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) self.assertEqual(1+1+1+2, d.ReadInt32()) @@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1000.0, d.ReadDouble()) self.assertTrue(d.EndOfStream()) + def testParsePackedFromUnpacked(self): + unpacked = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(unpacked) + packed = unittest_pb2.TestPackedTypes() + packed.MergeFromString(unpacked.SerializeToString()) + expected = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(expected) + self.assertEqual(expected, packed) + + def testParseUnpackedFromPacked(self): + packed = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(packed) + unpacked = unittest_pb2.TestUnpackedTypes() + unpacked.MergeFromString(packed.SerializeToString()) + expected = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(expected) + self.assertEqual(expected, unpacked) + def testFieldNumbers(self): proto = unittest_pb2.TestAllTypes() self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) @@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase): field_descriptor.label) -class UtilityTest(unittest.TestCase): - - def testImergeSorted(self): - ImergeSorted = reflection._ImergeSorted - # Various types of emptiness. - self.assertEqual([], list(ImergeSorted())) - self.assertEqual([], list(ImergeSorted([]))) - self.assertEqual([], list(ImergeSorted([], []))) - - # One nonempty list. - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], []))) - self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3]))) - - # Merging some nonempty lists together. - self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2]))) - self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], []))) - - # Elements repeated across component iterators. - self.assertEqual([1, 2, 2, 3, 3], - list(ImergeSorted([1, 2], [3], [2, 3]))) - - # Elements repeated within an iterator. - self.assertEqual([1, 2, 2, 3, 3], - list(ImergeSorted([1, 2, 2], [3], [3]))) - if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 1a0da55..1df1619 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -31,14 +31,13 @@ """Utilities for Python proto2 tests. This is intentionally modeled on C++ code in -//net/proto2/internal/test_util.*. +//google/protobuf/test_util.*. """ __author__ = 'robinson@google.com (Will Robinson)' import os.path -import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 @@ -353,198 +352,198 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized): raise ValueError('Expected %r, found %r' % (expected, serialized)) -class GoldenMessageTestCase(unittest.TestCase): - """This adds methods to TestCase useful for verifying our Golden Message.""" - - def ExpectAllFieldsSet(self, message): - """Check all fields for correct values have after Set*Fields() is called.""" - self.assertTrue(message.HasField('optional_int32')) - self.assertTrue(message.HasField('optional_int64')) - self.assertTrue(message.HasField('optional_uint32')) - self.assertTrue(message.HasField('optional_uint64')) - self.assertTrue(message.HasField('optional_sint32')) - self.assertTrue(message.HasField('optional_sint64')) - self.assertTrue(message.HasField('optional_fixed32')) - self.assertTrue(message.HasField('optional_fixed64')) - self.assertTrue(message.HasField('optional_sfixed32')) - self.assertTrue(message.HasField('optional_sfixed64')) - self.assertTrue(message.HasField('optional_float')) - self.assertTrue(message.HasField('optional_double')) - self.assertTrue(message.HasField('optional_bool')) - self.assertTrue(message.HasField('optional_string')) - self.assertTrue(message.HasField('optional_bytes')) - - self.assertTrue(message.HasField('optionalgroup')) - self.assertTrue(message.HasField('optional_nested_message')) - self.assertTrue(message.HasField('optional_foreign_message')) - self.assertTrue(message.HasField('optional_import_message')) - - self.assertTrue(message.optionalgroup.HasField('a')) - self.assertTrue(message.optional_nested_message.HasField('bb')) - self.assertTrue(message.optional_foreign_message.HasField('c')) - self.assertTrue(message.optional_import_message.HasField('d')) - - self.assertTrue(message.HasField('optional_nested_enum')) - self.assertTrue(message.HasField('optional_foreign_enum')) - self.assertTrue(message.HasField('optional_import_enum')) - - self.assertTrue(message.HasField('optional_string_piece')) - self.assertTrue(message.HasField('optional_cord')) - - self.assertEqual(101, message.optional_int32) - self.assertEqual(102, message.optional_int64) - self.assertEqual(103, message.optional_uint32) - self.assertEqual(104, message.optional_uint64) - self.assertEqual(105, message.optional_sint32) - self.assertEqual(106, message.optional_sint64) - self.assertEqual(107, message.optional_fixed32) - self.assertEqual(108, message.optional_fixed64) - self.assertEqual(109, message.optional_sfixed32) - self.assertEqual(110, message.optional_sfixed64) - self.assertEqual(111, message.optional_float) - self.assertEqual(112, message.optional_double) - self.assertEqual(True, message.optional_bool) - self.assertEqual('115', message.optional_string) - self.assertEqual('116', message.optional_bytes) - - self.assertEqual(117, message.optionalgroup.a); - self.assertEqual(118, message.optional_nested_message.bb) - self.assertEqual(119, message.optional_foreign_message.c) - self.assertEqual(120, message.optional_import_message.d) - - self.assertEqual(unittest_pb2.TestAllTypes.BAZ, - message.optional_nested_enum) - self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum) - self.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.optional_import_enum) - - # ----------------------------------------------------------------- - - self.assertEqual(2, len(message.repeated_int32)) - self.assertEqual(2, len(message.repeated_int64)) - self.assertEqual(2, len(message.repeated_uint32)) - self.assertEqual(2, len(message.repeated_uint64)) - self.assertEqual(2, len(message.repeated_sint32)) - self.assertEqual(2, len(message.repeated_sint64)) - self.assertEqual(2, len(message.repeated_fixed32)) - self.assertEqual(2, len(message.repeated_fixed64)) - self.assertEqual(2, len(message.repeated_sfixed32)) - self.assertEqual(2, len(message.repeated_sfixed64)) - self.assertEqual(2, len(message.repeated_float)) - self.assertEqual(2, len(message.repeated_double)) - self.assertEqual(2, len(message.repeated_bool)) - self.assertEqual(2, len(message.repeated_string)) - self.assertEqual(2, len(message.repeated_bytes)) - - self.assertEqual(2, len(message.repeatedgroup)) - self.assertEqual(2, len(message.repeated_nested_message)) - self.assertEqual(2, len(message.repeated_foreign_message)) - self.assertEqual(2, len(message.repeated_import_message)) - self.assertEqual(2, len(message.repeated_nested_enum)) - self.assertEqual(2, len(message.repeated_foreign_enum)) - self.assertEqual(2, len(message.repeated_import_enum)) - - self.assertEqual(2, len(message.repeated_string_piece)) - self.assertEqual(2, len(message.repeated_cord)) - - self.assertEqual(201, message.repeated_int32[0]) - self.assertEqual(202, message.repeated_int64[0]) - self.assertEqual(203, message.repeated_uint32[0]) - self.assertEqual(204, message.repeated_uint64[0]) - self.assertEqual(205, message.repeated_sint32[0]) - self.assertEqual(206, message.repeated_sint64[0]) - self.assertEqual(207, message.repeated_fixed32[0]) - self.assertEqual(208, message.repeated_fixed64[0]) - self.assertEqual(209, message.repeated_sfixed32[0]) - self.assertEqual(210, message.repeated_sfixed64[0]) - self.assertEqual(211, message.repeated_float[0]) - self.assertEqual(212, message.repeated_double[0]) - self.assertEqual(True, message.repeated_bool[0]) - self.assertEqual('215', message.repeated_string[0]) - self.assertEqual('216', message.repeated_bytes[0]) - - self.assertEqual(217, message.repeatedgroup[0].a) - self.assertEqual(218, message.repeated_nested_message[0].bb) - self.assertEqual(219, message.repeated_foreign_message[0].c) - self.assertEqual(220, message.repeated_import_message[0].d) - - self.assertEqual(unittest_pb2.TestAllTypes.BAR, - message.repeated_nested_enum[0]) - self.assertEqual(unittest_pb2.FOREIGN_BAR, - message.repeated_foreign_enum[0]) - self.assertEqual(unittest_import_pb2.IMPORT_BAR, - message.repeated_import_enum[0]) - - self.assertEqual(301, message.repeated_int32[1]) - self.assertEqual(302, message.repeated_int64[1]) - self.assertEqual(303, message.repeated_uint32[1]) - self.assertEqual(304, message.repeated_uint64[1]) - self.assertEqual(305, message.repeated_sint32[1]) - self.assertEqual(306, message.repeated_sint64[1]) - self.assertEqual(307, message.repeated_fixed32[1]) - self.assertEqual(308, message.repeated_fixed64[1]) - self.assertEqual(309, message.repeated_sfixed32[1]) - self.assertEqual(310, message.repeated_sfixed64[1]) - self.assertEqual(311, message.repeated_float[1]) - self.assertEqual(312, message.repeated_double[1]) - self.assertEqual(False, message.repeated_bool[1]) - self.assertEqual('315', message.repeated_string[1]) - self.assertEqual('316', message.repeated_bytes[1]) - - self.assertEqual(317, message.repeatedgroup[1].a) - self.assertEqual(318, message.repeated_nested_message[1].bb) - self.assertEqual(319, message.repeated_foreign_message[1].c) - self.assertEqual(320, message.repeated_import_message[1].d) - - self.assertEqual(unittest_pb2.TestAllTypes.BAZ, - message.repeated_nested_enum[1]) - self.assertEqual(unittest_pb2.FOREIGN_BAZ, - message.repeated_foreign_enum[1]) - self.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.repeated_import_enum[1]) - - # ----------------------------------------------------------------- - - self.assertTrue(message.HasField('default_int32')) - self.assertTrue(message.HasField('default_int64')) - self.assertTrue(message.HasField('default_uint32')) - self.assertTrue(message.HasField('default_uint64')) - self.assertTrue(message.HasField('default_sint32')) - self.assertTrue(message.HasField('default_sint64')) - self.assertTrue(message.HasField('default_fixed32')) - self.assertTrue(message.HasField('default_fixed64')) - self.assertTrue(message.HasField('default_sfixed32')) - self.assertTrue(message.HasField('default_sfixed64')) - self.assertTrue(message.HasField('default_float')) - self.assertTrue(message.HasField('default_double')) - self.assertTrue(message.HasField('default_bool')) - self.assertTrue(message.HasField('default_string')) - self.assertTrue(message.HasField('default_bytes')) - - self.assertTrue(message.HasField('default_nested_enum')) - self.assertTrue(message.HasField('default_foreign_enum')) - self.assertTrue(message.HasField('default_import_enum')) - - self.assertEqual(401, message.default_int32) - self.assertEqual(402, message.default_int64) - self.assertEqual(403, message.default_uint32) - self.assertEqual(404, message.default_uint64) - self.assertEqual(405, message.default_sint32) - self.assertEqual(406, message.default_sint64) - self.assertEqual(407, message.default_fixed32) - self.assertEqual(408, message.default_fixed64) - self.assertEqual(409, message.default_sfixed32) - self.assertEqual(410, message.default_sfixed64) - self.assertEqual(411, message.default_float) - self.assertEqual(412, message.default_double) - self.assertEqual(False, message.default_bool) - self.assertEqual('415', message.default_string) - self.assertEqual('416', message.default_bytes) - - self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) - self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum) - self.assertEqual(unittest_import_pb2.IMPORT_FOO, - message.default_import_enum) +def ExpectAllFieldsSet(test_case, message): + """Check all fields for correct values have after Set*Fields() is called.""" + test_case.assertTrue(message.HasField('optional_int32')) + test_case.assertTrue(message.HasField('optional_int64')) + test_case.assertTrue(message.HasField('optional_uint32')) + test_case.assertTrue(message.HasField('optional_uint64')) + test_case.assertTrue(message.HasField('optional_sint32')) + test_case.assertTrue(message.HasField('optional_sint64')) + test_case.assertTrue(message.HasField('optional_fixed32')) + test_case.assertTrue(message.HasField('optional_fixed64')) + test_case.assertTrue(message.HasField('optional_sfixed32')) + test_case.assertTrue(message.HasField('optional_sfixed64')) + test_case.assertTrue(message.HasField('optional_float')) + test_case.assertTrue(message.HasField('optional_double')) + test_case.assertTrue(message.HasField('optional_bool')) + test_case.assertTrue(message.HasField('optional_string')) + test_case.assertTrue(message.HasField('optional_bytes')) + + test_case.assertTrue(message.HasField('optionalgroup')) + test_case.assertTrue(message.HasField('optional_nested_message')) + test_case.assertTrue(message.HasField('optional_foreign_message')) + test_case.assertTrue(message.HasField('optional_import_message')) + + test_case.assertTrue(message.optionalgroup.HasField('a')) + test_case.assertTrue(message.optional_nested_message.HasField('bb')) + test_case.assertTrue(message.optional_foreign_message.HasField('c')) + test_case.assertTrue(message.optional_import_message.HasField('d')) + + test_case.assertTrue(message.HasField('optional_nested_enum')) + test_case.assertTrue(message.HasField('optional_foreign_enum')) + test_case.assertTrue(message.HasField('optional_import_enum')) + + test_case.assertTrue(message.HasField('optional_string_piece')) + test_case.assertTrue(message.HasField('optional_cord')) + + test_case.assertEqual(101, message.optional_int32) + test_case.assertEqual(102, message.optional_int64) + test_case.assertEqual(103, message.optional_uint32) + test_case.assertEqual(104, message.optional_uint64) + test_case.assertEqual(105, message.optional_sint32) + test_case.assertEqual(106, message.optional_sint64) + test_case.assertEqual(107, message.optional_fixed32) + test_case.assertEqual(108, message.optional_fixed64) + test_case.assertEqual(109, message.optional_sfixed32) + test_case.assertEqual(110, message.optional_sfixed64) + test_case.assertEqual(111, message.optional_float) + test_case.assertEqual(112, message.optional_double) + test_case.assertEqual(True, message.optional_bool) + test_case.assertEqual('115', message.optional_string) + test_case.assertEqual('116', message.optional_bytes) + + test_case.assertEqual(117, message.optionalgroup.a) + test_case.assertEqual(118, message.optional_nested_message.bb) + test_case.assertEqual(119, message.optional_foreign_message.c) + test_case.assertEqual(120, message.optional_import_message.d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, + message.optional_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.optional_import_enum) + + # ----------------------------------------------------------------- + + test_case.assertEqual(2, len(message.repeated_int32)) + test_case.assertEqual(2, len(message.repeated_int64)) + test_case.assertEqual(2, len(message.repeated_uint32)) + test_case.assertEqual(2, len(message.repeated_uint64)) + test_case.assertEqual(2, len(message.repeated_sint32)) + test_case.assertEqual(2, len(message.repeated_sint64)) + test_case.assertEqual(2, len(message.repeated_fixed32)) + test_case.assertEqual(2, len(message.repeated_fixed64)) + test_case.assertEqual(2, len(message.repeated_sfixed32)) + test_case.assertEqual(2, len(message.repeated_sfixed64)) + test_case.assertEqual(2, len(message.repeated_float)) + test_case.assertEqual(2, len(message.repeated_double)) + test_case.assertEqual(2, len(message.repeated_bool)) + test_case.assertEqual(2, len(message.repeated_string)) + test_case.assertEqual(2, len(message.repeated_bytes)) + + test_case.assertEqual(2, len(message.repeatedgroup)) + test_case.assertEqual(2, len(message.repeated_nested_message)) + test_case.assertEqual(2, len(message.repeated_foreign_message)) + test_case.assertEqual(2, len(message.repeated_import_message)) + test_case.assertEqual(2, len(message.repeated_nested_enum)) + test_case.assertEqual(2, len(message.repeated_foreign_enum)) + test_case.assertEqual(2, len(message.repeated_import_enum)) + + test_case.assertEqual(2, len(message.repeated_string_piece)) + test_case.assertEqual(2, len(message.repeated_cord)) + + test_case.assertEqual(201, message.repeated_int32[0]) + test_case.assertEqual(202, message.repeated_int64[0]) + test_case.assertEqual(203, message.repeated_uint32[0]) + test_case.assertEqual(204, message.repeated_uint64[0]) + test_case.assertEqual(205, message.repeated_sint32[0]) + test_case.assertEqual(206, message.repeated_sint64[0]) + test_case.assertEqual(207, message.repeated_fixed32[0]) + test_case.assertEqual(208, message.repeated_fixed64[0]) + test_case.assertEqual(209, message.repeated_sfixed32[0]) + test_case.assertEqual(210, message.repeated_sfixed64[0]) + test_case.assertEqual(211, message.repeated_float[0]) + test_case.assertEqual(212, message.repeated_double[0]) + test_case.assertEqual(True, message.repeated_bool[0]) + test_case.assertEqual('215', message.repeated_string[0]) + test_case.assertEqual('216', message.repeated_bytes[0]) + + test_case.assertEqual(217, message.repeatedgroup[0].a) + test_case.assertEqual(218, message.repeated_nested_message[0].bb) + test_case.assertEqual(219, message.repeated_foreign_message[0].c) + test_case.assertEqual(220, message.repeated_import_message[0].d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[0]) + test_case.assertEqual(unittest_pb2.FOREIGN_BAR, + message.repeated_foreign_enum[0]) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, + message.repeated_import_enum[0]) + + test_case.assertEqual(301, message.repeated_int32[1]) + test_case.assertEqual(302, message.repeated_int64[1]) + test_case.assertEqual(303, message.repeated_uint32[1]) + test_case.assertEqual(304, message.repeated_uint64[1]) + test_case.assertEqual(305, message.repeated_sint32[1]) + test_case.assertEqual(306, message.repeated_sint64[1]) + test_case.assertEqual(307, message.repeated_fixed32[1]) + test_case.assertEqual(308, message.repeated_fixed64[1]) + test_case.assertEqual(309, message.repeated_sfixed32[1]) + test_case.assertEqual(310, message.repeated_sfixed64[1]) + test_case.assertEqual(311, message.repeated_float[1]) + test_case.assertEqual(312, message.repeated_double[1]) + test_case.assertEqual(False, message.repeated_bool[1]) + test_case.assertEqual('315', message.repeated_string[1]) + test_case.assertEqual('316', message.repeated_bytes[1]) + + test_case.assertEqual(317, message.repeatedgroup[1].a) + test_case.assertEqual(318, message.repeated_nested_message[1].bb) + test_case.assertEqual(319, message.repeated_foreign_message[1].c) + test_case.assertEqual(320, message.repeated_import_message[1].d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.repeated_nested_enum[1]) + test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, + message.repeated_foreign_enum[1]) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.repeated_import_enum[1]) + + # ----------------------------------------------------------------- + + test_case.assertTrue(message.HasField('default_int32')) + test_case.assertTrue(message.HasField('default_int64')) + test_case.assertTrue(message.HasField('default_uint32')) + test_case.assertTrue(message.HasField('default_uint64')) + test_case.assertTrue(message.HasField('default_sint32')) + test_case.assertTrue(message.HasField('default_sint64')) + test_case.assertTrue(message.HasField('default_fixed32')) + test_case.assertTrue(message.HasField('default_fixed64')) + test_case.assertTrue(message.HasField('default_sfixed32')) + test_case.assertTrue(message.HasField('default_sfixed64')) + test_case.assertTrue(message.HasField('default_float')) + test_case.assertTrue(message.HasField('default_double')) + test_case.assertTrue(message.HasField('default_bool')) + test_case.assertTrue(message.HasField('default_string')) + test_case.assertTrue(message.HasField('default_bytes')) + + test_case.assertTrue(message.HasField('default_nested_enum')) + test_case.assertTrue(message.HasField('default_foreign_enum')) + test_case.assertTrue(message.HasField('default_import_enum')) + + test_case.assertEqual(401, message.default_int32) + test_case.assertEqual(402, message.default_int64) + test_case.assertEqual(403, message.default_uint32) + test_case.assertEqual(404, message.default_uint64) + test_case.assertEqual(405, message.default_sint32) + test_case.assertEqual(406, message.default_sint64) + test_case.assertEqual(407, message.default_fixed32) + test_case.assertEqual(408, message.default_fixed64) + test_case.assertEqual(409, message.default_sfixed32) + test_case.assertEqual(410, message.default_sfixed64) + test_case.assertEqual(411, message.default_float) + test_case.assertEqual(412, message.default_double) + test_case.assertEqual(False, message.default_bool) + test_case.assertEqual('415', message.default_string) + test_case.assertEqual('416', message.default_bytes) + + test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.default_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_FOO, + message.default_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, + message.default_import_enum) def GoldenFile(filename): """Finds the given golden file and returns a file object representing it.""" @@ -570,21 +569,21 @@ def SetAllPackedFields(message): Args: message: A unittest_pb2.TestPackedTypes instance. """ - message.packed_int32.extend([101, 102]) - message.packed_int64.extend([103, 104]) - message.packed_uint32.extend([105, 106]) - message.packed_uint64.extend([107, 108]) - message.packed_sint32.extend([109, 110]) - message.packed_sint64.extend([111, 112]) - message.packed_fixed32.extend([113, 114]) - message.packed_fixed64.extend([115, 116]) - message.packed_sfixed32.extend([117, 118]) - message.packed_sfixed64.extend([119, 120]) - message.packed_float.extend([121.0, 122.0]) - message.packed_double.extend([122.0, 123.0]) + message.packed_int32.extend([601, 701]) + message.packed_int64.extend([602, 702]) + message.packed_uint32.extend([603, 703]) + message.packed_uint64.extend([604, 704]) + message.packed_sint32.extend([605, 705]) + message.packed_sint64.extend([606, 706]) + message.packed_fixed32.extend([607, 707]) + message.packed_fixed64.extend([608, 708]) + message.packed_sfixed32.extend([609, 709]) + message.packed_sfixed64.extend([610, 710]) + message.packed_float.extend([611.0, 711.0]) + message.packed_double.extend([612.0, 712.0]) message.packed_bool.extend([True, False]) - message.packed_enum.extend([unittest_pb2.FOREIGN_FOO, - unittest_pb2.FOREIGN_BAR]) + message.packed_enum.extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) def SetAllPackedExtensions(message): @@ -596,17 +595,41 @@ def SetAllPackedExtensions(message): extensions = message.Extensions pb2 = unittest_pb2 - extensions[pb2.packed_int32_extension].append(101) - extensions[pb2.packed_int64_extension].append(102) - extensions[pb2.packed_uint32_extension].append(103) - extensions[pb2.packed_uint64_extension].append(104) - extensions[pb2.packed_sint32_extension].append(105) - extensions[pb2.packed_sint64_extension].append(106) - extensions[pb2.packed_fixed32_extension].append(107) - extensions[pb2.packed_fixed64_extension].append(108) - extensions[pb2.packed_sfixed32_extension].append(109) - extensions[pb2.packed_sfixed64_extension].append(110) - extensions[pb2.packed_float_extension].append(111.0) - extensions[pb2.packed_double_extension].append(112.0) - extensions[pb2.packed_bool_extension].append(True) - extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ) + extensions[pb2.packed_int32_extension].extend([601, 701]) + extensions[pb2.packed_int64_extension].extend([602, 702]) + extensions[pb2.packed_uint32_extension].extend([603, 703]) + extensions[pb2.packed_uint64_extension].extend([604, 704]) + extensions[pb2.packed_sint32_extension].extend([605, 705]) + extensions[pb2.packed_sint64_extension].extend([606, 706]) + extensions[pb2.packed_fixed32_extension].extend([607, 707]) + extensions[pb2.packed_fixed64_extension].extend([608, 708]) + extensions[pb2.packed_sfixed32_extension].extend([609, 709]) + extensions[pb2.packed_sfixed64_extension].extend([610, 710]) + extensions[pb2.packed_float_extension].extend([611.0, 711.0]) + extensions[pb2.packed_double_extension].extend([612.0, 712.0]) + extensions[pb2.packed_bool_extension].extend([True, False]) + extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) + + +def SetAllUnpackedFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestUnpackedTypes instance. + """ + message.unpacked_int32.extend([601, 701]) + message.unpacked_int64.extend([602, 702]) + message.unpacked_uint32.extend([603, 703]) + message.unpacked_uint64.extend([604, 704]) + message.unpacked_sint32.extend([605, 705]) + message.unpacked_sint64.extend([606, 706]) + message.unpacked_fixed32.extend([607, 707]) + message.unpacked_fixed64.extend([608, 708]) + message.unpacked_sfixed32.extend([609, 709]) + message.unpacked_sfixed64.extend([610, 710]) + message.unpacked_float.extend([611.0, 711.0]) + message.unpacked_double.extend([612.0, 712.0]) + message.unpacked_bool.extend([True, False]) + message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 0cf2718..e0991cb 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -43,7 +43,7 @@ from google.protobuf import unittest_pb2 from google.protobuf import unittest_mset_pb2 -class TextFormatTest(test_util.GoldenMessageTestCase): +class TextFormatTest(unittest.TestCase): def ReadGolden(self, golden_filename): f = test_util.GoldenFile(golden_filename) golden_lines = f.readlines() @@ -149,7 +149,7 @@ class TextFormatTest(test_util.GoldenMessageTestCase): parsed_message = unittest_pb2.TestAllTypes() text_format.Merge(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - self.ExpectAllFieldsSet(message) + test_util.ExpectAllFieldsSet(self, message) def testMergeAllExtensions(self): message = unittest_pb2.TestAllExtensions() @@ -191,7 +191,8 @@ class TextFormatTest(test_util.GoldenMessageTestCase): 'repeated_double: 1.23e+22\n' 'repeated_double: 1.23e-18\n' 'repeated_string: \n' - '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n' + 'repeated_string: "foo" \'corge\' "grault"') text_format.Merge(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) @@ -201,6 +202,7 @@ class TextFormatTest(test_util.GoldenMessageTestCase): self.assertEqual(1.23e-18, message.repeated_double[2]) self.assertEqual( '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0]) + self.assertEqual('foocorgegrault', message.repeated_string[1]) def testMergeUnknownField(self): message = unittest_pb2.TestAllTypes() @@ -212,12 +214,18 @@ class TextFormatTest(test_util.GoldenMessageTestCase): text_format.Merge, text, message) def testMergeBadExtension(self): - message = unittest_pb2.TestAllTypes() + message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' self.assertRaisesWithMessage( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', text_format.Merge, text, message) + message = unittest_pb2.TestAllTypes() + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' + 'extensions.'), + text_format.Merge, text, message) def testMergeGroupNotClosed(self): message = unittest_pb2.TestAllTypes() @@ -231,6 +239,19 @@ class TextFormatTest(test_util.GoldenMessageTestCase): text_format.ParseError, '1:16 : Expected "}".', text_format.Merge, text, message) + def testMergeEmptyGroup(self): + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: {}' + text_format.Merge(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + message.Clear() + + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: <>' + text_format.Merge(text, message) + self.assertTrue(message.HasField('optionalgroup')) + def testMergeBadEnumValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: BARR' @@ -304,10 +325,10 @@ class TokenizerTest(unittest.TestCase): '{', (tokenizer.ConsumeIdentifier, 'A'), ':', - (tokenizer.ConsumeFloat, float('inf')), + (tokenizer.ConsumeFloat, text_format._INFINITY), (tokenizer.ConsumeIdentifier, 'B'), ':', - (tokenizer.ConsumeFloat, float('-inf')), + (tokenizer.ConsumeFloat, -text_format._INFINITY), (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), @@ -392,6 +413,16 @@ class TokenizerTest(unittest.TestCase): tokenizer = text_format._Tokenizer(text) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) + def testInfNan(self): + # Make sure our infinity and NaN definitions are sound. + self.assertEquals(float, type(text_format._INFINITY)) + self.assertEquals(float, type(text_format._NAN)) + self.assertTrue(text_format._NAN != text_format._NAN) + + inf_times_zero = text_format._INFINITY * 0 + self.assertTrue(inf_times_zero != inf_times_zero) + self.assertTrue(text_format._INFINITY > 0) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index a3bc57f..2b3cd4d 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -192,47 +192,72 @@ TYPE_TO_BYTE_SIZE_FN = { } -# Maps from field type to an unbound Encoder method F, such that -# F(encoder, field_number, value) will append the serialization -# of a value of this type to the encoder. -_Encoder = encoder.Encoder -TYPE_TO_SERIALIZE_METHOD = { - _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, - _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, - _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, - _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, - _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, - _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, - _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, - _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, - _FieldDescriptor.TYPE_STRING: _Encoder.AppendString, - _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, - _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, - _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, - _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, - _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, - _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, - _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64, - _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32, - _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64, +# Maps from field types to encoder constructors. +TYPE_TO_ENCODER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder, + _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder, + _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder, + _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder, + _FieldDescriptor.TYPE_STRING: encoder.StringEncoder, + _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder, + _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder, + _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder, } -TYPE_TO_NOTAG_SERIALIZE_METHOD = { - _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag, - _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag, - _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag, - _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag, - _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag, - _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag, - _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag, - _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag, - _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag, - _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag, - _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag, - _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag, - _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag, - _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag, +# Maps from field types to sizer constructors. +TYPE_TO_SIZER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer, + _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer, + _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer, + _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer, + _FieldDescriptor.TYPE_STRING: encoder.StringSizer, + _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer, + _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer, + _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer, + } + + +# Maps from field type to a decoder constructor. +TYPE_TO_DECODER = { + _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder, + _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder, + _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder, + _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder, + _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder, + _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder, + _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder, + _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder, + _FieldDescriptor.TYPE_STRING: decoder.StringDecoder, + _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder, + _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder, + _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder, + _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder, + _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder, + _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder, + _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder, + _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder, + _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder, } # Maps from field type to expected wiretype. @@ -259,29 +284,3 @@ FIELD_TYPE_TO_WIRE_TYPE = { _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, } - - -# Maps from field type to an unbound Decoder method F, -# such that F(decoder) will read a field of the requested type. -# -# Note that Message and Group are intentionally missing here. -# They're handled by _RecursivelyMerge(). -_Decoder = decoder.Decoder -TYPE_TO_DESERIALIZE_METHOD = { - _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble, - _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat, - _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64, - _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64, - _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32, - _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64, - _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32, - _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool, - _FieldDescriptor.TYPE_STRING: _Decoder.ReadString, - _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes, - _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32, - _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum, - _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32, - _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64, - _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32, - _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64, - } diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py index da6464d..c941fe1 100755 --- a/python/google/protobuf/internal/wire_format.py +++ b/python/google/protobuf/internal/wire_format.py @@ -33,16 +33,17 @@ __author__ = 'robinson@google.com (Will Robinson)' import struct +from google.protobuf import descriptor from google.protobuf import message TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. -_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 +TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 # These numbers identify the wire type of a protocol buffer value. # We use the least-significant TAG_TYPE_BITS bits of the varint-encoded # tag-and-type to store one of these WIRETYPE_* constants. -# These values must match WireType enum in //net/proto2/public/wire_format.h. +# These values must match WireType enum in google/protobuf/wire_format.h. WIRETYPE_VARINT = 0 WIRETYPE_FIXED64 = 1 WIRETYPE_LENGTH_DELIMITED = 2 @@ -93,7 +94,7 @@ def UnpackTag(tag): """The inverse of PackTag(). Given an unsigned 32-bit number, returns a (field_number, wire_type) tuple. """ - return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK) + return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK) def ZigZagEncode(value): @@ -245,3 +246,23 @@ def _VarUInt64ByteSizeNoTag(uint64): if uint64 > UINT64_MAX: raise message.EncodeError('Value out of range: %d' % uint64) return 10 + + +NON_PACKABLE_TYPES = ( + descriptor.FieldDescriptor.TYPE_STRING, + descriptor.FieldDescriptor.TYPE_GROUP, + descriptor.FieldDescriptor.TYPE_MESSAGE, + descriptor.FieldDescriptor.TYPE_BYTES +) + + +def IsTypePackable(field_type): + """Return true iff packable = true is valid for fields of this type. + + Args: + field_type: a FieldDescriptor::Type value. + + Returns: + True iff fields of this type are packable. + """ + return field_type not in NON_PACKABLE_TYPES diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index 9a88bdc..f839847 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -99,7 +99,7 @@ class Message(object): Args: other_msg: Message to copy into the current one. """ - if self == other_msg: + if self is other_msg: return self.Clear() self.MergeFrom(other_msg) @@ -108,6 +108,15 @@ class Message(object): """Clears all data that was set in the message.""" raise NotImplementedError + def SetInParent(self): + """Mark this as present in the parent. + + This normally happens automatically when you assign a field of a + sub-message, but sometimes you want to make the sub-message + present while keeping it empty. If you find yourself using this, + you may want to reconsider your design.""" + raise NotImplementedError + def IsInitialized(self): """Checks if the message is initialized. diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index d65d8b6..5b23803 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -50,9 +50,13 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -import heapq -import threading +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO +import struct import weakref + # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers from google.protobuf.internal import decoder @@ -139,14 +143,26 @@ class GeneratedProtocolMessageType(type): type. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number)) + # We act as a "friend" class of the descriptor, setting # its _concrete_class attribute the first time we use a # given descriptor to initialize a concrete protocol message - # class. + # class. We also attach stuff to each FieldDescriptor for quick + # lookup later on. concrete_class_attr_name = '_concrete_class' if not hasattr(descriptor, concrete_class_attr_name): setattr(descriptor, concrete_class_attr_name, cls) - cls._known_extensions = [] + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + _AddEnumValues(descriptor, cls) _AddInitMethod(descriptor, cls) _AddPropertiesForFields(descriptor, cls) @@ -184,30 +200,33 @@ def _PropertyName(proto_field_name): # return proto_field_name + "_" # return proto_field_name # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. return proto_field_name -def _ValueFieldName(proto_field_name): - """Returns the name of the (internal) instance attribute which objects - should use to store the current value for a given protocol message field. - - Args: - proto_field_name: The protocol message field name, exactly - as it appears (or would appear) in a .proto file. - """ - return '_value_' + proto_field_name +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + if not isinstance(extension_handle, _FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) -def _HasFieldName(proto_field_name): - """Returns the name of the (internal) instance attribute which - objects should use to store a boolean telling whether this field - is explicitly set or not. + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) - Args: - proto_field_name: The protocol message field name, exactly - as it appears (or would appear) in a .proto file. - """ - return '_has_' + proto_field_name + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) def _AddSlots(message_descriptor, dictionary): @@ -218,16 +237,57 @@ def _AddSlots(message_descriptor, dictionary): message_descriptor: A Descriptor instance describing this message type. dictionary: Class dictionary to which we'll add a '__slots__' entry. """ - field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] - field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields - if f.label != _FieldDescriptor.LABEL_REPEATED) - field_names.extend(('Extensions', - '_cached_byte_size', - '_cached_byte_size_dirty', - '_called_transition_to_nonempty', - '_listener', - '_lock', '__weakref__')) - dictionary['__slots__'] = field_names + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor)) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) def _AddClassAttributesForNestedExtensions(descriptor, dictionary): @@ -249,44 +309,51 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) -def _DefaultValueForField(message, field): - """Returns a default value for a field. +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: message: Message instance containing this field, or a weakref proxy of same. - field: FieldDescriptor object for this field. - Returns: A default value for this field. May refer back to |message| - via a weak reference. + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. """ - # TODO(robinson): Only the repeated fields need a reference to 'message' (so - # that they can set the 'has' bit on the containing Message when someone - # append()s a value). We could special-case this, and avoid an extra - # function call on __init__() and Clear() for non-repeated fields. - - # TODO(robinson): Find a better place for the default value assertion in this - # function. No need to repeat them every time the client calls Clear('foo'). - # (We should probably just assert these things once and as early as possible, - # by tightening checking in the descriptor classes.) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( field.default_value)) - listener = _Listener(message, None) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # We can't look at _concrete_class yet since it might not have # been set. (Depends on order in which we initialize the classes). - return containers.RepeatedCompositeFieldContainer( - listener, field.message_type) + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault else: - return containers.RepeatedScalarFieldContainer( - listener, type_checkers.GetTypeChecker(field.cpp_type, field.type)) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - assert field.default_value is None + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + result = message_type._concrete_class() + result._SetListener(message._listener_for_children) + return result + return MakeSubMessageDefault - return field.default_value + def MakeScalarDefault(message): + return field.default_value + return MakeScalarDefault def _AddInitMethod(message_descriptor, cls): @@ -295,21 +362,29 @@ def _AddInitMethod(message_descriptor, cls): def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = False + self._fields = {} + self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() - self._called_transition_to_nonempty = False - # TODO(robinson): We should only create a lock if we really need one - # in this class. - self._lock = threading.Lock() - for field in fields: - default_value = _DefaultValueForField(self, field) - python_field_name = _ValueFieldName(field.name) - setattr(self, python_field_name, default_value) - if field.label != _FieldDescriptor.LABEL_REPEATED: - setattr(self, _HasFieldName(field.name), False) - self.Extensions = _ExtensionDict(self, cls._known_extensions) + self._listener_for_children = _Listener(self) for field_name, field_value in kwargs.iteritems(): field = _GetFieldByName(message_descriptor, field_name) - _MergeFieldOrExtension(self, field, field_value) + if field is None: + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (message_descriptor.name, field_name)) + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + for val in field_value: + copy.add().MergeFrom(val) + else: # Scalar + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + copy.MergeFrom(field_value) + self._fields[field] = copy + else: + self._fields[field] = field_value init.__module__ = None init.__doc__ = None @@ -336,6 +411,11 @@ def _AddPropertiesForFields(descriptor, cls): for field in descriptor.fields: _AddPropertiesForField(field, cls) + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + def _AddPropertiesForField(field, cls): """Adds a public property for a protocol message field. @@ -377,11 +457,22 @@ def _AddPropertiesForRepeatedField(field, cls): cls: The class we're constructing. """ proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) def getter(self): - return getattr(self, python_field_name) + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -407,21 +498,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): cls: The class we're constructing. """ proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + default_value = field.default_value def getter(self): - return getattr(self, python_field_name) + return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name def setter(self, new_value): type_checker.CheckValue(new_value) - setattr(self, has_field_name, True) - self._MarkByteSizeDirty() - self._MaybeCallTransitionToNonemptyCallback() - setattr(self, python_field_name, new_value) + self._fields[field] = new_value + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -444,25 +535,23 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # TODO(robinson): Remove duplication with similar method # for non-repeated scalars. proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) message_type = field.message_type def getter(self): - # TODO(robinson): Appropriately scary note about double-checked locking. - field_value = getattr(self, python_field_name) + field_value = self._fields.get(field) if field_value is None: - self._lock.acquire() - try: - field_value = getattr(self, python_field_name) - if field_value is None: - field_class = message_type._concrete_class - field_value = field_class() - field_value._SetListener(_Listener(self, has_field_name)) - setattr(self, python_field_name, field_value) - finally: - self._lock.release() + # Construct a new object to represent this field. + field_value = message_type._concrete_class() + field_value._SetListener(self._listener_for_children) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) return field_value getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -490,7 +579,27 @@ def _AddStaticMethods(cls): # TODO(robinson): This probably needs to be thread-safe(?) def RegisterExtension(extension_handle): extension_handle.containing_type = cls.DESCRIPTOR - cls._known_extensions.append(extension_handle) + _AttachFieldHelpers(cls, extension_handle) + + # Try to insert our extension, failing if an extension with the same number + # already exists. + actual_handle = cls._extensions_by_number.setdefault( + extension_handle.number, extension_handle) + if actual_handle is not extension_handle: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" with ' + 'field number %d.' % + (extension_handle.full_name, actual_handle.full_name, + cls.DESCRIPTOR.full_name, extension_handle.number)) + + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + handle = extension_handle # avoid line wrapping + if _IsMessageSetExtension(handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + cls.RegisterExtension = staticmethod(RegisterExtension) def FromString(s): @@ -500,115 +609,107 @@ def _AddStaticMethods(cls): cls.FromString = staticmethod(FromString) +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + def _AddListFieldsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - # Ensure that we always list in ascending field-number order. - # For non-extension fields, we can do the sort once, here, at import-time. - # For extensions, we sort on each ListFields() call, though - # we could do better if we have to. - fields = sorted(message_descriptor.fields, key=lambda f: f.number) - has_field_names = (_HasFieldName(f.name) for f in fields) - value_field_names = (_ValueFieldName(f.name) for f in fields) - triplets = zip(has_field_names, value_field_names, fields) - def ListFields(self): - # We need to list all extension and non-extension fields - # together, in sorted order by field number. - - # Step 0: Get an iterator over all "set" non-extension fields, - # sorted by field number. - # This iterator yields (field_number, field_descriptor, value) tuples. - def SortedSetFieldsIter(): - # Note that triplets is already sorted by field number. - for has_field_name, value_field_name, field_descriptor in triplets: - if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: - value = getattr(self, _ValueFieldName(field_descriptor.name)) - if len(value) > 0: - yield (field_descriptor.number, field_descriptor, value) - elif getattr(self, _HasFieldName(field_descriptor.name)): - value = getattr(self, _ValueFieldName(field_descriptor.name)) - yield (field_descriptor.number, field_descriptor, value) - sorted_fields = SortedSetFieldsIter() - - # Step 1: Get an iterator over all "set" extension fields, - # sorted by field number. - # This iterator ALSO yields (field_number, field_descriptor, value) tuples. - # TODO(robinson): It's not necessary to repeat this with each - # serialization call. We can do better. - sorted_extension_fields = sorted( - [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) - - # Step 2: Create a composite iterator that merges the extension- - # and non-extension fields, and that still yields fields in - # sorted order. - all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) - - # Step 3: Strip off the field numbers and return. - return [field[1:] for field in all_set_fields] + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields cls.ListFields = ListFields -def _AddHasFieldMethod(cls): + +def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" + + singular_fields = {} + for field in message_descriptor.fields: + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field + def HasField(self, field_name): try: - return getattr(self, _HasFieldName(field_name)) - except AttributeError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + field = singular_fields[field_name] + except KeyError: + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields cls.HasField = HasField -def _AddClearFieldMethod(cls): +def _AddClearFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def ClearField(self, field_name): - field = _GetFieldByName(self.DESCRIPTOR, field_name) - proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) - default_value = _DefaultValueForField(self, field) - if field.label == _FieldDescriptor.LABEL_REPEATED: - self._MarkByteSizeDirty() - else: - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - old_field_value = getattr(self, python_field_name) - if old_field_value is not None: - # Snip the old object out of the object tree. - old_field_value._SetListener(None) - if getattr(self, has_field_name): - setattr(self, has_field_name, False) - # Set dirty bit on ourself and parents only if - # we're actually changing state. - self._MarkByteSizeDirty() - setattr(self, python_field_name, default_value) + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + if field in self._fields: + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + cls.ClearField = ClearField def _AddClearExtensionMethod(cls): """Helper for _AddMessageMethods().""" def ClearExtension(self, extension_handle): - self.Extensions._ClearExtension(extension_handle) + _VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() cls.ClearExtension = ClearExtension -def _AddClearMethod(cls): +def _AddClearMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def Clear(self): # Clear fields. - fields = self.DESCRIPTOR.fields - for field in fields: - self.ClearField(field.name) - # Clear extensions. - extensions = self.Extensions._ListSetExtensions() - for extension in extensions: - self.ClearExtension(extension[0]) + self._fields = {} + self._Modified() cls.Clear = Clear def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): - return self.Extensions._HasExtension(extension_handle) + _VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields cls.HasExtension = HasExtension @@ -622,26 +723,8 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - # Compare all fields contained directly in this message. - for field_descriptor in message_descriptor.fields: - label = field_descriptor.label - property_name = _PropertyName(field_descriptor.name) - # Non-repeated field equality requires matching "has" bits as well - # as having an equal value. - if label != _FieldDescriptor.LABEL_REPEATED: - self_has = self.HasField(property_name) - other_has = other.HasField(property_name) - if self_has != other_has: - return False - if not self_has: - # If the "has" bit for this field is False, we must stop here. - # Otherwise we will recurse forever on recursively-defined protos. - continue - if getattr(self, property_name) != getattr(other, property_name): - return False + return self.ListFields() == other.ListFields() - # Compare the extensions present in both messages. - return self.Extensions == other.Extensions cls.__eq__ = __eq__ @@ -685,618 +768,202 @@ def _BytesForNonRepeatedElement(value, field_number, field_type): def _AddByteSizeMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def BytesForField(message, field, value): - """Returns the number of bytes required to serialize a single field - in message. The field may be repeated or not, composite or not. - - Args: - message: The Message instance containing a field of the given type. - field: A FieldDescriptor describing the field of interest. - value: The value whose byte size we're interested in. - - Returns: The number of bytes required to serialize the current value - of "field" in "message", including space for tags and any other - necessary information. - """ - - if _MessageSetField(field): - return wire_format.MessageSetItemByteSize(field.number, value) - - field_number, field_type = field.number, field.type - - # Repeated fields. - if field.label == _FieldDescriptor.LABEL_REPEATED: - elements = value - else: - elements = [value] - - if field.GetOptions().packed: - content_size = _ContentBytesForPackedField(message, field, elements) - if content_size: - tag_size = wire_format.TagByteSize(field_number) - length_size = wire_format.Int32ByteSizeNoTag(content_size) - return tag_size + length_size + content_size - else: - return 0 - else: - return sum(_BytesForNonRepeatedElement(element, field_number, field_type) - for element in elements) - - def _ContentBytesForPackedField(self, field, value): - """Returns the number of bytes required to serialize the actual - content of a packed field (not including the tag or the encoding - of the length. - - Args: - self: The Message instance containing a field of the given type. - field: A FieldDescriptor describing the field of interest. - value: The value whose byte size we're interested in. - - Returns: The number of bytes required to serialize the current value - of the packed "field" in "message", excluding space for tags and the - length encoding. - """ - size = sum(_BytesForNonRepeatedElement(element, field.number, field.type) - for element in value) - # In the packed case, there are no per element tags. - return size - wire_format.TagByteSize(field.number) * len(value) - - fields = message_descriptor.fields - has_field_names = (_HasFieldName(f.name) for f in fields) - zipped = zip(has_field_names, fields) - def ByteSize(self): if not self._cached_byte_size_dirty: return self._cached_byte_size size = 0 - # Hardcoded fields first. - for has_field_name, field in zipped: - if (field.label == _FieldDescriptor.LABEL_REPEATED - or getattr(self, has_field_name)): - value = getattr(self, _ValueFieldName(field.name)) - size += BytesForField(self, field, value) - # Extensions next. - for field, value in self.Extensions._ListSetExtensions(): - size += BytesForField(self, field, value) + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) self._cached_byte_size = size self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False return size - cls._ContentBytesForPackedField = _ContentBytesForPackedField cls.ByteSize = ByteSize -def _MessageSetField(field_descriptor): - """Checks if a field should be serialized using the message set wire format. - - Args: - field_descriptor: Descriptor of the field. - - Returns: - True if the field should be serialized using the message set wire format, - false otherwise. - """ - return (field_descriptor.is_extension and - field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and - field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and - field_descriptor.containing_type.GetOptions().message_set_wire_format) - - -def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): - """Appends the serialization of a single value to encoder. - - Args: - value: Value to serialize. - field_number: Field number of this value. - field_descriptor: Descriptor of the field to serialize. - encoder: encoder.Encoder object to which we should serialize this value. - """ - if _MessageSetField(field_descriptor): - encoder.AppendMessageSetItem(field_number, value) - return - - try: - method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] - method(encoder, field_number, value) - except KeyError: - raise message_mod.EncodeError('Unrecognized field type: %d' % - field_descriptor.type) - - -def _ImergeSorted(*streams): - """Merges N sorted iterators into a single sorted iterator. - Each element in streams must be an iterable that yields - its elements in sorted order, and the elements contained - in each stream must all be comparable. - - There may be repeated elements in the component streams or - across the streams; the repeated elements will all be repeated - in the merged iterator as well. - - I believe that the heapq module at HEAD in the Python - sources has a method like this, but for now we roll our own. - """ - iters = [iter(stream) for stream in streams] - heap = [] - for index, it in enumerate(iters): - try: - heap.append((it.next(), index)) - except StopIteration: - pass - heapq.heapify(heap) - - while heap: - smallest_value, idx = heap[0] - yield smallest_value - try: - next_element = iters[idx].next() - heapq.heapreplace(heap, (next_element, idx)) - except StopIteration: - heapq.heappop(heap) - - def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializeToString(self): # Check if the message has all of its required fields set. errors = [] - if not _InternalIsInitialized(self, errors): - raise message_mod.EncodeError('\n'.join(errors)) + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message is missing required fields: ' + + ','.join(self.FindInitializationErrors())) return self.SerializePartialToString() cls.SerializeToString = SerializeToString def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - Encoder = encoder.Encoder def SerializePartialToString(self): - encoder = Encoder() - # We need to serialize all extension and non-extension fields - # together, in sorted order by field number. - for field_descriptor, field_value in self.ListFields(): - if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: - repeated_value = field_value - else: - repeated_value = [field_value] - if field_descriptor.GetOptions().packed: - # First, write the field number and WIRETYPE_LENGTH_DELIMITED. - field_number = field_descriptor.number - encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) - # Next, write the number of bytes. - content_bytes = self._ContentBytesForPackedField( - field_descriptor, field_value) - encoder.AppendInt32NoTag(content_bytes) - # Finally, write the actual values. - try: - method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[ - field_descriptor.type] - for value in repeated_value: - method(encoder, value) - except KeyError: - raise message_mod.EncodeError('Unrecognized field type: %d' % - field_descriptor.type) - else: - for element in repeated_value: - _SerializeValueToEncoder(element, field_descriptor.number, - field_descriptor, encoder) - return encoder.ToString() - + out = StringIO() + self._InternalSerialize(out.write) + return out.getvalue() cls.SerializePartialToString = SerializePartialToString + def InternalSerialize(self, write_bytes): + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value) + cls._InternalSerialize = InternalSerialize -def _WireTypeForFieldType(field_type): - """Given a field type, returns the expected wire type.""" - try: - return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] - except KeyError: - raise message_mod.DecodeError('Unknown field type: %d' % field_type) - - -def _WireTypeForField(field_descriptor): - """Given a field descriptor, returns the expected wire type.""" - if field_descriptor.GetOptions().packed: - return wire_format.WIRETYPE_LENGTH_DELIMITED - else: - return _WireTypeForFieldType(field_descriptor.type) - - -def _RecursivelyMerge(field_number, field_type, decoder, message): - """Decodes a message from decoder into message. - message is either a group or a nested message within some containing - protocol message. If it's a group, we use the group protocol to - deserialize, and if it's a nested message, we use the nested-message - protocol. - - Args: - field_number: The field number of message in its enclosing protocol buffer. - field_type: The field type of message. Must be either TYPE_MESSAGE - or TYPE_GROUP. - decoder: Decoder to read from. - message: Message to deserialize into. - """ - if field_type == _FieldDescriptor.TYPE_MESSAGE: - decoder.ReadMessageInto(message) - elif field_type == _FieldDescriptor.TYPE_GROUP: - decoder.ReadGroupInto(field_number, message) - else: - raise message_mod.DecodeError('Unexpected field type: %d' % field_type) - - -def _DeserializeScalarFromDecoder(field_type, decoder): - """Deserializes a scalar of the requested type from decoder. field_type must - be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. - """ - try: - method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] - return method(decoder) - except KeyError: - raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) - - -def _SkipField(field_number, wire_type, decoder): - """Skips a field with the specified wire type. - - Args: - field_number: Tag number of the field to skip. - wire_type: Wire type of the field to skip. - decoder: Decoder used to deserialize the messsage. It must be positioned - just after reading the the tag and wire type of the field. - """ - if wire_type == wire_format.WIRETYPE_VARINT: - decoder.ReadUInt64() - elif wire_type == wire_format.WIRETYPE_FIXED64: - decoder.ReadFixed64() - elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: - decoder.SkipBytes(decoder.ReadInt32()) - elif wire_type == wire_format.WIRETYPE_START_GROUP: - _SkipGroup(field_number, decoder) - elif wire_type == wire_format.WIRETYPE_END_GROUP: - pass - elif wire_type == wire_format.WIRETYPE_FIXED32: - decoder.ReadFixed32() - else: - raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) - - -def _SkipGroup(group_number, decoder): - """Skips a nested group from the decoder. - - Args: - group_number: Tag number of the group to skip. - decoder: Decoder used to deserialize the message. It must be positioned - exactly at the beginning of the message that should be skipped. - """ - while True: - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if (wire_type == wire_format.WIRETYPE_END_GROUP and - field_number == group_number): - return - _SkipField(field_number, wire_type, decoder) - - -def _DeserializeMessageSetItem(message, decoder): - """Deserializes a message using the message set wire format. - - Args: - message: Message to be parsed to. - decoder: The decoder to be used to deserialize encoded data. Note that the - decoder should be positioned just after reading the START_GROUP tag that - began the messageset item. - """ - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - type_id = decoder.ReadInt32() - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - extension_dict = message.Extensions - extensions_by_number = extension_dict._AllExtensionsByNumber() - if type_id not in extensions_by_number: - _SkipField(field_number, wire_type, decoder) - return - - field_descriptor = extensions_by_number[type_id] - value = extension_dict[field_descriptor] - decoder.ReadMessageInto(value) - # Read the END_GROUP tag. - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - -def _DeserializeOneEntity(message_descriptor, message, decoder): - """Deserializes the next wire entity from decoder into message. - - The next wire entity is either a scalar or a nested message, an - element in a repeated field (the wire encoding in this case is the - same), or a packed repeated field (in this case, the entire repeated - field is read by a single call to _DeserializeOneEntity). - - Args: - message_descriptor: A Descriptor instance describing all fields - in message. - message: The Message instance into which we're decoding our fields. - decoder: The Decoder we're using to deserialize encoded data. - - Returns: The number of bytes read from decoder during this method. - """ - initial_position = decoder.Position() - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - extension_dict = message.Extensions - extensions_by_number = extension_dict._AllExtensionsByNumber() - if field_number in message_descriptor.fields_by_number: - # Non-extension field. - field_descriptor = message_descriptor.fields_by_number[field_number] - value = getattr(message, _PropertyName(field_descriptor.name)) - def nonextension_setter_fn(scalar): - setattr(message, _PropertyName(field_descriptor.name), scalar) - scalar_setter_fn = nonextension_setter_fn - elif field_number in extensions_by_number: - # Extension field. - field_descriptor = extensions_by_number[field_number] - value = extension_dict[field_descriptor] - def extension_setter_fn(scalar): - extension_dict[field_descriptor] = scalar - scalar_setter_fn = extension_setter_fn - elif wire_type == wire_format.WIRETYPE_END_GROUP: - # We assume we're being parsed as the group that's ended. - return 0 - elif (wire_type == wire_format.WIRETYPE_START_GROUP and - field_number == 1 and - message_descriptor.GetOptions().message_set_wire_format): - # A Message Set item. - _DeserializeMessageSetItem(message, decoder) - return decoder.Position() - initial_position - else: - _SkipField(field_number, wire_type, decoder) - return decoder.Position() - initial_position - - # If we reach this point, we've identified the field as either - # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, - # and |value| appropriately. Now actually deserialize the thing. - # - # field_descriptor: Describes the field we're deserializing. - # value: The value currently stored in the field to deserialize. - # Used only if the field is composite and/or repeated. - # scalar_setter_fn: A function F such that F(scalar) will - # set a nonrepeated scalar value for this field. Used only - # if this field is a nonrepeated scalar. - - field_number = field_descriptor.number - expected_wire_type = _WireTypeForField(field_descriptor) - if wire_type != expected_wire_type: - # Need to fill in uninterpreted_bytes. Work for the next CL. - raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') - - property_name = _PropertyName(field_descriptor.name) - label = field_descriptor.label - field_type = field_descriptor.type - cpp_type = field_descriptor.cpp_type - - # Nonrepeated scalar. Just set the field directly. - if (label != _FieldDescriptor.LABEL_REPEATED - and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): - scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - - # Nonrepeated composite. Recursively deserialize. - if label != _FieldDescriptor.LABEL_REPEATED: - composite = value - _RecursivelyMerge(field_number, field_type, decoder, composite) - return decoder.Position() - initial_position - - # Now we know we're dealing with a repeated field of some kind. - element_list = value - - if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: - # Repeated scalar. - if not field_descriptor.GetOptions().packed: - element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - else: - # Packed repeated field. - length = _DeserializeScalarFromDecoder( - _FieldDescriptor.TYPE_INT32, decoder) - content_start = decoder.Position() - while decoder.Position() - content_start < length: - element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - else: - # Repeated composite. - composite = element_list.add() - _RecursivelyMerge(field_number, field_type, decoder, composite) - return decoder.Position() - initial_position - - -def _FieldOrExtensionValues(message, field_or_extension): - """Retrieves the list of values for the specified field or extension. - - The target field or extension can be optional, required or repeated, but it - must have value(s) set. The assumption is that the target field or extension - is set (e.g. _HasFieldOrExtension holds true). - Args: - message: Message which contains the target field or extension. - field_or_extension: Field or extension for which the list of values is - required. Must be an instance of FieldDescriptor. +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except IndexError: + raise message_mod.DecodeError('Truncated message.') + except struct.error, e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString - Returns: - A list of values for the specified field or extension. This list will only - contain a single element if the field is non-repeated. - """ - if field_or_extension.is_extension: - value = message.Extensions[field_or_extension] - else: - value = getattr(message, _ValueFieldName(field_or_extension.name)) - if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: - return [value] - else: - # In this case value is a list or repeated values. - return value + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + self._Modified() + field_dict = self._fields + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder = decoders_by_tag.get(tag_bytes) + if field_decoder is None: + new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + return pos + cls._InternalParse = InternalParse -def _HasFieldOrExtension(message, field_or_extension): - """Checks if a message has the specified field or extension set. +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" - The field or extension specified can be optional, required or repeated. If - it is repeated, this function returns True. Otherwise it checks the has bit - of the field or extension. + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] - Args: - message: Message which contains the target field or extension. - field_or_extension: Field or extension to check. This must be a - FieldDescriptor instance. + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. - Returns: - True if the message has a value set for the specified field or extension, - or if the field or extension is repeated. - """ - if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: - return True - if field_or_extension.is_extension: - return message.HasExtension(field_or_extension) - else: - return message.HasField(field_or_extension.name) + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + Returns: + True iff the specified message has all required fields set. + """ -def _IsFieldOrExtensionInitialized(message, field, errors=None): - """Checks if a message field or extension is initialized. + # Performance is critical so we avoid HasField() and ListFields(). - Args: - message: The message which contains the field or extension. - field: Field or extension to check. This must be a FieldDescriptor instance. - errors: Errors will be appended to it, if set to a meaningful value. + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False - Returns: - True if the field/extension can be considered initialized. - """ - # If the field is required and is not set, it isn't initialized. - if field.label == _FieldDescriptor.LABEL_REQUIRED: - if not _HasFieldOrExtension(message, field): - if errors is not None: - errors.append('Required field %s is not set.' % field.full_name) - return False + for field, value in self._fields.iteritems(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False - # If the field is optional and is not set, or if it - # isn't a submessage then the field is initialized. - if field.label == _FieldDescriptor.LABEL_OPTIONAL: - if not _HasFieldOrExtension(message, field): - return True - if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: return True - # The field is set and is either a single or a repeated submessage. - messages = _FieldOrExtensionValues(message, field) - # If all submessages in this field are initialized, the field is - # considered initialized. - for message in messages: - if not _InternalIsInitialized(message, errors): - return False - return True - + cls.IsInitialized = IsInitialized -def _InternalIsInitialized(message, errors=None): - """Checks if all required fields of a message are set. - - Args: - message: The message to check. - errors: If set, initialization errors will be appended to it. + def FindInitializationErrors(self): + """Finds required fields which are not initialized. - Returns: - True iff the specified message has all required fields set. - """ - fields_and_extensions = [] - fields_and_extensions.extend(message.DESCRIPTOR.fields) - fields_and_extensions.extend( - [extension[0] for extension in message.Extensions._ListSetExtensions()]) - for field_or_extension in fields_and_extensions: - if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): - return False - return True - - -def _AddMergeFromStringMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - Decoder = decoder.Decoder - def MergeFromString(self, serialized): - decoder = Decoder(serialized) - byte_count = 0 - while not decoder.EndOfStream(): - bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) - if not bytes_read: - break - byte_count += bytes_read - return byte_count - cls.MergeFromString = MergeFromString - - -def _AddIsInitializedMethod(cls): - """Adds the IsInitialized method to the protocol message class.""" - cls.IsInitialized = _InternalIsInitialized + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + errors = [] # simplify things -def _MergeFieldOrExtension(destination_msg, field, value): - """Merges a specified message field into another message.""" - property_name = _PropertyName(field.name) - is_extension = field.is_extension + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) - if not is_extension: - destination = getattr(destination_msg, property_name) - elif (field.label == _FieldDescriptor.LABEL_REPEATED or - field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): - destination = destination_msg.Extensions[field] + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = "(%s)" % field.full_name + else: + name = field.name - # Case 1 - a composite field. - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if field.label == _FieldDescriptor.LABEL_REPEATED: - for v in value: - destination.add().MergeFrom(v) - else: - destination.MergeFrom(value) - return + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): + element = value[i] + prefix = "%s[%d]." % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + else: + prefix = name + "." + sub_errors = value.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] - # Case 2 - a repeated field. - if field.label == _FieldDescriptor.LABEL_REPEATED: - for v in value: - destination.append(v) - return + return errors - # Case 3 - a singular field. - if is_extension: - destination_msg.Extensions[field] = value - else: - setattr(destination_msg, property_name, value) + cls.FindInitializationErrors = FindInitializationErrors def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + def MergeFrom(self, msg): assert msg is not self - for field in msg.ListFields(): - _MergeFieldOrExtension(self, field[0], field[1]) + self._Modified() + + fields = self._fields + + for field, value in msg._fields.iteritems(): + if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value cls.MergeFrom = MergeFrom def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) - _AddHasFieldMethod(cls) - _AddClearFieldMethod(cls) - _AddClearExtensionMethod(cls) - _AddClearMethod(cls) - _AddHasExtensionMethod(cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) _AddSetListenerMethod(cls) @@ -1304,31 +971,30 @@ def _AddMessageMethods(message_descriptor, cls): _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) _AddMergeFromStringMethod(message_descriptor, cls) - _AddIsInitializedMethod(cls) + _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) def _AddPrivateHelperMethods(cls): """Adds implementation of private helper methods to cls.""" - def MaybeCallTransitionToNonemptyCallback(self): - """Calls self._listener.TransitionToNonempty() the first time this - method is called. On all subsequent calls, this is a no-op. - """ - if not self._called_transition_to_nonempty: - self._listener.TransitionToNonempty() - self._called_transition_to_nonempty = True - cls._MaybeCallTransitionToNonemptyCallback = ( - MaybeCallTransitionToNonemptyCallback) - - def MarkByteSizeDirty(self): + def Modified(self): """Sets the _cached_byte_size_dirty bit to true, and propagates this to our listener iff this was a state change. """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. if not self._cached_byte_size_dirty: self._cached_byte_size_dirty = True - self._listener.ByteSizeDirty() - cls._MarkByteSizeDirty = MarkByteSizeDirty + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + cls._Modified = Modified + cls.SetInParent = Modified class _Listener(object): @@ -1338,22 +1004,17 @@ class _Listener(object): In order to support semantics like: - foo.bar.baz = 23 + foo.bar.baz.qux = 23 assert foo.HasField('bar') ...child objects must have back references to their parents. This helper class is at the heart of this support. """ - def __init__(self, parent_message, has_field_name): + def __init__(self, parent_message): """Args: - parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() - and _MarkByteSizeDirty() methods we should call when we receive - TransitionToNonempty() and ByteSizeDirty() messages. - has_field_name: The name of the "has" field that we should set in - the parent message when we receive a TransitionToNonempty message, - or None if there's no "has" field to set. (This will be the case - for child objects in "repeated" fields). + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. """ # This listener establishes a back reference from a child (contained) object # to its parent (containing) object. We make this a weak reference to avoid @@ -1363,36 +1024,27 @@ class _Listener(object): self._parent_message_weakref = parent_message else: self._parent_message_weakref = weakref.proxy(parent_message) - self._has_field_name = has_field_name - def TransitionToNonempty(self): + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return try: - if self._has_field_name is not None: - setattr(self._parent_message_weakref, self._has_field_name, True) # Propagate the signal to our parents iff this is the first field set. - self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() + self._parent_message_weakref._Modified() except ReferenceError: # We can get here if a client has kept a reference to a child object, # and is now setting a field on it, but the child's parent has been # garbage-collected. This is not an error. pass - def ByteSizeDirty(self): - try: - self._parent_message_weakref._MarkByteSizeDirty() - except ReferenceError: - # Same as above. - pass - # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. -# TODO(robinson): There's so much similarity between the way that -# extensions behave and the way that normal fields behave that it would -# be really nice to unify more code. It's not immediately obvious -# how to do this, though, and I'd rather get the full functionality -# implemented (and, crucially, get all the tests and specs fleshed out -# and passing), and then come back to this thorny unification problem. # TODO(robinson): Support iteritems()-style iteration over all # extensions with the "has" bits turned on? class _ExtensionDict(object): @@ -1404,250 +1056,85 @@ class _ExtensionDict(object): FieldDescriptors. """ - class _ExtensionListener(object): + def __init__(self, extended_message): + """extended_message: Message instance for which we are the Extensions dict. + """ - """Adapts an _ExtensionDict to behave as a MessageListener.""" + self._extended_message = extended_message - def __init__(self, extension_dict, handle_id): - self._extension_dict = extension_dict - self._handle_id = handle_id + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" - def TransitionToNonempty(self): - self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) + _VerifyExtensionHandle(self._extended_message, extension_handle) - def ByteSizeDirty(self): - self._extension_dict._SubmessageByteSizeBecameDirty() + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result - # TODO(robinson): Somewhere, we need to blow up if people - # try to register two extensions with the same field number. - # (And we need a test for this of course). + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value - def __init__(self, extended_message, known_extensions): - """extended_message: Message instance for which we are the Extensions dict. - known_extensions: Iterable of known extension handles. - These must be FieldDescriptors. - """ - # We keep a weak reference to extended_message, since - # it has a reference to this instance in turn. - self._extended_message = weakref.proxy(extended_message) - # We make a deep copy of known_extensions to avoid any - # thread-safety concerns, since the argument passed in - # is the global (class-level) dict of known extensions for - # this type of message, which could be modified at any time - # via a RegisterExtension() call. - # - # This dict maps from handle id to handle (a FieldDescriptor). - # - # XXX - # TODO(robinson): This isn't good enough. The client could - # instantiate an object in module A, then afterward import - # module B and pass the instance to B.Foo(). If B imports - # an extender of this proto and then tries to use it, B - # will get a KeyError, even though the extension *is* registered - # at the time of use. - # XXX - self._known_extensions = dict((id(e), e) for e in known_extensions) - # Read lock around self._values, which may be modified by multiple - # concurrent readers in the conceptually "const" __getitem__ method. - # So, we grab this lock in every "read-only" method to ensure - # that concurrent read access is safe without external locking. - self._lock = threading.Lock() - # Maps from extension handle ID to current value of that extension. - self._values = {} - # Maps from extension handle ID to a boolean "has" bit, but only - # for non-repeated extension fields. - keys = (id for id, extension in self._known_extensions.iteritems() - if extension.label != _FieldDescriptor.LABEL_REPEATED) - self._has_bits = dict.fromkeys(keys, False) - - self._extensions_by_number = dict( - (f.number, f) for f in self._known_extensions.itervalues()) - - self._extensions_by_name = {} - for extension in self._known_extensions.itervalues(): - if (extension.containing_type.GetOptions().message_set_wire_format and - extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and - extension.message_type == extension.extension_scope and - extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL): - extension_name = extension.message_type.full_name - else: - extension_name = extension.full_name - self._extensions_by_name[extension_name] = extension + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) - def __getitem__(self, extension_handle): - """Returns the current value of the given extension handle.""" - # We don't care as much about keeping critical sections short in the - # extension support, since it's presumably much less of a common case. - self._lock.acquire() - try: - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - if handle_id not in self._values: - self._AddMissingHandle(extension_handle, handle_id) - return self._values[handle_id] - finally: - self._lock.release() + return result def __eq__(self, other): - # We have to grab read locks since we're accessing _values - # in a "const" method. See the comment in the constructor. - if self is other: - return True - self._lock.acquire() - try: - other._lock.acquire() - try: - if self._has_bits != other._has_bits: - return False - # If there's a "has" bit, then only compare values where it is true. - for k, v in self._values.iteritems(): - if self._has_bits.get(k, False) and v != other._values[k]: - return False - return True - finally: - other._lock.release() - finally: - self._lock.release() + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [ field for field in my_fields if field.is_extension ] + other_fields = [ field for field in other_fields if field.is_extension ] + + return my_fields == other_fields def __ne__(self, other): return not self == other # Note that this is only meaningful for non-repeated, scalar extension - # fields. Note also that we may have to call - # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field - # this way, to set any necssary "has" bits in the ancestors of the extended - # message. + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necssary "has" bits in the + # ancestors of the extended message. def __setitem__(self, extension_handle, value): """If extension_handle specifies a non-repeated, scalar extension field, sets the value of that field. """ - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - field = extension_handle # Just shorten the name. - if (field.label == _FieldDescriptor.LABEL_OPTIONAL - and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): - # It's slightly wasteful to lookup the type checker each time, - # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) - type_checker.CheckValue(value) - self._values[handle_id] = value - self._has_bits[handle_id] = True - self._extended_message._MarkByteSizeDirty() - self._extended_message._MaybeCallTransitionToNonemptyCallback() - else: - raise TypeError('Extension is repeated and/or a composite type.') - - def _AddMissingHandle(self, extension_handle, handle_id): - """Helper internal to ExtensionDict.""" - # Special handling for non-repeated message extensions, which (like - # normal fields of this kind) are initialized lazily. - # REQUIRES: _lock already held. - cpp_type = extension_handle.cpp_type - label = extension_handle.label - if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE - and label != _FieldDescriptor.LABEL_REPEATED): - self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) - else: - self._values[handle_id] = _DefaultValueForField( - self._extended_message, extension_handle) - - def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): - """Helper internal to ExtensionDict.""" - # REQUIRES: _lock already held. - value = extension_handle.message_type._concrete_class() - value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) - self._values[handle_id] = value - - def _SubmessageTransitionedToNonempty(self, handle_id): - """Called when a submessage with a given handle id first transitions to - being nonempty. Called by _ExtensionListener. - """ - assert handle_id in self._has_bits - self._has_bits[handle_id] = True - self._extended_message._MaybeCallTransitionToNonemptyCallback() - def _SubmessageByteSizeBecameDirty(self): - """Called whenever a submessage's cached byte size becomes invalid - (goes from being "clean" to being "dirty"). Called by _ExtensionListener. - """ - self._extended_message._MarkByteSizeDirty() - - # We may wish to widen the public interface of Message.Extensions - # to expose some of this private functionality in the future. - # For now, we make all this functionality module-private and just - # implement what we need for serialization/deserialization, - # HasField()/ClearField(), etc. - - def _HasExtension(self, extension_handle): - """Method for internal use by this module. - Returns true iff we "have" this extension in the sense of the - "has" bit being set. - """ - handle_id = id(extension_handle) - # Note that this is different from the other checks. - if handle_id not in self._has_bits: - raise KeyError('Extension not known to this class, or is repeated field.') - return self._has_bits[handle_id] - - # Intentionally pretty similar to ClearField() above. - def _ClearExtension(self, extension_handle): - """Method for internal use by this module. - Clears the specified extension, unsetting its "has" bit. - """ - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - default_value = _DefaultValueForField(self._extended_message, - extension_handle) - if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: - self._extended_message._MarkByteSizeDirty() - else: - cpp_type = extension_handle.cpp_type - if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if handle_id in self._values: - # Future modifications to this object shouldn't set any - # "has" bits here. - self._values[handle_id]._SetListener(None) - if self._has_bits[handle_id]: - self._has_bits[handle_id] = False - self._extended_message._MarkByteSizeDirty() - if handle_id in self._values: - del self._values[handle_id] - - def _ListSetExtensions(self): - """Method for internal use by this module. - - Returns an sequence of all extensions that are currently "set" - in this extension dict. A "set" extension is a repeated extension, - or a non-repeated extension with its "has" bit set. - - The returned sequence contains (field_descriptor, value) pairs, - where value is the current value of the extension with the given - field descriptor. - - The sequence values are in arbitrary order. - """ - self._lock.acquire() # Read-only methods must lock around self._values. - try: - set_extensions = [] - for handle_id, value in self._values.iteritems(): - handle = self._known_extensions[handle_id] - if (handle.label == _FieldDescriptor.LABEL_REPEATED - or self._has_bits[handle_id]): - set_extensions.append((handle, value)) - return set_extensions - finally: - self._lock.release() - - def _AllExtensionsByNumber(self): - """Method for internal use by this module. - - Returns: A dict mapping field_number to (handle, field_descriptor), - for *all* registered extensions for this dict. - """ - return self._extensions_by_number + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker( + extension_handle.cpp_type, extension_handle.type) + type_checker.CheckValue(value) + self._extended_message._fields[extension_handle] = value + self._extended_message._Modified() def _FindExtensionByName(self, name): """Tries to find a known extension with the specified name. @@ -1658,4 +1145,4 @@ class _ExtensionDict(object): Returns: Extension field descriptor. """ - return self._extensions_by_name.get(name, None) + return self._extended_message._extensions_by_name.get(name, None) diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py index dd136c9..180b70e 100755 --- a/python/google/protobuf/service.py +++ b/python/google/protobuf/service.py @@ -28,12 +28,16 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Declares the RPC service interfaces. +"""DEPRECATED: Declares the RPC service interfaces. This module declares the abstract interfaces underlying proto2 RPC services. These are intended to be independent of any particular RPC implementation, so that proto2 services can be used on top of a variety -of implementations. +of implementations. Starting with version 2.3.0, RPC implementations should +not try to build on these, but should instead provide code generator plugins +which generate code specific to the particular RPC implementation. This way +the generated code can be more appropriate for the implementation in use +and can avoid unnecessary layers of indirection. """ __author__ = 'petar@google.com (Petar Petrov)' diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 1cddce6..cc6ac90 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -43,6 +43,12 @@ __all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', 'Merge' ] +# Infinity and NaN are not explicitly supported by Python pre-2.6, and +# float('inf') does not work on Windows (pre-2.6). +_INFINITY = 1e10000 # overflows, thus will actually be infinity. +_NAN = _INFINITY * 0 + + class ParseError(Exception): """Thrown in case of ASCII parsing error.""" @@ -149,6 +155,10 @@ def _MergeField(tokenizer, message): name.append(tokenizer.ConsumeIdentifier()) name = '.'.join(name) + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) field = message.Extensions._FindExtensionByName(name) if not field: raise tokenizer.ParseErrorPreviousToken( @@ -198,6 +208,7 @@ def _MergeField(tokenizer, message): sub_message = message.Extensions[field] else: sub_message = getattr(message, field.name) + sub_message.SetInParent() while not tokenizer.TryConsume(end_token): if tokenizer.AtEnd(): @@ -293,7 +304,7 @@ class _Tokenizer(object): '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier '[0-9+-][0-9a-zA-Z_.+-]*|' # a number '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string - '\'([^\"\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string + '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string _IDENTIFIER = re.compile('\w+') _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(), type_checkers.Int32ValueChecker(), @@ -473,12 +484,12 @@ class _Tokenizer(object): if re.match(self._FLOAT_INFINITY, text): self.NextToken() if text.startswith('-'): - return float('-inf') - return float('inf') + return -_INFINITY + return _INFINITY if re.match(self._FLOAT_NAN, text): self.NextToken() - return float('nan') + return _NAN try: result = float(text) @@ -525,6 +536,18 @@ class _Tokenizer(object): Raises: ParseError: If a byte array value couldn't be consumed. """ + list = [self._ConsumeSingleByteString()] + while len(self.token) > 0 and self.token[0] in ('\'', '"'): + list.append(self._ConsumeSingleByteString()) + return "".join(list) + + def _ConsumeSingleByteString(self): + """Consume one token of a string literal. + + String literals (whether bytes or text) can come in multiple adjacent + tokens which are automatically concatenated, like in C or Python. This + method only consumes one token. + """ text = self.token if len(text) < 1 or text[0] not in ('\'', '"'): raise self._ParseError('Exptected string.') |