diff options
author | Jeff Davidson <jpd@google.com> | 2014-09-15 16:29:06 -0700 |
---|---|---|
committer | Jeff Davidson <jpd@google.com> | 2015-01-15 14:10:53 -0800 |
commit | a3b2a6da25a76f17c73d31def3952feb0fd2296e (patch) | |
tree | 586f7d5e9a7e05af45d0e821188097c0faa96219 /python | |
parent | c7c25812eb19d080087b71e08bfe35aff9f21433 (diff) | |
download | external_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.zip external_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.tar.gz external_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.tar.bz2 |
Update protobuf library from 2.3 to 2.6.
Copied in all files from the open source protobuf project at commit
edc5994525c79cd1919859a370837a6ff7c8e308, removing files which have
been renamed (COPYING.txt -> LICENSE, README.txt -> README.md).
Removed 2.3 prebuilts, which is an approach that will not work due to
incompatibility with the 2.6 runtime.
Merged in micro/nano-specific changes in the following files:
-Android.mk - updated list of C++/Java sources, bumped versions
-java/README.txt - merged in micro/nano instructions, bumped versions
-java/pom.xml - merged in micro/nano build rules, set packaging to jar
-src/Makefile.am - merged in references to micro/nano generators
-src/google/protobuf/compiler/javamicro/javamicro_file.h - imported
google/protobuf/compiler/code_generator.h and removed redundant
OutputDirectory class.
-src/google/protobuf/compiler/javanano/javanano_file.h - same
-Replaced instances of vector with std::vector as needed to get
libprotobuf-cpp-full to compile. Plan to upstream this fix per
discussion with protobuf maintainers.
Reran autogen.sh to update ./configure and associated scripts.
Change-Id: I949d32fb5126f1c05e2a6ed48f6636a4a9b15a48
Diffstat (limited to 'python')
66 files changed, 15256 insertions, 1659 deletions
diff --git a/python/README.txt b/python/README.txt index 96f1a73..adfa46b 100644 --- a/python/README.txt +++ b/python/README.txt @@ -43,9 +43,13 @@ Installation $ protoc --version -4) Run the tests: +4) Build and run the tests: - $ python setup.py test + $ python setup.py build + $ python setup.py google_test + + If you want to test c++ implementation, run: + $ python setup.py test --cpp_implementation If some tests fail, this library may not work correctly on your system. Continue at your own risk. @@ -61,8 +65,13 @@ Installation 5) Install: $ python setup.py install + or: + $ python setup.py install --cpp_implementation This step may require superuser privileges. + NOTE: To use C++ implementation, you need to install C++ protobuf runtime + library of the same version and export the environment variable before this + step. See the "C++ Implementation" section below for more details. Usage ===== @@ -71,3 +80,26 @@ The complete documentation for Protocol Buffers is available via the web at: http://code.google.com/apis/protocolbuffers/ + +C++ Implementation +================== + +The C++ implementation for Python messages is built as a Python extension to +improve the overall protobuf Python performance. + +To use the C++ implementation, you need to: +1) Install the C++ protobuf runtime library, please see instructions in the + parent directory. +2) Export an environment variable: + + $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp + $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 + +You need to export this variable before running setup.py script to build and +install the extension. You must also set the variable at runtime, otherwise +the pure-Python implementation will be used. In a future release, we will +change the default so that C++ implementation is used whenever it is available. +It is strongly recommended to run `python setup.py test` after setting the +variable to "cpp", so the tests will be against C++ implemented Python +messages. + diff --git a/python/ez_setup.py b/python/ez_setup.py index b7a9849..3aec98e 100755 --- a/python/ez_setup.py +++ b/python/ez_setup.py @@ -2,7 +2,7 @@ # This file was obtained from: # http://peak.telecommunity.com/dist/ez_setup.py -# on 2009/4/17. +# on 2011/1/21. """Bootstrap setuptools installation @@ -19,7 +19,7 @@ the appropriate options to ``use_setuptools()``. This file can also be run as a script to install or upgrade setuptools. """ import sys -DEFAULT_VERSION = "0.6c9" +DEFAULT_VERSION = "0.6c11" DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] md5_data = { @@ -33,6 +33,14 @@ md5_data = { 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', + 'setuptools-0.6c10-py2.3.egg': 'ce1e2ab5d3a0256456d9fc13800a7090', + 'setuptools-0.6c10-py2.4.egg': '57d6d9d6e9b80772c59a53a8433a5dd4', + 'setuptools-0.6c10-py2.5.egg': 'de46ac8b1c97c895572e5e8596aeb8c7', + 'setuptools-0.6c10-py2.6.egg': '58ea40aef06da02ce641495523a0b7f5', + 'setuptools-0.6c11-py2.3.egg': '2baeac6e13d414a9d28e7ba5b5a596de', + 'setuptools-0.6c11-py2.4.egg': 'bd639f9b0eac4c42497034dec2ec0c2b', + 'setuptools-0.6c11-py2.5.egg': '64c94f3bf7a72a13ec83e0b24f2749b2', + 'setuptools-0.6c11-py2.6.egg': 'bfa92100bd772d5a213eedd356d64086', 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', @@ -99,6 +107,7 @@ def use_setuptools( except ImportError: return do_download() try: + return do_download() pkg_resources.require("setuptools>="+version); return except pkg_resources.VersionConflict, e: if was_imported: @@ -109,11 +118,11 @@ def use_setuptools( "\n\n(Currently using %r)" ) % (version, e.args[0]) sys.exit(2) - else: - del pkg_resources, sys.modules['pkg_resources'] # reload ok - return do_download() except pkg_resources.DistributionNotFound: - return do_download() + pass + + del pkg_resources, sys.modules['pkg_resources'] # reload ok + return do_download() def download_setuptools( version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, @@ -273,9 +282,3 @@ if __name__=='__main__': update_md5(sys.argv[2:]) else: main(sys.argv[1:]) - - - - - - diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index aa4ab96..555498d 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -28,15 +28,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# TODO(robinson): We probably need to provide deep-copy methods for -# descriptor types. When a FieldDescriptor is passed into -# Descriptor.__init__(), we should make a deep copy and then set -# containing_type on it. Alternatively, we could just get -# rid of containing_type (iit's not needed for reflection.py, at least). +# Needs to stay compatible with Python 2.5 due to GAE. # -# TODO(robinson): Print method? -# -# TODO(robinson): Useful __repr__? +# Copyright 2007 Google Inc. All Rights Reserved. """Descriptors essentially contain exactly the information found in a .proto file, in types that make this information accessible in Python. @@ -44,11 +38,28 @@ file, in types that make this information accessible in Python. __author__ = 'robinson@google.com (Will Robinson)' +from google.protobuf.internal import api_implementation + + +if api_implementation.Type() == 'cpp': + # Used by MakeDescriptor in cpp mode + import os + import uuid + + if api_implementation.Version() == 2: + from google.protobuf.pyext import _message + else: + from google.protobuf.internal import cpp_message + class Error(Exception): """Base error for this module.""" +class TypeTransformationError(Error): + """Error transforming between python proto type and corresponding C++ type.""" + + class DescriptorBase(object): """Descriptors base class. @@ -75,6 +86,18 @@ class DescriptorBase(object): # Does this descriptor have non-default options? self.has_options = options is not None + def _SetOptions(self, options, options_class_name): + """Sets the descriptor's options + + This function is used in generated proto2 files to update descriptor + options. It must not be used outside proto2. + """ + self._options = options + self._options_class_name = options_class_name + + # Does this descriptor have non-default options? + self.has_options = options is not None + def GetOptions(self): """Retrieves descriptor options. @@ -204,13 +227,21 @@ class Descriptor(_NestedDescriptorBase): options: (descriptor_pb2.MessageOptions) Protocol message options or None to use default message options. + oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields + in this message. + oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|, + but indexed by "name" attribute. + file: (FileDescriptor) Reference to file descriptor. """ + # NOTE(tmarek): The file argument redefining a builtin is nothing we can + # fix right now since we don't know how many clients already rely on the + # name of the argument. def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, - is_extendable=True, extension_ranges=None, file=None, - serialized_start=None, serialized_end=None): + is_extendable=True, extension_ranges=None, oneofs=None, + file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -220,7 +251,7 @@ class Descriptor(_NestedDescriptorBase): super(Descriptor, self).__init__( options, 'MessageOptions', name, full_name, file, containing_type, serialized_start=serialized_start, - serialized_end=serialized_start) + serialized_end=serialized_end) # We have fields in addition to fields_by_name and fields_by_number, # so that: @@ -234,6 +265,8 @@ class Descriptor(_NestedDescriptorBase): self.fields_by_name = dict((f.name, f) for f in fields) self.nested_types = nested_types + for nested_type in nested_types: + nested_type.containing_type = self self.nested_types_by_name = dict((t.name, t) for t in nested_types) self.enum_types = enum_types @@ -249,9 +282,28 @@ class Descriptor(_NestedDescriptorBase): self.extensions_by_name = dict((f.name, f) for f in extensions) self.is_extendable = is_extendable self.extension_ranges = extension_ranges + self.oneofs = oneofs if oneofs is not None else [] + self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) + for oneof in self.oneofs: + oneof.containing_type = self - self._serialized_start = serialized_start - self._serialized_end = serialized_end + def EnumValueName(self, enum, value): + """Returns the string name of an enum value. + + This is just a small helper method to simplify a common operation. + + Args: + enum: string name of the Enum. + value: int, value of the enum. + + Returns: + string name of the enum value. + + Raises: + KeyError if either the Enum doesn't exist or the value is not a valid + value for the enum. + """ + return self.enum_types_by_name[enum].values_by_number[value].name def CopyToProto(self, proto): """Copies this to a descriptor_pb2.DescriptorProto. @@ -278,7 +330,7 @@ class FieldDescriptor(DescriptorBase): """Descriptor for a single field in a .proto file. - A FieldDescriptor instance has the following attriubtes: + A FieldDescriptor instance has the following attributes: name: (str) Name of this field, exactly as it appears in .proto. full_name: (str) Name of this field, including containing scope. This is @@ -319,6 +371,9 @@ class FieldDescriptor(DescriptorBase): options: (descriptor_pb2.FieldOptions) Protocol message field options or None to use default field options. + + containing_oneof: (OneofDescriptor) If the field is a member of a oneof + union, contains its descriptor. Otherwise, None. """ # Must be consistent with C++ FieldDescriptor::Type enum in @@ -361,6 +416,27 @@ class FieldDescriptor(DescriptorBase): CPPTYPE_MESSAGE = 10 MAX_CPPTYPE = 10 + _PYTHON_TO_CPP_PROTO_TYPE_MAP = { + TYPE_DOUBLE: CPPTYPE_DOUBLE, + TYPE_FLOAT: CPPTYPE_FLOAT, + TYPE_ENUM: CPPTYPE_ENUM, + TYPE_INT64: CPPTYPE_INT64, + TYPE_SINT64: CPPTYPE_INT64, + TYPE_SFIXED64: CPPTYPE_INT64, + TYPE_UINT64: CPPTYPE_UINT64, + TYPE_FIXED64: CPPTYPE_UINT64, + TYPE_INT32: CPPTYPE_INT32, + TYPE_SFIXED32: CPPTYPE_INT32, + TYPE_SINT32: CPPTYPE_INT32, + TYPE_UINT32: CPPTYPE_UINT32, + TYPE_FIXED32: CPPTYPE_UINT32, + TYPE_BYTES: CPPTYPE_STRING, + TYPE_STRING: CPPTYPE_STRING, + TYPE_BOOL: CPPTYPE_BOOL, + TYPE_MESSAGE: CPPTYPE_MESSAGE, + TYPE_GROUP: CPPTYPE_MESSAGE + } + # Must be consistent with C++ FieldDescriptor::Label enum in # descriptor.h. # @@ -370,10 +446,16 @@ class FieldDescriptor(DescriptorBase): LABEL_REPEATED = 3 MAX_LABEL = 3 + # Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber, + # and kLastReservedNumber in descriptor.h + MAX_FIELD_NUMBER = (1 << 29) - 1 + FIRST_RESERVED_FIELD_NUMBER = 19000 + LAST_RESERVED_FIELD_NUMBER = 19999 + def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, - has_default_value=True): + has_default_value=True, containing_oneof=None): """The arguments are as described in the description of FieldDescriptor attributes above. @@ -396,6 +478,45 @@ class FieldDescriptor(DescriptorBase): self.enum_type = enum_type self.is_extension = is_extension self.extension_scope = extension_scope + self.containing_oneof = containing_oneof + if api_implementation.Type() == 'cpp': + if is_extension: + if api_implementation.Version() == 2: + # pylint: disable=protected-access + self._cdescriptor = ( + _message.Message._GetExtensionDescriptor(full_name)) + # pylint: enable=protected-access + else: + self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) + else: + if api_implementation.Version() == 2: + # pylint: disable=protected-access + self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) + # pylint: enable=protected-access + else: + self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) + else: + self._cdescriptor = None + + @staticmethod + def ProtoTypeToCppProtoType(proto_type): + """Converts from a Python proto type to a C++ Proto Type. + + The Python ProtocolBuffer classes specify both the 'Python' datatype and the + 'C++' datatype - and they're not the same. This helper method should + translate from one to another. + + Args: + proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*) + Returns: + descriptor.FieldDescriptor.CPPTYPE_*, the C++ type. + Raises: + TypeTransformationError: when the Python proto type isn't known. + """ + try: + return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type] + except KeyError: + raise TypeTransformationError('Unknown proto_type: %s' % proto_type) class EnumDescriptor(_NestedDescriptorBase): @@ -434,7 +555,7 @@ class EnumDescriptor(_NestedDescriptorBase): super(EnumDescriptor, self).__init__( options, 'EnumOptions', name, full_name, file, containing_type, serialized_start=serialized_start, - serialized_end=serialized_start) + serialized_end=serialized_end) self.values = values for value in self.values: @@ -442,9 +563,6 @@ class EnumDescriptor(_NestedDescriptorBase): self.values_by_name = dict((v.name, v) for v in values) self.values_by_number = dict((v.number, v) for v in values) - self._serialized_start = serialized_start - self._serialized_end = serialized_end - def CopyToProto(self, proto): """Copies this to a descriptor_pb2.EnumDescriptorProto. @@ -479,6 +597,29 @@ class EnumValueDescriptor(DescriptorBase): self.type = type +class OneofDescriptor(object): + """Descriptor for a oneof field. + + name: (str) Name of the oneof field. + full_name: (str) Full name of the oneof field, including package name. + index: (int) 0-based index giving the order of the oneof field inside + its containing type. + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + fields: (list of FieldDescriptor) The list of field descriptors this + oneof can contain. + """ + + def __init__(self, name, full_name, index, containing_type, fields): + """Arguments are as described in the attribute description above.""" + self.name = name + self.full_name = full_name + self.index = index + self.containing_type = containing_type + self.fields = fields + + class ServiceDescriptor(_NestedDescriptorBase): """Descriptor for a service. @@ -557,20 +698,43 @@ class MethodDescriptor(DescriptorBase): class FileDescriptor(DescriptorBase): """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto. + Note that enum_types_by_name, extensions_by_name, and dependencies + fields are only set by the message_factory module, and not by the + generated proto code. + name: name of file, relative to root of source tree. package: name of the package serialized_pb: (str) Byte string of serialized descriptor_pb2.FileDescriptorProto. + dependencies: List of other FileDescriptors this FileDescriptor depends on. + message_types_by_name: Dict of message names of their descriptors. + enum_types_by_name: Dict of enum names and their descriptors. + extensions_by_name: Dict of extension names and their descriptors. """ - def __init__(self, name, package, options=None, serialized_pb=None): + def __init__(self, name, package, options=None, serialized_pb=None, + dependencies=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') + self.message_types_by_name = {} self.name = name self.package = package self.serialized_pb = serialized_pb + self.enum_types_by_name = {} + self.extensions_by_name = {} + self.dependencies = (dependencies or []) + + if (api_implementation.Type() == 'cpp' and + self.serialized_pb is not None): + if api_implementation.Version() == 2: + # pylint: disable=protected-access + _message.Message._BuildFile(self.serialized_pb) + # pylint: enable=protected-access + else: + cpp_message.BuildFile(self.serialized_pb) + def CopyToProto(self, proto): """Copies this to a descriptor_pb2.FileDescriptorProto. @@ -588,3 +752,98 @@ def _ParseOptions(message, string): """ message.ParseFromString(string) return message + + +def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): + """Make a protobuf Descriptor given a DescriptorProto protobuf. + + Handles nested descriptors. Note that this is limited to the scope of defining + a message inside of another message. Composite fields can currently only be + resolved if the message is defined in the same scope as the field. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: Optional package name for the new message Descriptor (string). + build_file_if_cpp: Update the C++ descriptor pool if api matches. + Set to False on recursion, so no duplicates are created. + Returns: + A Descriptor for protobuf messages. + """ + if api_implementation.Type() == 'cpp' and build_file_if_cpp: + # The C++ implementation requires all descriptors to be backed by the same + # definition in the C++ descriptor pool. To do this, we build a + # FileDescriptorProto with the same definition as this descriptor and build + # it into the pool. + from google.protobuf import descriptor_pb2 + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.message_type.add().MergeFrom(desc_proto) + + # Generate a random name for this proto file to prevent conflicts with + # any imported ones. We need to specify a file name so BuildFile accepts + # our FileDescriptorProto, but it is not important what that file name + # is actually set to. + proto_name = str(uuid.uuid4()) + + if package: + file_descriptor_proto.name = os.path.join(package.replace('.', '/'), + proto_name + '.proto') + file_descriptor_proto.package = package + else: + file_descriptor_proto.name = proto_name + '.proto' + + if api_implementation.Version() == 2: + # pylint: disable=protected-access + _message.Message._BuildFile(file_descriptor_proto.SerializeToString()) + # pylint: enable=protected-access + else: + cpp_message.BuildFile(file_descriptor_proto.SerializeToString()) + + full_message_name = [desc_proto.name] + if package: full_message_name.insert(0, package) + + # Create Descriptors for enum types + enum_types = {} + for enum_proto in desc_proto.enum_type: + full_name = '.'.join(full_message_name + [enum_proto.name]) + enum_desc = EnumDescriptor( + enum_proto.name, full_name, None, [ + EnumValueDescriptor(enum_val.name, ii, enum_val.number) + for ii, enum_val in enumerate(enum_proto.value)]) + enum_types[full_name] = enum_desc + + # Create Descriptors for nested types + nested_types = {} + for nested_proto in desc_proto.nested_type: + full_name = '.'.join(full_message_name + [nested_proto.name]) + # Nested types are just those defined inside of the message, not all types + # used by fields in the message, so no loops are possible here. + nested_desc = MakeDescriptor(nested_proto, + package='.'.join(full_message_name), + build_file_if_cpp=False) + nested_types[full_name] = nested_desc + + fields = [] + for field_proto in desc_proto.field: + full_name = '.'.join(full_message_name + [field_proto.name]) + enum_desc = None + nested_desc = None + if field_proto.HasField('type_name'): + type_name = field_proto.type_name + full_type_name = '.'.join(full_message_name + + [type_name[type_name.rfind('.')+1:]]) + if full_type_name in nested_types: + nested_desc = nested_types[full_type_name] + elif full_type_name in enum_types: + enum_desc = enum_types[full_type_name] + # Else type_name references a non-local type, which isn't implemented + field = FieldDescriptor( + field_proto.name, full_name, field_proto.number - 1, + field_proto.number, field_proto.type, + FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), + field_proto.label, None, nested_desc, enum_desc, None, False, None, + has_default_value=False) + fields.append(field) + + desc_name = '.'.join(full_message_name) + return Descriptor(desc_proto.name, desc_name, None, None, fields, + nested_types.values(), enum_types.values(), []) diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py new file mode 100644 index 0000000..9f5a117 --- /dev/null +++ b/python/google/protobuf/descriptor_database.py @@ -0,0 +1,137 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides a container for DescriptorProtos.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + + +class Error(Exception): + pass + + +class DescriptorDatabaseConflictingDefinitionError(Error): + """Raised when a proto is added with the same name & different descriptor.""" + + +class DescriptorDatabase(object): + """A container accepting FileDescriptorProtos and maps DescriptorProtos.""" + + def __init__(self): + self._file_desc_protos_by_file = {} + self._file_desc_protos_by_symbol = {} + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this database. + + Args: + file_desc_proto: The FileDescriptorProto to add. + Raises: + DescriptorDatabaseException: if an attempt is made to add a proto + with the same name but different definition than an exisiting + proto in the database. + """ + proto_name = file_desc_proto.name + if proto_name not in self._file_desc_protos_by_file: + self._file_desc_protos_by_file[proto_name] = file_desc_proto + elif self._file_desc_protos_by_file[proto_name] != file_desc_proto: + raise DescriptorDatabaseConflictingDefinitionError( + '%s already added, but with different descriptor.' % proto_name) + + package = file_desc_proto.package + for message in file_desc_proto.message_type: + self._file_desc_protos_by_symbol.update( + (name, file_desc_proto) for name in _ExtractSymbols(message, package)) + for enum in file_desc_proto.enum_type: + self._file_desc_protos_by_symbol[ + '.'.join((package, enum.name))] = file_desc_proto + + def FindFileByName(self, name): + """Finds the file descriptor proto by file name. + + Typically the file name is a relative path ending to a .proto file. The + proto with the given name will have to have been added to this database + using the Add method or else an error will be raised. + + Args: + name: The file name to find. + + Returns: + The file descriptor proto matching the name. + + Raises: + KeyError if no file by the given name was added. + """ + + return self._file_desc_protos_by_file[name] + + def FindFileContainingSymbol(self, symbol): + """Finds the file descriptor proto containing the specified symbol. + + The symbol should be a fully qualified name including the file descriptor's + package and any containing messages. Some examples: + + 'some.package.name.Message' + 'some.package.name.Message.NestedEnum' + + The file descriptor proto containing the specified symbol must be added to + this database using the Add method or else an error will be raised. + + Args: + symbol: The fully qualified symbol name. + + Returns: + The file descriptor proto containing the symbol. + + Raises: + KeyError if no file contains the specified symbol. + """ + + return self._file_desc_protos_by_symbol[symbol] + + +def _ExtractSymbols(desc_proto, package): + """Pulls out all the symbols from a descriptor proto. + + Args: + desc_proto: The proto to extract symbols from. + package: The package containing the descriptor type. + + Yields: + The fully qualified name found in the descriptor. + """ + + message_name = '.'.join((package, desc_proto.name)) + yield message_name + for nested_type in desc_proto.nested_type: + for symbol in _ExtractSymbols(nested_type, message_name): + yield symbol + for enum_type in desc_proto.enum_type: + yield '.'.join((message_name, enum_type.name)) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py new file mode 100644 index 0000000..372f458 --- /dev/null +++ b/python/google/protobuf/descriptor_pool.py @@ -0,0 +1,643 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides DescriptorPool to use as a container for proto2 descriptors. + +The DescriptorPool is used in conjection with a DescriptorDatabase to maintain +a collection of protocol buffer descriptors for use when dynamically creating +message types at runtime. + +For most applications protocol buffers should be used via modules generated by +the protocol buffer compiler tool. This should only be used when the type of +protocol buffers used in an application or library cannot be predetermined. + +Below is a straightforward example on how to use this class: + + pool = DescriptorPool() + file_descriptor_protos = [ ... ] + for file_descriptor_proto in file_descriptor_protos: + pool.Add(file_descriptor_proto) + my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') + +The message descriptor can be used in conjunction with the message_factory +module in order to create a protocol buffer class that can be encoded and +decoded. + +If you want to get a Python class for the specified proto, use the +helper functions inside google.protobuf.message_factory +directly instead of this class. +""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import sys + +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import text_encoding + + +def _NormalizeFullyQualifiedName(name): + """Remove leading period from fully-qualified type name. + + Due to b/13860351 in descriptor_database.py, types in the root namespace are + generated with a leading period. This function removes that prefix. + + Args: + name: A str, the fully-qualified symbol name. + + Returns: + A str, the normalized fully-qualified symbol name. + """ + return name.lstrip('.') + + +class DescriptorPool(object): + """A collection of protobufs dynamically constructed by descriptor protos.""" + + def __init__(self, descriptor_db=None): + """Initializes a Pool of proto buffs. + + The descriptor_db argument to the constructor is provided to allow + specialized file descriptor proto lookup code to be triggered on demand. An + example would be an implementation which will read and compile a file + specified in a call to FindFileByName() and not require the call to Add() + at all. Results from this database will be cached internally here as well. + + Args: + descriptor_db: A secondary source of file descriptors. + """ + + self._internal_db = descriptor_database.DescriptorDatabase() + self._descriptor_db = descriptor_db + self._descriptors = {} + self._enum_descriptors = {} + self._file_descriptors = {} + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + file_desc_proto: The FileDescriptorProto to add. + """ + + self._internal_db.Add(file_desc_proto) + + def AddDescriptor(self, desc): + """Adds a Descriptor to the pool, non-recursively. + + If the Descriptor contains nested messages or enums, the caller must + explicitly register them. This method also registers the FileDescriptor + associated with the message. + + Args: + desc: A Descriptor. + """ + if not isinstance(desc, descriptor.Descriptor): + raise TypeError('Expected instance of descriptor.Descriptor.') + + self._descriptors[desc.full_name] = desc + self.AddFileDescriptor(desc.file) + + def AddEnumDescriptor(self, enum_desc): + """Adds an EnumDescriptor to the pool. + + This method also registers the FileDescriptor associated with the message. + + Args: + enum_desc: An EnumDescriptor. + """ + + if not isinstance(enum_desc, descriptor.EnumDescriptor): + raise TypeError('Expected instance of descriptor.EnumDescriptor.') + + self._enum_descriptors[enum_desc.full_name] = enum_desc + self.AddFileDescriptor(enum_desc.file) + + def AddFileDescriptor(self, file_desc): + """Adds a FileDescriptor to the pool, non-recursively. + + If the FileDescriptor contains messages or enums, the caller must explicitly + register them. + + Args: + file_desc: A FileDescriptor. + """ + + if not isinstance(file_desc, descriptor.FileDescriptor): + raise TypeError('Expected instance of descriptor.FileDescriptor.') + self._file_descriptors[file_desc.name] = file_desc + + def FindFileByName(self, file_name): + """Gets a FileDescriptor by file name. + + Args: + file_name: The path to the file to get a descriptor for. + + Returns: + A FileDescriptor for the named file. + + Raises: + KeyError: if the file can not be found in the pool. + """ + + try: + return self._file_descriptors[file_name] + except KeyError: + pass + + try: + file_proto = self._internal_db.FindFileByName(file_name) + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileByName(file_name) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file named %s' % file_name) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def FindFileContainingSymbol(self, symbol): + """Gets the FileDescriptor for the file containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file can not be found in the pool. + """ + + symbol = _NormalizeFullyQualifiedName(symbol) + try: + return self._descriptors[symbol].file + except KeyError: + pass + + try: + return self._enum_descriptors[symbol].file + except KeyError: + pass + + try: + file_proto = self._internal_db.FindFileContainingSymbol(symbol) + except KeyError: + _, error, _ = sys.exc_info() #PY25 compatible for GAE. + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file containing %s' % symbol) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def FindMessageTypeByName(self, full_name): + """Loads the named descriptor from the pool. + + Args: + full_name: The full name of the descriptor to load. + + Returns: + The descriptor for the named type. + """ + + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._descriptors: + self.FindFileContainingSymbol(full_name) + return self._descriptors[full_name] + + def FindEnumTypeByName(self, full_name): + """Loads the named enum descriptor from the pool. + + Args: + full_name: The full name of the enum descriptor to load. + + Returns: + The enum descriptor for the named type. + """ + + full_name = _NormalizeFullyQualifiedName(full_name) + if full_name not in self._enum_descriptors: + self.FindFileContainingSymbol(full_name) + return self._enum_descriptors[full_name] + + def _ConvertFileProtoToFileDescriptor(self, file_proto): + """Creates a FileDescriptor from a proto or returns a cached copy. + + This method also has the side effect of loading all the symbols found in + the file into the appropriate dictionaries in the pool. + + Args: + file_proto: The proto to convert. + + Returns: + A FileDescriptor matching the passed in proto. + """ + + if file_proto.name not in self._file_descriptors: + built_deps = list(self._GetDeps(file_proto.dependency)) + direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + + file_descriptor = descriptor.FileDescriptor( + name=file_proto.name, + package=file_proto.package, + options=file_proto.options, + serialized_pb=file_proto.SerializeToString(), + dependencies=direct_deps) + scope = {} + + # This loop extracts all the message and enum types from all the + # dependencoes of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope) + file_descriptor.message_types_by_name[message_desc.name] = message_desc + + for enum_type in file_proto.enum_type: + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self.MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = extension_desc + + for desc_proto in file_proto.message_type: + self.SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' + + for desc_proto in file_proto.message_type: + desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope) + file_descriptor.message_types_by_name[desc_proto.name] = desc + self.Add(file_proto) + self._file_descriptors[file_proto.name] = file_descriptor + + return self._file_descriptors[file_proto.name] + + def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, + scope=None): + """Adds the proto to the pool in the specified package. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: The package the proto should be located in. + file_desc: The file containing this message. + scope: Dict mapping short and full symbols to message and enum types. + + Returns: + The added descriptor. + """ + + if package: + desc_name = '.'.join((package, desc_proto.name)) + else: + desc_name = desc_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + if scope is None: + scope = {} + + nested = [ + self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope) + for nested in desc_proto.nested_type] + enums = [ + self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) + for enum in desc_proto.enum_type] + fields = [self.MakeFieldDescriptor(field, desc_name, index) + for index, field in enumerate(desc_proto.field)] + extensions = [ + self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True) + for index, extension in enumerate(desc_proto.extension)] + oneofs = [ + descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), + index, None, []) + for index, desc in enumerate(desc_proto.oneof_decl)] + extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] + if extension_ranges: + is_extendable = True + else: + is_extendable = False + desc = descriptor.Descriptor( + name=desc_proto.name, + full_name=desc_name, + filename=file_name, + containing_type=None, + fields=fields, + oneofs=oneofs, + nested_types=nested, + enum_types=enums, + extensions=extensions, + options=desc_proto.options, + is_extendable=is_extendable, + extension_ranges=extension_ranges, + file=file_desc, + serialized_start=None, + serialized_end=None) + for nested in desc.nested_types: + nested.containing_type = desc + for enum in desc.enum_types: + enum.containing_type = desc + for field_index, field_desc in enumerate(desc_proto.field): + if field_desc.HasField('oneof_index'): + oneof_index = field_desc.oneof_index + oneofs[oneof_index].fields.append(fields[field_index]) + fields[field_index].containing_oneof = oneofs[oneof_index] + + scope[_PrefixWithDot(desc_name)] = desc + self._descriptors[desc_name] = desc + return desc + + def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, + containing_type=None, scope=None): + """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. + + Args: + enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. + package: Optional package name for the new message EnumDescriptor. + file_desc: The file containing the enum descriptor. + containing_type: The type containing this enum. + scope: Scope containing available types. + + Returns: + The added descriptor + """ + + if package: + enum_name = '.'.join((package, enum_proto.name)) + else: + enum_name = enum_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + values = [self._MakeEnumValueDescriptor(value, index) + for index, value in enumerate(enum_proto.value)] + desc = descriptor.EnumDescriptor(name=enum_proto.name, + full_name=enum_name, + filename=file_name, + file=file_desc, + values=values, + containing_type=containing_type, + options=enum_proto.options) + scope['.%s' % enum_name] = desc + self._enum_descriptors[enum_name] = desc + return desc + + def MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): + """Creates a field descriptor from a FieldDescriptorProto. + + For message and enum type fields, this method will do a look up + in the pool for the appropriate descriptor for that type. If it + is unavailable, it will fall back to the _source function to + create it. If this type is still unavailable, construction will + fail. + + Args: + field_proto: The proto describing the field. + message_name: The name of the containing message. + index: Index of the field + is_extension: Indication that this field is for an extension. + + Returns: + An initialized FieldDescriptor object + """ + + if message_name: + full_name = '.'.join((message_name, field_proto.name)) + else: + full_name = field_proto.name + + return descriptor.FieldDescriptor( + name=field_proto.name, + full_name=full_name, + index=index, + number=field_proto.number, + type=field_proto.type, + cpp_type=None, + message_type=None, + enum_type=None, + containing_type=None, + label=field_proto.label, + has_default_value=False, + default_value=None, + is_extension=is_extension, + extension_scope=None, + options=field_proto.options) + + def SetAllFieldTypes(self, package, desc_proto, scope): + """Sets all the descriptor's fields's types. + + This method also sets the containing types on any extensions. + + Args: + package: The current package of desc_proto. + desc_proto: The message descriptor to update. + scope: Enclosing scope of available types. + """ + + package = _PrefixWithDot(package) + + main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) + + if package == '.': + nested_package = _PrefixWithDot(desc_proto.name) + else: + nested_package = '.'.join([package, desc_proto.name]) + + for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): + self.SetFieldType(field_proto, field_desc, nested_package, scope) + + for extension_proto, extension_desc in ( + zip(desc_proto.extension, main_desc.extensions)): + extension_desc.containing_type = self._GetTypeFromScope( + nested_package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, nested_package, scope) + + for nested_type in desc_proto.nested_type: + self.SetAllFieldTypes(nested_package, nested_type, scope) + + def SetFieldType(self, field_proto, field_desc, package, scope): + """Sets the field's type, cpp_type, message_type and enum_type. + + Args: + field_proto: Data about the field in proto format. + field_desc: The descriptor to modiy. + package: The package the field's container is in. + scope: Enclosing scope of available types. + """ + if field_proto.type_name: + desc = self._GetTypeFromScope(package, field_proto.type_name, scope) + else: + desc = None + + if not field_proto.HasField('type'): + if isinstance(desc, descriptor.Descriptor): + field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE + else: + field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM + + field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( + field_proto.type) + + if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): + field_desc.message_type = desc + + if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.enum_type = desc + + if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: + field_desc.has_default_value = False + field_desc.default_value = [] + elif field_proto.HasField('default_value'): + field_desc.has_default_value = True + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = float(field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = field_proto.default_value + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = field_proto.default_value.lower() == 'true' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values_by_name[ + field_proto.default_value].index + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = text_encoding.CUnescape( + field_proto.default_value) + else: + field_desc.default_value = int(field_proto.default_value) + else: + field_desc.has_default_value = False + field_desc.default_value = None + + field_desc.type = field_proto.type + + def _MakeEnumValueDescriptor(self, value_proto, index): + """Creates a enum value descriptor object from a enum value proto. + + Args: + value_proto: The proto describing the enum value. + index: The index of the enum value. + + Returns: + An initialized EnumValueDescriptor object. + """ + + return descriptor.EnumValueDescriptor( + name=value_proto.name, + index=index, + number=value_proto.number, + options=value_proto.options, + type=None) + + def _ExtractSymbols(self, descriptors): + """Pulls out all the symbols from descriptor protos. + + Args: + descriptors: The messages to extract descriptors from. + Yields: + A two element tuple of the type name and descriptor object. + """ + + for desc in descriptors: + yield (_PrefixWithDot(desc.full_name), desc) + for symbol in self._ExtractSymbols(desc.nested_types): + yield symbol + for enum in desc.enum_types: + yield (_PrefixWithDot(enum.full_name), enum) + + def _GetDeps(self, dependencies): + """Recursively finds dependencies for file protos. + + Args: + dependencies: The names of the files being depended on. + + Yields: + Each direct and indirect dependency. + """ + + for dependency in dependencies: + dep_desc = self.FindFileByName(dependency) + yield dep_desc + for parent_dep in dep_desc.dependencies: + yield parent_dep + + def _GetTypeFromScope(self, package, type_name, scope): + """Finds a given type name in the current scope. + + Args: + package: The package the proto should be located in. + type_name: The name of the type to be found in the scope. + scope: Dict mapping short and full symbols to message and enum types. + + Returns: + The descriptor for the requested type. + """ + if type_name not in scope: + components = _PrefixWithDot(package).split('.') + while components: + possible_match = '.'.join(components + [type_name]) + if possible_match in scope: + type_name = possible_match + break + else: + components.pop(-1) + return scope[type_name] + + +def _PrefixWithDot(name): + return name if name.startswith('.') else '.%s' % name diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc new file mode 100644 index 0000000..ad6fd9c --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.cc @@ -0,0 +1,139 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +namespace google { +namespace protobuf { +namespace python { + +// Version constant. +// This is either 0 for python, 1 for CPP V1, 2 for CPP V2. +// +// 0 is default and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +// +// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1 +// +// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 +#ifdef PYTHON_PROTO2_CPP_IMPL_V1 +#if PY_MAJOR_VERSION >= 3 +#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." +#endif +static int kImplVersion = 1; +#else +#ifdef PYTHON_PROTO2_CPP_IMPL_V2 +static int kImplVersion = 2; +#else +#ifdef PYTHON_PROTO2_PYTHON_IMPL +static int kImplVersion = 0; +#else + +// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. +// Python 2 still uses the Python version by default until some compatibility +// issues can be worked around. +#if PY_MAJOR_VERSION >= 3 +static int kImplVersion = 2; +#else +static int kImplVersion = 0; +#endif + +#endif // PYTHON_PROTO2_PYTHON_IMPL +#endif // PYTHON_PROTO2_CPP_IMPL_V2 +#endif // PYTHON_PROTO2_CPP_IMPL_V1 + +static const char* kImplVersionName = "api_version"; + +static const char* kModuleName = "_api_implementation"; +static const char kModuleDocstring[] = +"_api_implementation is a module that exposes compile-time constants that\n" +"determine the default API implementation to use for Python proto2.\n" +"\n" +"It complements api_implementation.py by setting defaults using compile-time\n" +"constants defined in C, such that one can set defaults at compilation\n" +"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2)."; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + kModuleName, + kModuleDocstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__api_implementation +#define INITFUNC_ERRORVAL NULL +#else +#define INITFUNC init_api_implementation +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC() { +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&_module); +#else + PyObject *module = Py_InitModule3( + const_cast<char*>(kModuleName), + NULL, + const_cast<char*>(kModuleDocstring)); +#endif + if (module == NULL) { + return INITFUNC_ERRORVAL; + } + + // Adds the module variable "api_version". + if (PyModule_AddIntConstant( + module, + const_cast<char*>(kImplVersionName), + kImplVersion)) +#if PY_MAJOR_VERSION < 3 + return; +#else + { Py_DECREF(module); return NULL; } + + return module; +#endif + } +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py new file mode 100755 index 0000000..cbb8574 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Determine which implementation of the protobuf API is used in this process. +""" + +import os +import sys + +try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import _api_implementation + # The compile-time constants in the _api_implementation module can be used to + # switch to a certain implementation of the Python API at build time. + _api_version = _api_implementation.api_version + del _api_implementation +except ImportError: + _api_version = 0 + +_default_implementation_type = ( + 'python' if _api_version == 0 else 'cpp') +_default_version_str = ( + '1' if _api_version <= 1 else '2') + +# This environment variable can be used to switch to a certain implementation +# of the Python API, overriding the compile-time constants in the +# _api_implementation module. Right now only 'python' and 'cpp' are valid +# values. Any other value will be ignored. +_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', + _default_implementation_type) + +if _implementation_type != 'python': + _implementation_type = 'cpp' + +# This environment variable can be used to switch between the two +# 'cpp' implementations, overriding the compile-time constants in the +# _api_implementation module. Right now only 1 and 2 are valid values. Any other +# value will be ignored. +_implementation_version_str = os.getenv( + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', + _default_version_str) + +if _implementation_version_str not in ('1', '2'): + raise ValueError( + "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + + _implementation_version_str + "' (supported versions: 1, 2)" + ) + +_implementation_version = int(_implementation_version_str) + + +# Usage of this function is discouraged. Clients shouldn't care which +# implementation of the API is in use. Note that there is no guarantee +# that differences between APIs will be maintained. +# Please don't use this function if possible. +def Type(): + return _implementation_type + + +# See comment on 'Type' above. +def Version(): + return _implementation_version diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py new file mode 100644 index 0000000..b2b4128 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation_default_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test that the api_implementation defaults are what we expect.""" + +import os +import sys +# Clear environment implementation settings before the google3 imports. +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) + +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation + + +class ApiImplementationDefaultTest(basetest.TestCase): + + if sys.version_info.major <= 2: + + def testThatPythonIsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('python', api_implementation.Type()) + + else: + + def testThatCppApiV2IsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 5cc7d6d..5797e81 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -72,9 +72,20 @@ class BaseContainer(object): # The concrete classes should define __eq__. return not self == other + def __hash__(self): + raise TypeError('unhashable object') + def __repr__(self): return repr(self._values) + def sort(self, *args, **kwargs): + # Continue to support the old sort_function keyword argument. + # This is expected to be a rare occurrence, so use LBYL to avoid + # the overhead of actually catching KeyError. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._values.sort(*args, **kwargs) + class RepeatedScalarFieldContainer(BaseContainer): @@ -97,15 +108,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self._type_checker.CheckValue(value) - self._values.append(value) + self._values.append(self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" - self._type_checker.CheckValue(value) - self._values.insert(key, value) + self._values.insert(key, self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() @@ -116,8 +125,7 @@ class RepeatedScalarFieldContainer(BaseContainer): new_values = [] for elem in elem_seq: - self._type_checker.CheckValue(elem) - new_values.append(elem) + new_values.append(self._type_checker.CheckValue(elem)) self._values.extend(new_values) self._message_listener.Modified() @@ -135,9 +143,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def __setitem__(self, key, value): """Sets the item on the specified position.""" - self._type_checker.CheckValue(value) - self._values[key] = value - self._message_listener.Modified() + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -147,8 +159,7 @@ class RepeatedScalarFieldContainer(BaseContainer): """Sets the subset of items from between the specified indices.""" new_values = [] for value in values: - self._type_checker.CheckValue(value) - new_values.append(value) + new_values.append(self._type_checker.CheckValue(value)) self._values[start:stop] = new_values self._message_listener.Modified() @@ -198,28 +209,42 @@ class RepeatedCompositeFieldContainer(BaseContainer): super(RepeatedCompositeFieldContainer, self).__init__(message_listener) self._message_descriptor = message_descriptor - def add(self): - new_element = self._message_descriptor._concrete_class() + def add(self, **kwargs): + """Adds a new element at the end of the list and returns it. Keyword + arguments may be used to initialize the element. + """ + new_element = self._message_descriptor._concrete_class(**kwargs) new_element._SetListener(self._message_listener) self._values.append(new_element) if not self._message_listener.dirty: self._message_listener.Modified() return new_element - def MergeFrom(self, other): - """Appends the contents of another repeated field of the same type to this - one, copying each individual message. + def extend(self, elem_seq): + """Extends by appending the given sequence of elements of the same type + as this one, copying each individual message. """ message_class = self._message_descriptor._concrete_class listener = self._message_listener values = self._values - for message in other._values: + for message in elem_seq: new_element = message_class() new_element._SetListener(listener) new_element.MergeFrom(message) values.append(new_element) listener.Modified() + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one, copying each individual message. + """ + self.extend(other._values) + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py new file mode 100755 index 0000000..8eb38ca --- /dev/null +++ b/python/google/protobuf/internal/cpp_message.py @@ -0,0 +1,663 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains helper functions used to create protocol message classes from +Descriptor objects at runtime backed by the protocol buffer C++ API. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + +import copy_reg +import operator +from google.protobuf.internal import _net_proto2___python +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import message + + +_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED +_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL +_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE +_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE + + +def GetDescriptorPool(): + """Creates a new DescriptorPool C++ object.""" + return _net_proto2___python.NewCDescriptorPool() + + +_pool = GetDescriptorPool() + + +def GetFieldDescriptor(full_field_name): + """Searches for a field descriptor given a full field name.""" + return _pool.FindFieldByName(full_field_name) + + +def BuildFile(content): + """Registers a new proto file in the underlying C++ descriptor pool.""" + _net_proto2___python.BuildFile(content) + + +def GetExtensionDescriptor(full_extension_name): + """Searches for extension descriptor given a full field name.""" + return _pool.FindExtensionByName(full_extension_name) + + +def NewCMessage(full_message_name): + """Creates a new C++ protocol message by its name.""" + return _net_proto2___python.NewCMessage(full_message_name) + + +def ScalarProperty(cdescriptor): + """Returns a scalar property for the given descriptor.""" + + def Getter(self): + return self._cmsg.GetScalar(cdescriptor) + + def Setter(self, value): + self._cmsg.SetScalar(cdescriptor, value) + + return property(Getter, Setter) + + +def CompositeProperty(cdescriptor, message_type): + """Returns a Python property the given composite field.""" + + def Getter(self): + sub_message = self._composite_fields.get(cdescriptor.name, None) + if sub_message is None: + cmessage = self._cmsg.NewSubMessage(cdescriptor) + sub_message = message_type._concrete_class(__cmessage=cmessage) + self._composite_fields[cdescriptor.name] = sub_message + return sub_message + + return property(Getter) + + +class RepeatedScalarContainer(object): + """Container for repeated scalar fields.""" + + __slots__ = ['_message', '_cfield_descriptor', '_cmsg'] + + def __init__(self, msg, cfield_descriptor): + self._message = msg + self._cmsg = msg._cmsg + self._cfield_descriptor = cfield_descriptor + + def append(self, value): + self._cmsg.AddRepeatedScalar( + self._cfield_descriptor, value) + + def extend(self, sequence): + for element in sequence: + self.append(element) + + def insert(self, key, value): + values = self[slice(None, None, None)] + values.insert(key, value) + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def remove(self, value): + values = self[slice(None, None, None)] + values.remove(value) + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def __setitem__(self, key, value): + values = self[slice(None, None, None)] + values[key] = value + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def __getitem__(self, key): + return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key) + + def __delitem__(self, key): + self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key) + + def __len__(self): + return len(self[slice(None, None, None)]) + + def __eq__(self, other): + if self is other: + return True + if not operator.isSequenceType(other): + raise TypeError( + 'Can only compare repeated scalar fields against sequences.') + # We are presumably comparing against some other sequence type. + return other == self[slice(None, None, None)] + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + def sort(self, *args, **kwargs): + # Maintain compatibility with the previous interface. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, + sorted(self, *args, **kwargs)) + + +def RepeatedScalarProperty(cdescriptor): + """Returns a Python property the given repeated scalar field.""" + + def Getter(self): + container = self._composite_fields.get(cdescriptor.name, None) + if container is None: + container = RepeatedScalarContainer(self, cdescriptor) + self._composite_fields[cdescriptor.name] = container + return container + + def Setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % cdescriptor.name) + + doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name + return property(Getter, Setter, doc=doc) + + +class RepeatedCompositeContainer(object): + """Container for repeated composite fields.""" + + __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg'] + + def __init__(self, msg, cfield_descriptor, subclass): + self._message = msg + self._cmsg = msg._cmsg + self._subclass = subclass + self._cfield_descriptor = cfield_descriptor + + def add(self, **kwargs): + cmessage = self._cmsg.AddMessage(self._cfield_descriptor) + return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs) + + def extend(self, elem_seq): + """Extends by appending the given sequence of elements of the same type + as this one, copying each individual message. + """ + for message in elem_seq: + self.add().MergeFrom(message) + + def remove(self, value): + # TODO(protocol-devel): This is inefficient as it needs to generate a + # message pointer for each message only to do index(). Move this to a C++ + # extension function. + self.__delitem__(self[slice(None, None, None)].index(value)) + + def MergeFrom(self, other): + for message in other[:]: + self.add().MergeFrom(message) + + def __getitem__(self, key): + cmessages = self._cmsg.GetRepeatedMessage( + self._cfield_descriptor, key) + subclass = self._subclass + if not isinstance(cmessages, list): + return subclass(__cmessage=cmessages, __owner=self._message) + + return [subclass(__cmessage=m, __owner=self._message) for m in cmessages] + + def __delitem__(self, key): + self._cmsg.DeleteRepeatedField( + self._cfield_descriptor, key) + + def __len__(self): + return self._cmsg.FieldLength(self._cfield_descriptor) + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + messages = self[slice(None, None, None)] + other_messages = other[slice(None, None, None)] + return messages == other_messages + + def __hash__(self): + raise TypeError('unhashable object') + + def sort(self, cmp=None, key=None, reverse=False, **kwargs): + # Maintain compatibility with the old interface. + if cmp is None and 'sort_function' in kwargs: + cmp = kwargs.pop('sort_function') + + # The cmp function, if provided, is passed the results of the key function, + # so we only need to wrap one of them. + if key is None: + index_key = self.__getitem__ + else: + index_key = lambda i: key(self[i]) + + # Sort the list of current indexes by the underlying object. + indexes = range(len(self)) + indexes.sort(cmp=cmp, key=index_key, reverse=reverse) + + # Apply the transposition. + for dest, src in enumerate(indexes): + if dest == src: + continue + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) + # Don't swap the same value twice. + indexes[src] = src + + +def RepeatedCompositeProperty(cdescriptor, message_type): + """Returns a Python property for the given repeated composite field.""" + + def Getter(self): + container = self._composite_fields.get(cdescriptor.name, None) + if container is None: + container = RepeatedCompositeContainer( + self, cdescriptor, message_type._concrete_class) + self._composite_fields[cdescriptor.name] = container + return container + + def Setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % cdescriptor.name) + + doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name + return property(Getter, Setter, doc=doc) + + +class ExtensionDict(object): + """Extension dictionary added to each protocol message.""" + + def __init__(self, msg): + self._message = msg + self._cmsg = msg._cmsg + self._values = {} + + def __setitem__(self, extension, value): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_OPTIONAL or + cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + raise TypeError('Extension %r is repeated and/or a composite type.' % ( + extension.full_name,)) + self._cmsg.SetScalar(cdescriptor, value) + self._values[extension] = value + + def __getitem__(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_REPEATED and + cdescriptor.cpp_type != _CPPTYPE_MESSAGE): + return self._cmsg.GetScalar(cdescriptor) + + ext = self._values.get(extension, None) + if ext is not None: + return ext + + ext = self._CreateNewHandle(extension) + self._values[extension] = ext + return ext + + def ClearExtension(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + self._cmsg.ClearFieldByDescriptor(extension._cdescriptor) + if extension in self._values: + del self._values[extension] + + def HasExtension(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + return self._cmsg.HasFieldByDescriptor(extension._cdescriptor) + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._message._extensions_by_name.get(name, None) + + def _CreateNewHandle(self, extension): + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_REPEATED and + cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + cmessage = self._cmsg.NewSubMessage(cdescriptor) + return extension.message_type._concrete_class(__cmessage=cmessage) + + if cdescriptor.label == _LABEL_REPEATED: + if cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + return RepeatedCompositeContainer( + self._message, cdescriptor, extension.message_type._concrete_class) + else: + return RepeatedScalarContainer(self._message, cdescriptor) + # This shouldn't happen! + assert False + return None + + +def NewMessage(bases, message_descriptor, dictionary): + """Creates a new protocol message *class*.""" + _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) + _AddEnumValues(message_descriptor, dictionary) + _AddDescriptors(message_descriptor, dictionary) + return bases + + +def InitMessage(message_descriptor, cls): + """Constructs a new message instance (called before instance's __init__).""" + cls._extensions_by_name = {} + _AddInitMethod(message_descriptor, cls) + _AddMessageMethods(message_descriptor, cls) + _AddPropertiesForExtensions(message_descriptor, cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + +def _AddDescriptors(message_descriptor, dictionary): + """Sets up a new protocol message class dictionary. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__descriptors'] = {} + for field in message_descriptor.fields: + dictionary['__descriptors'][field.name] = GetFieldDescriptor( + field.full_name) + + dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] + + +def _AddEnumValues(message_descriptor, dictionary): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + message_descriptor: Descriptor object for this message type. + dictionary: Class dictionary that should be populated. + """ + for enum_type in message_descriptor.enum_types: + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) + for enum_value in enum_type.values: + dictionary[enum_value.name] = enum_value.number + + +def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary): + """Adds class attributes for the nested extensions.""" + extension_dict = message_descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + + # Create and attach message field properties to the message class. + # This can be done just once per message class, since property setters and + # getters are passed the message instance. + # This makes message instantiation extremely fast, and at the same time it + # doesn't require the creation of property objects for each message instance, + # which saves a lot of memory. + for field in message_descriptor.fields: + field_cdescriptor = cls.__descriptors[field.name] + if field.label == _LABEL_REPEATED: + if field.cpp_type == _CPPTYPE_MESSAGE: + value = RepeatedCompositeProperty(field_cdescriptor, field.message_type) + else: + value = RepeatedScalarProperty(field_cdescriptor) + elif field.cpp_type == _CPPTYPE_MESSAGE: + value = CompositeProperty(field_cdescriptor, field.message_type) + else: + value = ScalarProperty(field_cdescriptor) + setattr(cls, field.name, value) + + # Attach a constant with the field number. + constant_name = field.name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, field.number) + + def Init(self, **kwargs): + """Message constructor.""" + cmessage = kwargs.pop('__cmessage', None) + if cmessage: + self._cmsg = cmessage + else: + self._cmsg = NewCMessage(message_descriptor.full_name) + + # Keep a reference to the owner, as the owner keeps a reference to the + # underlying protocol buffer message. + owner = kwargs.pop('__owner', None) + if owner: + self._owner = owner + + if message_descriptor.is_extendable: + self.Extensions = ExtensionDict(self) + else: + # Reference counting in the C++ code is broken and depends on + # the Extensions reference to keep this object alive during unit + # tests (see b/4856052). Remove this once b/4945904 is fixed. + self._HACK_REFCOUNTS = self + self._composite_fields = {} + + for field_name, field_value in kwargs.iteritems(): + field_cdescriptor = self.__descriptors.get(field_name, None) + if not field_cdescriptor: + raise ValueError('Protocol message has no "%s" field.' % field_name) + if field_cdescriptor.label == _LABEL_REPEATED: + if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + field_name = getattr(self, field_name) + for val in field_value: + field_name.add().MergeFrom(val) + else: + getattr(self, field_name).extend(field_value) + elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + getattr(self, field_name).MergeFrom(field_value) + else: + setattr(self, field_name, field_value) + + Init.__module__ = None + Init.__doc__ = None + cls.__init__ = Init + + +def _IsMessageSetExtension(field): + """Checks if a field is a message set extension.""" + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _LABEL_OPTIONAL) + + +def _AddMessageMethods(message_descriptor, cls): + """Adds the methods to a protocol message class.""" + if message_descriptor.is_extendable: + + def ClearExtension(self, extension): + self.Extensions.ClearExtension(extension) + + def HasExtension(self, extension): + return self.Extensions.HasExtension(extension) + + def HasField(self, field_name): + return self._cmsg.HasField(field_name) + + def ClearField(self, field_name): + child_cmessage = None + if field_name in self._composite_fields: + child_field = self._composite_fields[field_name] + del self._composite_fields[field_name] + + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + child_cmessage = child_field._cmsg + + if child_cmessage is not None: + self._cmsg.ClearField(field_name, child_cmessage) + else: + self._cmsg.ClearField(field_name) + + def Clear(self): + cmessages_to_release = [] + for field_name, child_field in self._composite_fields.iteritems(): + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) + self._composite_fields.clear() + self._cmsg.Clear(cmessages_to_release) + + def IsInitialized(self, errors=None): + if self._cmsg.IsInitialized(): + return True + if errors is not None: + errors.extend(self.FindInitializationErrors()); + return False + + def SerializeToString(self): + if not self.IsInitialized(): + raise message.EncodeError( + 'Message %s is missing required fields: %s' % ( + self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) + return self._cmsg.SerializeToString() + + def SerializePartialToString(self): + return self._cmsg.SerializePartialToString() + + def ParseFromString(self, serialized): + self.Clear() + self.MergeFromString(serialized) + + def MergeFromString(self, serialized): + byte_size = self._cmsg.MergeFromString(serialized) + if byte_size < 0: + raise message.DecodeError('Unable to merge from string.') + return byte_size + + def MergeFrom(self, msg): + if not isinstance(msg, cls): + raise TypeError( + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) + self._cmsg.MergeFrom(msg._cmsg) + + def CopyFrom(self, msg): + self._cmsg.CopyFrom(msg._cmsg) + + def ByteSize(self): + return self._cmsg.ByteSize() + + def SetInParent(self): + return self._cmsg.SetInParent() + + def ListFields(self): + all_fields = [] + field_list = self._cmsg.ListFields() + fields_by_name = cls.DESCRIPTOR.fields_by_name + for is_extension, field_name in field_list: + if is_extension: + extension = cls._extensions_by_name[field_name] + all_fields.append((extension, self.Extensions[extension])) + else: + field_descriptor = fields_by_name[field_name] + all_fields.append( + (field_descriptor, getattr(self, field_name))) + all_fields.sort(key=lambda item: item[0].number) + return all_fields + + def FindInitializationErrors(self): + return self._cmsg.FindInitializationErrors() + + def __str__(self): + return str(self._cmsg) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, self.__class__): + return False + return self.ListFields() == other.ListFields() + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + def __unicode__(self): + # Lazy import to prevent circular import when text_format imports this file. + from google.protobuf import text_format + return text_format.MessageToString(self, as_utf8=True).decode('utf-8') + + # Attach the local methods to the message class. + for key, value in locals().copy().iteritems(): + if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'): + setattr(cls, key, value) + + # Static methods: + + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + if _IsMessageSetExtension(extension_handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(string): + msg = cls() + msg.MergeFromString(string) + return msg + cls.FromString = staticmethod(FromString) + + + +def _AddPropertiesForExtensions(message_descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = message_descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, extension_field.number) diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 461a30c..651ee0d 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. @@ -81,17 +85,26 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF +_NAN = _POS_INF * 0 + + # This is not for optimization, but rather to avoid conflicts with local # variables named "message". _DecodeError = message.DecodeError -def _VarintDecoder(mask): +def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -102,15 +115,18 @@ def _VarintDecoder(mask): """ local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -118,15 +134,17 @@ def _VarintDecoder(mask): return DecodeVarint -def _SignedVarintDecoder(mask): +def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -135,19 +153,23 @@ def _SignedVarintDecoder(mask): result |= ~mask else: result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) def ReadTag(buffer, pos): @@ -161,8 +183,10 @@ def ReadTag(buffer, pos): use that, but not in Python. """ + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes start = pos - while ord(buffer[pos]) & 0x80: + while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -269,10 +293,161 @@ def _StructPackDecoder(wire_type, format): return _SimpleDecoder(wire_type, InnerDecode) +def _FloatDecoder(): + """Returns a decoder for a float field. + + This code works around a bug in struct.unpack for non-finite 32-bit + floating-point values. + """ + + local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 + + def InnerDecode(buffer, pos): + # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. + new_pos = pos + 4 + float_bytes = buffer[pos:new_pos] + + # If this value has all its exponent bits set, then it's non-finite. + # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. + # To avoid that, we parse it specially. + if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 +##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') + and (float_bytes[2:3] >= b('\x80'))): ##PY25 +##!PY25 and (float_bytes[2:3] >= b'\x80')): + # If at least one significand bit is set... + if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 +##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': + return (_NAN, new_pos) + # If sign bit is set... + if float_bytes[3:4] == b('\xFF'): ##PY25 +##!PY25 if float_bytes[3:4] == b'\xFF': + return (_NEG_INF, new_pos) + return (_POS_INF, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<f', float_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) + + +def _DoubleDecoder(): + """Returns a decoder for a double field. + + This code works around a bug in struct.unpack for not-a-number. + """ + + local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 + + def InnerDecode(buffer, pos): + # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. + new_pos = pos + 8 + double_bytes = buffer[pos:new_pos] + + # If this value has all its exponent bits set and at least one significand + # bit set, it's not a number. In Python 2.4, struct.unpack will treat it + # as inf or -inf. To avoid that, we treat it specially. +##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') +##!PY25 and (double_bytes[6:7] >= b'\xF0') +##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 + and (double_bytes[6:7] >= b('\xF0')) ##PY25 + and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 + return (_NAN, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<d', double_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) + + +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + return pos + return DecodeField + + # -------------------------------------------------------------------- -Int32Decoder = EnumDecoder = _SimpleDecoder( +Int32Decoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -294,8 +469,8 @@ Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') -FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f') -DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d') +FloatDecoder = _FloatDecoder() +DoubleDecoder = _DoubleDecoder() BoolDecoder = _ModifiedDecoder( wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) @@ -307,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint local_unicode = unicode + def _ConvertToUnicode(byte_str): + try: + return local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError, e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -321,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -334,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField @@ -503,6 +686,7 @@ def MessageSetItemDecoder(extensions_by_number): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 @@ -541,6 +725,11 @@ def MessageSetItemDecoder(extensions_by_number): # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, + buffer[message_set_item_start:pos])) return pos @@ -552,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number): def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - - while ord(buffer[pos]) & 0x80: + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -620,7 +811,6 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK - local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -633,7 +823,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = local_ord(tag_bytes[0]) & wiretype_mask + wire_type = ord(tag_bytes[0:1]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py new file mode 100644 index 0000000..856f472 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_database.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.apputils import basetest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database + + +class DescriptorDatabaseTest(basetest.TestCase): + + def testAdd(self): + db = descriptor_database.DescriptorDatabase() + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db.Add(file_desc_proto) + + self.assertEquals(file_desc_proto, db.FindFileByName( + 'google/protobuf/internal/factory_test2.proto')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Enum')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py new file mode 100644 index 0000000..7c1ce2e --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -0,0 +1,564 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_pool.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import os +import unittest + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation +from google.protobuf.internal import descriptor_pool_test1_pb2 +from google.protobuf.internal import descriptor_pool_test2_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool + + +class DescriptorPoolTest(basetest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + def testFindFileByName(self): + name1 = 'google/protobuf/internal/factory_test1.proto' + file_desc1 = self.pool.FindFileByName(name1) + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals(name1, file_desc1.name) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + name2 = 'google/protobuf/internal/factory_test2.proto' + file_desc2 = self.pool.FindFileByName(name2) + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals(name2, file_desc2.name) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindFileByName('Does not exist') + + def testFindFileContainingSymbol(self): + file_desc1 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory1Message') + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals('google/protobuf/internal/factory_test1.proto', + file_desc1.name) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + file_desc2 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message') + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals('google/protobuf/internal/factory_test2.proto', + file_desc2.name) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileContainingSymbolFailure(self): + with self.assertRaises(KeyError): + self.pool.FindFileContainingSymbol('Does not exist') + + def testFindMessageTypeByName(self): + msg1 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + self.assertIsInstance(msg1, descriptor.Descriptor) + self.assertEquals('Factory1Message', msg1.name) + self.assertEquals('google.protobuf.python.internal.Factory1Message', + msg1.full_name) + self.assertEquals(None, msg1.containing_type) + + nested_msg1 = msg1.nested_types[0] + self.assertEquals('NestedFactory1Message', nested_msg1.name) + self.assertEquals(msg1, nested_msg1.containing_type) + + nested_enum1 = msg1.enum_types[0] + self.assertEquals('NestedFactory1Enum', nested_enum1.name) + self.assertEquals(msg1, nested_enum1.containing_type) + + self.assertEquals(nested_msg1, msg1.fields_by_name[ + 'nested_factory_1_message'].message_type) + self.assertEquals(nested_enum1, msg1.fields_by_name[ + 'nested_factory_1_enum'].enum_type) + + msg2 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + self.assertIsInstance(msg2, descriptor.Descriptor) + self.assertEquals('Factory2Message', msg2.name) + self.assertEquals('google.protobuf.python.internal.Factory2Message', + msg2.full_name) + self.assertIsNone(msg2.containing_type) + + nested_msg2 = msg2.nested_types[0] + self.assertEquals('NestedFactory2Message', nested_msg2.name) + self.assertEquals(msg2, nested_msg2.containing_type) + + nested_enum2 = msg2.enum_types[0] + self.assertEquals('NestedFactory2Enum', nested_enum2.name) + self.assertEquals(msg2, nested_enum2.containing_type) + + self.assertEquals(nested_msg2, msg2.fields_by_name[ + 'nested_factory_2_message'].message_type) + self.assertEquals(nested_enum2, msg2.fields_by_name[ + 'nested_factory_2_enum'].enum_type) + + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) + self.assertEquals( + 1776, msg2.fields_by_name['int_with_default'].default_value) + + self.assertTrue( + msg2.fields_by_name['double_with_default'].has_default_value) + self.assertEquals( + 9.99, msg2.fields_by_name['double_with_default'].default_value) + + self.assertTrue( + msg2.fields_by_name['string_with_default'].has_default_value) + self.assertEquals( + 'hello world', msg2.fields_by_name['string_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) + self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) + self.assertEquals( + 1, msg2.fields_by_name['enum_with_default'].default_value) + + msg3 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') + self.assertEquals(nested_msg2, msg3) + + self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) + self.assertEquals( + b'a\xfb\x00c', + msg2.fields_by_name['bytes_with_default'].default_value) + + self.assertEqual(1, len(msg2.oneofs)) + self.assertEqual(1, len(msg2.oneofs_by_name)) + self.assertEqual(2, len(msg2.oneofs[0].fields)) + for name in ['oneof_int', 'oneof_string']: + self.assertEqual(msg2.oneofs[0], + msg2.fields_by_name[name].containing_oneof) + self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + + def testFindMessageTypeByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindMessageTypeByName('Does not exist') + + def testFindEnumTypeByName(self): + enum1 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory1Enum') + self.assertIsInstance(enum1, descriptor.EnumDescriptor) + self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) + self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + + nested_enum1 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') + self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) + self.assertEquals( + 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) + + enum2 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory2Enum') + self.assertIsInstance(enum2, descriptor.EnumDescriptor) + self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) + self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) + + nested_enum2 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') + self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) + self.assertEquals( + 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) + + def testFindEnumTypeByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindEnumTypeByName('Does not exist') + + def testUserDefinedDB(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + self.testFindMessageTypeByName() + + def testComplexNesting(self): + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + + +class ProtoFile(object): + + def __init__(self, name, package, messages, dependencies=None): + self.name = name + self.package = package + self.messages = messages + self.dependencies = dependencies or [] + + def CheckFile(self, test, pool): + file_desc = pool.FindFileByName(self.name) + test.assertEquals(self.name, file_desc.name) + test.assertEquals(self.package, file_desc.package) + dependencies_names = [f.name for f in file_desc.dependencies] + test.assertEqual(self.dependencies, dependencies_names) + for name, msg_type in self.messages.items(): + msg_type.CheckType(test, None, name, file_desc) + + +class EnumType(object): + + def __init__(self, values): + self.values = values + + def CheckType(self, test, msg_desc, name, file_desc): + enum_desc = msg_desc.enum_types_by_name[name] + test.assertEqual(name, enum_desc.name) + expected_enum_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_enum_full_name, enum_desc.full_name) + test.assertEqual(msg_desc, enum_desc.containing_type) + test.assertEqual(file_desc, enum_desc.file) + for index, (value, number) in enumerate(self.values): + value_desc = enum_desc.values_by_name[value] + test.assertEqual(value, value_desc.name) + test.assertEqual(index, value_desc.index) + test.assertEqual(number, value_desc.number) + test.assertEqual(enum_desc, value_desc.type) + test.assertIn(value, msg_desc.enum_values_by_name) + + +class MessageType(object): + + def __init__(self, type_dict, field_list, is_extendable=False, + extensions=None): + self.type_dict = type_dict + self.field_list = field_list + self.is_extendable = is_extendable + self.extensions = extensions or [] + + def CheckType(self, test, containing_type_desc, name, file_desc): + if containing_type_desc is None: + desc = file_desc.message_types_by_name[name] + expected_full_name = '.'.join([file_desc.package, name]) + else: + desc = containing_type_desc.nested_types_by_name[name] + expected_full_name = '.'.join([containing_type_desc.full_name, name]) + + test.assertEqual(name, desc.name) + test.assertEqual(expected_full_name, desc.full_name) + test.assertEqual(containing_type_desc, desc.containing_type) + test.assertEqual(desc.file, file_desc) + test.assertEqual(self.is_extendable, desc.is_extendable) + for name, subtype in self.type_dict.items(): + subtype.CheckType(test, desc, name, file_desc) + + for index, (name, field) in enumerate(self.field_list): + field.CheckField(test, desc, name, index) + + for index, (name, field) in enumerate(self.extensions): + field.CheckField(test, desc, name, index) + + +class EnumField(object): + + def __init__(self, number, type_name, default_value): + self.number = number + self.type_name = type_name + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + enum_desc = msg_desc.enum_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_ENUM, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(enum_desc.values_by_name[self.default_value].index, + field_desc.default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(enum_desc, field_desc.enum_type) + + +class MessageField(object): + + def __init__(self, number, type_name): + self.number = number + self.type_name = type_name + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + field_type_desc = msg_desc.nested_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(field_type_desc, field_desc.message_type) + + +class StringField(object): + + def __init__(self, number, default_value): + self.number = number + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_STRING, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_STRING, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(self.default_value, field_desc.default_value) + + +class ExtensionField(object): + + def __init__(self, number, extended_type): + self.number = number + self.extended_type = extended_type + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.extensions_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(index, field_desc.index) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertTrue(field_desc.is_extension) + test.assertEqual(msg_desc, field_desc.extension_scope) + test.assertEqual(msg_desc, field_desc.message_type) + test.assertEqual(self.extended_type, field_desc.containing_type.name) + + +class AddDescriptorTest(basetest.TestCase): + + def _TestMessage(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes').full_name) + + # AddDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') + + pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + # Files are implicitly also indexed when messages are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + + def testMessage(self): + self._TestMessage('') + self._TestMessage('.') + + def _TestEnum(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum').full_name) + + # AddEnumDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') + + pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + # Files are implicitly also indexed when enums are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + + def testEnum(self): + self._TestEnum('') + self._TestEnum('.') + + def testFile(self): + pool = descriptor_pool.DescriptorPool() + pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + # AddFileDescriptor is not recursive; messages and enums within files must + # be explicitly registered. + with self.assertRaises(KeyError): + pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes') + + +TEST1_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test1.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest1': MessageType({ + 'NestedEnum': EnumType([('ALPHA', 1), ('BETA', 2)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('EPSILON', 5), ('ZETA', 6)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('ETA', 7), ('THETA', 8)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ETA')), + ('nested_field', StringField(2, 'theta')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ZETA')), + ('nested_field', StringField(2, 'beta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'BETA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], is_extendable=True), + + 'DescriptorPoolTest2': MessageType({ + 'NestedEnum': EnumType([('GAMMA', 3), ('DELTA', 4)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('IOTA', 9), ('KAPPA', 10)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('LAMBDA', 11), ('MU', 12)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'MU')), + ('nested_field', StringField(2, 'lambda')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'IOTA')), + ('nested_field', StringField(2, 'delta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'GAMMA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ]), + }) + + +TEST2_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test2.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest3': MessageType({ + 'NestedEnum': EnumType([('NU', 13), ('XI', 14)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('OMICRON', 15), ('PI', 16)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('RHO', 17), ('SIGMA', 18)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'RHO')), + ('nested_field', StringField(2, 'sigma')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'PI')), + ('nested_field', StringField(2, 'nu')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'XI')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], extensions=[ + ('descriptor_pool_test', + ExtensionField(1001, 'DescriptorPoolTest1')), + ]), + }, + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto new file mode 100644 index 0000000..c11dcc0 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test1.proto @@ -0,0 +1,94 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + + +message DescriptorPoolTest1 { + extensions 1000 to max; + + enum NestedEnum { + ALPHA = 1; + BETA = 2; + } + + optional NestedEnum nested_enum = 1 [default = BETA]; + + message NestedMessage { + enum NestedEnum { + EPSILON = 5; + ZETA = 6; + } + optional NestedEnum nested_enum = 1 [default = ZETA]; + optional string nested_field = 2 [default = "beta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + ETA = 7; + THETA = 8; + } + optional NestedEnum nested_enum = 1 [default = ETA]; + optional string nested_field = 2 [default = "theta"]; + } + } + + optional NestedMessage nested_message = 2; +} + +message DescriptorPoolTest2 { + enum NestedEnum { + GAMMA = 3; + DELTA = 4; + } + + optional NestedEnum nested_enum = 1 [default = GAMMA]; + + message NestedMessage { + enum NestedEnum { + IOTA = 9; + KAPPA = 10; + } + optional NestedEnum nested_enum = 1 [default = IOTA]; + optional string nested_field = 2 [default = "delta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + LAMBDA = 11; + MU = 12; + } + optional NestedEnum nested_enum = 1 [default = MU]; + optional string nested_field = 2 [default = "lambda"]; + } + } + + optional NestedMessage nested_message = 2; +} diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto new file mode 100644 index 0000000..d97d39b --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -0,0 +1,70 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +import "google/protobuf/internal/descriptor_pool_test1.proto"; + + +message DescriptorPoolTest3 { + + extend DescriptorPoolTest1 { + optional DescriptorPoolTest3 descriptor_pool_test = 1001; + } + + enum NestedEnum { + NU = 13; + XI = 14; + } + + optional NestedEnum nested_enum = 1 [default = XI]; + + message NestedMessage { + enum NestedEnum { + OMICRON = 15; + PI = 16; + } + optional NestedEnum nested_enum = 1 [default = PI]; + optional string nested_field = 2 [default = "nu"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + RHO = 17; + SIGMA = 18; + } + optional NestedEnum nested_enum = 1 [default = RHO]; + optional string nested_field = 2 [default = "sigma"]; + } + } + + optional NestedMessage nested_message = 2; +} + diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py new file mode 100644 index 0000000..b3a1571 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for descriptor.py for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 05c2745..d20d945 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -34,7 +34,8 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 @@ -47,7 +48,7 @@ name: 'TestEmptyMessage' """ -class DescriptorTest(unittest.TestCase): +class DescriptorTest(basetest.TestCase): def setUp(self): self.my_file = descriptor.FileDescriptor( @@ -101,6 +102,15 @@ class DescriptorTest(unittest.TestCase): self.my_method ]) + def testEnumValueName(self): + self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), + 'FOREIGN_FOO') + + self.assertEqual( + self.my_message.enum_types_by_name[ + 'ForeignEnum'].values_by_number[4].name, + self.my_message.EnumValueName('ForeignEnum', 4)) + def testEnumFixups(self): self.assertEqual(self.my_enum, self.my_enum.values[0].type) @@ -125,6 +135,257 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_service.GetOptions(), descriptor_pb2.ServiceOptions()) + def testSimpleCustomOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["field1"] + enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"] + enum_value_descriptor =\ + message_descriptor.enum_values_by_name["ANENUM_VAL2"] + service_descriptor =\ + unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Foo") + + file_options = file_descriptor.GetOptions() + file_opt1 = unittest_custom_options_pb2.file_opt1 + self.assertEqual(9876543210, file_options.Extensions[file_opt1]) + message_options = message_descriptor.GetOptions() + message_opt1 = unittest_custom_options_pb2.message_opt1 + self.assertEqual(-56, message_options.Extensions[message_opt1]) + field_options = field_descriptor.GetOptions() + field_opt1 = unittest_custom_options_pb2.field_opt1 + self.assertEqual(8765432109, field_options.Extensions[field_opt1]) + field_opt2 = unittest_custom_options_pb2.field_opt2 + self.assertEqual(42, field_options.Extensions[field_opt2]) + enum_options = enum_descriptor.GetOptions() + enum_opt1 = unittest_custom_options_pb2.enum_opt1 + self.assertEqual(-789, enum_options.Extensions[enum_opt1]) + enum_value_options = enum_value_descriptor.GetOptions() + enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1 + self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1]) + + service_options = service_descriptor.GetOptions() + service_opt1 = unittest_custom_options_pb2.service_opt1 + self.assertEqual(-9876543210, service_options.Extensions[service_opt1]) + method_options = method_descriptor.GetOptions() + method_opt1 = unittest_custom_options_pb2.method_opt1 + self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, + method_options.Extensions[method_opt1]) + + def testDifferentCustomOptionTypes(self): + kint32min = -2**31 + kint64min = -2**63 + kint32max = 2**31 - 1 + kint64max = 2**63 - 1 + kuint32max = 2**32 - 1 + kuint64max = 2**64 - 1 + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(False, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(True, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(-100, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertAlmostEqual(12.3456789, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(1.234567890123456789, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + self.assertEqual("Hello, \"World\"", message_options.Extensions[ + unittest_custom_options_pb2.string_opt]) + self.assertEqual(b"Hello\0World", message_options.Extensions[ + unittest_custom_options_pb2.bytes_opt]) + dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum + self.assertEqual( + dummy_enum.TEST_OPTION_ENUM_TYPE2, + message_options.Extensions[unittest_custom_options_pb2.enum_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(-12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(-154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + def testComplexExtensionOptions(self): + descriptor =\ + unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR + options = descriptor.GetOptions() + self.assertEqual(42, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].foo) + self.assertEqual(324, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(876, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(987, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].baz) + self.assertEqual(654, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.grault]) + self.assertEqual(743, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.foo) + self.assertEqual(1999, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2008, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(741, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].foo) + self.assertEqual(1998, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2121, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(1971, options.Extensions[ + unittest_custom_options_pb2.ComplexOptionType2 + .ComplexOptionType4.complex_opt4].waldo) + self.assertEqual(321, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].fred.waldo) + self.assertEqual(9, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].qux) + self.assertEqual(22, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].complexoptiontype5.plugh) + self.assertEqual(24, options.Extensions[ + unittest_custom_options_pb2.complexopt6].xyzzy) + + # Check that aggregate options were parsed and saved correctly in + # the appropriate descriptors. + def testAggregateOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["fieldname"] + enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR + enum_value_descriptor = enum_descriptor.values_by_name["VALUE"] + service_descriptor =\ + unittest_custom_options_pb2.AggregateService.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Method") + + # Tests for the different types of data embedded in fileopt + file_options = file_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fileopt] + self.assertEqual(100, file_options.i) + self.assertEqual("FileAnnotation", file_options.s) + self.assertEqual("NestedFileAnnotation", file_options.sub.s) + self.assertEqual("FileExtensionAnnotation", file_options.file.Extensions[ + unittest_custom_options_pb2.fileopt].s) + self.assertEqual("EmbeddedMessageSetElement", file_options.mset.Extensions[ + unittest_custom_options_pb2.AggregateMessageSetElement + .message_set_extension].s) + + # Simple tests for all the other types of annotations + self.assertEqual( + "MessageAnnotation", + message_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.msgopt].s) + self.assertEqual( + "FieldAnnotation", + field_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fieldopt].s) + self.assertEqual( + "EnumAnnotation", + enum_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumopt].s) + self.assertEqual( + "EnumValueAnnotation", + enum_value_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumvalopt].s) + self.assertEqual( + "ServiceAnnotation", + service_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.serviceopt].s) + self.assertEqual( + "MethodAnnotation", + method_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.methodopt].s) + + def testNestedOptions(self): + nested_message =\ + unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR + self.assertEqual(1001, nested_message.GetOptions().Extensions[ + unittest_custom_options_pb2.message_opt1]) + nested_field = nested_message.fields_by_name["nested_field"] + self.assertEqual(1002, nested_field.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt1]) + outer_message =\ + unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR + nested_enum = outer_message.enum_types_by_name["NestedEnum"] + self.assertEqual(1003, nested_enum.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_opt1]) + nested_enum_value = outer_message.enum_values_by_name["NESTED_ENUM_VALUE"] + self.assertEqual(1004, nested_enum_value.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_value_opt1]) + nested_extension = outer_message.extensions_by_name["nested_extension"] + self.assertEqual(1005, nested_extension.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt2]) + def testFileDescriptorReferences(self): self.assertEqual(self.my_enum.file, self.my_file) self.assertEqual(self.my_message.file, self.my_file) @@ -134,7 +395,7 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.package, 'protobuf_unittest') -class DescriptorCopyToProtoTest(unittest.TestCase): +class DescriptorCopyToProtoTest(basetest.TestCase): """Tests for CopyTo functions of Descriptor.""" def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): @@ -269,45 +530,49 @@ class DescriptorCopyToProtoTest(unittest.TestCase): descriptor_pb2.DescriptorProto, TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) - def testCopyToProto_FileDescriptor(self): - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" - name: 'google/protobuf/unittest_import.proto' - package: 'protobuf_unittest_import' - message_type: < - name: 'ImportMessage' - field: < - name: 'd' - number: 1 - label: 1 # Optional - type: 5 # TYPE_INT32 - > - > - """ + - """enum_type: < - name: 'ImportEnum' - value: < - name: 'IMPORT_FOO' - number: 7 - > - value: < - name: 'IMPORT_BAR' - number: 8 - > - value: < - name: 'IMPORT_BAZ' - number: 9 - > - > - options: < - java_package: 'com.google.protobuf.test' - optimize_for: 1 # SPEED - > - """) - - self._InternalTestCopyToProto( - unittest_import_pb2.DESCRIPTOR, - descriptor_pb2.FileDescriptorProto, - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + # Disable this test so we can make changes to the proto file. + # TODO(xiaofeng): Enable this test after cl/55530659 is submitted. + # + # def testCopyToProto_FileDescriptor(self): + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + # name: 'google/protobuf/unittest_import.proto' + # package: 'protobuf_unittest_import' + # dependency: 'google/protobuf/unittest_import_public.proto' + # message_type: < + # name: 'ImportMessage' + # field: < + # name: 'd' + # number: 1 + # label: 1 # Optional + # type: 5 # TYPE_INT32 + # > + # > + # """ + + # """enum_type: < + # name: 'ImportEnum' + # value: < + # name: 'IMPORT_FOO' + # number: 7 + # > + # value: < + # name: 'IMPORT_BAR' + # number: 8 + # > + # value: < + # name: 'IMPORT_BAZ' + # number: 9 + # > + # > + # options: < + # java_package: 'com.google.protobuf.test' + # optimize_for: 1 # SPEED + # > + # public_dependency: 0 + # """) + # self._InternalTestCopyToProto( + # unittest_import_pb2.DESCRIPTOR, + # descriptor_pb2.FileDescriptorProto, + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) def testCopyToProto_ServiceDescriptor(self): TEST_SERVICE_ASCII = """ @@ -323,12 +588,82 @@ class DescriptorCopyToProtoTest(unittest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( unittest_pb2.TestService.DESCRIPTOR, descriptor_pb2.ServiceDescriptorProto, TEST_SERVICE_ASCII) +class MakeDescriptorTest(basetest.TestCase): + + def testMakeDescriptorWithNestedFields(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo2' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + nested_type = message_type.nested_type.add() + nested_type.name = 'Sub' + enum_type = nested_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + field = message_type.field.add() + field.number = 2 + field.name = 'nested_message_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_MESSAGE + field.type_name = 'Sub' + enum_field = nested_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo2.Sub.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + self.assertEqual(result.fields[1].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_MESSAGE) + self.assertEqual(result.fields[1].message_type.containing_type, + result) + self.assertEqual(result.nested_types[0].fields[0].full_name, + 'Foo2.Sub.bar_field') + self.assertEqual(result.nested_types[0].fields[0].enum_type, + result.nested_types[0].enum_types[0]) + + def testMakeDescriptorWithUnsignedIntField(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + enum_type = message_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + enum_field = message_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index aa05d5b..0a7c041 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type @@ -67,9 +71,17 @@ sizer rather than when calling them. In particular: __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import wire_format +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF + + def _VarintSize(value): """Compute the size of a varint value.""" if value <= 0x7f: return 1 @@ -334,7 +346,8 @@ def MessageSetItemSizer(field_number): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeVarint(write, value): bits = value & 0x7f value >>= 7 @@ -351,7 +364,8 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeSignedVarint(write, value): if value < 0: value += (1 << 64) @@ -376,7 +390,8 @@ def _VarintBytes(value): pieces = [] _EncodeVarint(pieces.append, value) - return "".join(pieces) + return "".encode("latin1").join(pieces) ##PY25 +##!PY25 return b"".join(pieces) def TagBytes(field_number, wire_type): @@ -502,6 +517,90 @@ def _StructPackEncoder(wire_type, format): return SpecificEncoder +def _FloatingPointEncoder(wire_type, format): + """Return a constructor for an encoder for float fields. + + This is like StructPackEncoder, but catches errors that may be due to + passing non-finite floating-point values to struct.pack, and makes a + second attempt to encode those values. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25 + value_size = struct.calcsize(format) + if value_size == 4: + def EncodeNonFiniteOrRaise(write, value): + # Remember that the serialized form uses little-endian byte order. + if value == _POS_INF: + write(b('\x00\x00\x80\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x80\x7F') + elif value == _NEG_INF: + write(b('\x00\x00\x80\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x80\xFF') + elif value != value: # NaN + write(b('\x00\x00\xC0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\xC0\x7F') + else: + raise + elif value_size == 8: + def EncodeNonFiniteOrRaise(write, value): + if value == _POS_INF: + write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') + elif value == _NEG_INF: + write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') + elif value != value: # NaN + write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') + else: + raise + else: + raise ValueError('Can\'t encode floating-point values that are ' + '%d bytes long (only 4 or 8)' % value_size) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size) + for element in value: + # This try/except block is going to be faster than any code that + # we could write to check whether element is finite. + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + try: + write(local_struct_pack(format, value)) + except SystemError: + EncodeNonFiniteOrRaise(write, value) + return EncodeField + + return SpecificEncoder + + # ==================================================================== # Here we declare an encoder constructor for each field type. These work # very similarly to sizer constructors, described earlier. @@ -525,15 +624,17 @@ Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I') Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q') SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i') SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q') -FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f') -DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d') +FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') def BoolEncoder(field_number, is_repeated, is_packed): """Returns an encoder for a boolean field.""" - false_byte = chr(0) - true_byte = chr(1) +##!PY25 false_byte = b'\x00' +##!PY25 true_byte = b'\x01' + false_byte = '\x00'.encode('latin1') ##PY25 + true_byte = '\x01'.encode('latin1') ##PY25 if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint @@ -669,7 +770,8 @@ def MessageSetItemEncoder(field_number): } } """ - start_bytes = "".join([ + start_bytes = "".encode("latin1").join([ ##PY25 +##!PY25 start_bytes = b"".join([ TagBytes(1, wire_format.WIRETYPE_START_GROUP), TagBytes(2, wire_format.WIRETYPE_VARINT), _VarintBytes(field_number), diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py new file mode 100644 index 0000000..7b28645 --- /dev/null +++ b/python/google/protobuf/internal/enum_type_wrapper.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A simple wrapper around enum types to expose utility functions. + +Instances are created as properties with the same name as the enum they wrap +on proto classes. For usage, see: + reflection_test.py +""" + +__author__ = 'rabsatt@google.com (Kevin Rabsatt)' + + +class EnumTypeWrapper(object): + """A utility for finding the names of enum values.""" + + DESCRIPTOR = None + + def __init__(self, enum_type): + """Inits EnumTypeWrapper with an EnumDescriptor.""" + self._enum_type = enum_type + self.DESCRIPTOR = enum_type; + + def Name(self, number): + """Returns a string containing the name of an enum value.""" + if number in self._enum_type.values_by_number: + return self._enum_type.values_by_number[number].name + raise ValueError('Enum %s has no name defined for value %d' % ( + self._enum_type.name, number)) + + def Value(self, name): + """Returns the value coresponding to the given enum name.""" + if name in self._enum_type.values_by_name: + return self._enum_type.values_by_name[name].number + raise ValueError('Enum %s has no value defined for name %s' % ( + self._enum_type.name, name)) + + def keys(self): + """Return a list of the string names in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.name + for value_descriptor in self._enum_type.values] + + def values(self): + """Return a list of the integer values in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.number + for value_descriptor in self._enum_type.values] + + def items(self): + """Return a list of the (name, value) pairs of the enum. + + These are returned in the order they were defined in the .proto file. + """ + return [(value_descriptor.name, value_descriptor.number) + for value_descriptor in self._enum_type.values] diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto new file mode 100644 index 0000000..03dcb2c --- /dev/null +++ b/python/google/protobuf/internal/factory_test1.proto @@ -0,0 +1,57 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + + +enum Factory1Enum { + FACTORY_1_VALUE_0 = 0; + FACTORY_1_VALUE_1 = 1; +} + +message Factory1Message { + optional Factory1Enum factory_1_enum = 1; + enum NestedFactory1Enum { + NESTED_FACTORY_1_VALUE_0 = 0; + NESTED_FACTORY_1_VALUE_1 = 1; + } + optional NestedFactory1Enum nested_factory_1_enum = 2; + message NestedFactory1Message { + optional string value = 1; + } + optional NestedFactory1Message nested_factory_1_message = 3; + optional int32 scalar_value = 4; + repeated string list_value = 5; + + extensions 1000 to max; +} diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto new file mode 100644 index 0000000..a8c6812 --- /dev/null +++ b/python/google/protobuf/internal/factory_test2.proto @@ -0,0 +1,92 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + +import "google/protobuf/internal/factory_test1.proto"; + + +enum Factory2Enum { + FACTORY_2_VALUE_0 = 0; + FACTORY_2_VALUE_1 = 1; +} + +message Factory2Message { + required int32 mandatory = 1; + optional Factory2Enum factory_2_enum = 2; + enum NestedFactory2Enum { + NESTED_FACTORY_2_VALUE_0 = 0; + NESTED_FACTORY_2_VALUE_1 = 1; + } + optional NestedFactory2Enum nested_factory_2_enum = 3; + message NestedFactory2Message { + optional string value = 1; + } + optional NestedFactory2Message nested_factory_2_message = 4; + optional Factory1Message factory_1_message = 5; + optional Factory1Enum factory_1_enum = 6; + optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7; + optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8; + optional Factory2Message circular_message = 9; + optional string scalar_value = 10; + repeated string list_value = 11; + repeated group Grouped = 12 { + optional string part_1 = 13; + optional string part_2 = 14; + } + optional LoopMessage loop = 15; + optional int32 int_with_default = 16 [default = 1776]; + optional double double_with_default = 17 [default = 9.99]; + optional string string_with_default = 18 [default = "hello world"]; + optional bool bool_with_default = 19 [default = false]; + optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1]; + optional bytes bytes_with_default = 21 [default = "a\373\000c"]; + + + extend Factory1Message { + optional string one_more_field = 1001; + } + + oneof oneof_field { + int32 oneof_int = 22; + string oneof_string = 23; + } +} + +message LoopMessage { + optional Factory2Message loop = 1; +} + +extend Factory1Message { + optional string another_field = 1002; +} diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 78360b5..5818060 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -41,17 +41,21 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest +from google.protobuf.internal import test_bad_identifiers_pb2 +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 -from google.protobuf import unittest_pb2 from google.protobuf import unittest_no_generic_services_pb2 - +from google.protobuf import unittest_pb2 +from google.protobuf import service +from google.protobuf import symbol_database MAX_EXTENSION = 536870912 -class GeneratorTest(unittest.TestCase): +class GeneratorTest(basetest.TestCase): def testNestedMessageDescriptor(self): field_name = 'optional_nested_message' @@ -99,6 +103,7 @@ class GeneratorTest(unittest.TestCase): self.assertTrue(isinf(message.neg_inf_float)) self.assertTrue(message.neg_inf_float < 0) self.assertTrue(isnan(message.nan_float)) + self.assertEqual("? ? ?? ?? ??? ??/ ??-", message.cpp_trigraph) def testHasDefaultValues(self): desc = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -140,6 +145,13 @@ class GeneratorTest(unittest.TestCase): proto = unittest_mset_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + def testMessageWithCustomOptions(self): + proto = unittest_custom_options_pb2.TestMessageWithCustomOptions() + enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions() + self.assertTrue(enum_options is not None) + # TODO(gps): We really should test for the presense of the enum_opt1 + # extension and for its value to be set to -789. + def testNestedTypes(self): self.assertEquals( set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), @@ -206,15 +218,126 @@ class GeneratorTest(unittest.TestCase): 'google/protobuf/unittest.proto') self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest') self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None) + self.assertEqual(unittest_pb2.DESCRIPTOR.dependencies, + [unittest_import_pb2.DESCRIPTOR]) + self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies, + [unittest_import_public_pb2.DESCRIPTOR]) def testNoGenericServices(self): - # unittest_no_generic_services.proto should contain defs for everything - # except services. self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO")) self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension")) - self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService")) + # Make sure unittest_no_generic_services_pb2 has no services subclassing + # Proto2 Service class. + if hasattr(unittest_no_generic_services_pb2, "TestService"): + self.assertFalse(issubclass(unittest_no_generic_services_pb2.TestService, + service.Service)) + + def testMessageTypesByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2._TESTALLTYPES, + file_type.message_types_by_name[unittest_pb2._TESTALLTYPES.name]) + + # Nested messages shouldn't be included in the message_types_by_name + # dictionary (like in the C++ API). + self.assertFalse( + unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in + file_type.message_types_by_name) + + def testEnumTypesByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2._FOREIGNENUM, + file_type.enum_types_by_name[unittest_pb2._FOREIGNENUM.name]) + + def testExtensionsByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2.my_extension_string, + file_type.extensions_by_name[unittest_pb2.my_extension_string.name]) + + def testPublicImports(self): + # Test public imports as embedded message. + all_type_proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, all_type_proto.optional_public_import_message.e) + + # PublicImportMessage is actually defined in unittest_import_public_pb2 + # module, and is public imported by unittest_import_pb2 module. + public_import_proto = unittest_import_pb2.PublicImportMessage() + self.assertEqual(0, public_import_proto.e) + self.assertTrue(unittest_import_public_pb2.PublicImportMessage is + unittest_import_pb2.PublicImportMessage) + + def testBadIdentifiers(self): + # We're just testing that the code was imported without problems. + message = test_bad_identifiers_pb2.TestBadIdentifiers() + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message], + "foo") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor], + "bar") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection], + "baz") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service], + "qux") + + def testOneof(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR + self.assertEqual(1, len(desc.oneofs)) + self.assertEqual('oneof_field', desc.oneofs[0].name) + self.assertEqual(0, desc.oneofs[0].index) + self.assertIs(desc, desc.oneofs[0].containing_type) + self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field']) + nested_names = set(['oneof_uint32', 'oneof_nested_message', + 'oneof_string', 'oneof_bytes']) + self.assertSameElements( + nested_names, + [field.name for field in desc.oneofs[0].fields]) + for field_name, field_desc in desc.fields_by_name.iteritems(): + if field_name in nested_names: + self.assertIs(desc.oneofs[0], field_desc.containing_oneof) + else: + self.assertIsNone(field_desc.containing_oneof) + + +class SymbolDatabaseRegistrationTest(basetest.TestCase): + """Checks that messages, enums and files are correctly registered.""" + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + with self.assertRaises(KeyError): + symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage') + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + symbol_database.Default().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py new file mode 100644 index 0000000..6a2053f --- /dev/null +++ b/python/google/protobuf/internal/message_factory_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for ..public.message_factory for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_factory_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py new file mode 100644 index 0000000..c53d77b --- /dev/null +++ b/python/google/protobuf/internal/message_factory_test.py @@ -0,0 +1,131 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.apputils import basetest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message_factory + + +class MessageFactoryTest(basetest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def _ExerciseDynamicClass(self, cls): + msg = cls() + msg.mandatory = 42 + msg.nested_factory_2_enum = 0 + msg.nested_factory_2_message.value = 'nested message value' + msg.factory_1_message.factory_1_enum = 1 + msg.factory_1_message.nested_factory_1_enum = 0 + msg.factory_1_message.nested_factory_1_message.value = ( + 'nested message value') + msg.factory_1_message.scalar_value = 22 + msg.factory_1_message.list_value.extend([u'one', u'two', u'three']) + msg.factory_1_message.list_value.append(u'four') + msg.factory_1_enum = 1 + msg.nested_factory_1_enum = 0 + msg.nested_factory_1_message.value = 'nested message value' + msg.circular_message.mandatory = 1 + msg.circular_message.circular_message.mandatory = 2 + msg.circular_message.scalar_value = 'one deep' + msg.scalar_value = 'zero deep' + msg.list_value.extend([u'four', u'three', u'two']) + msg.list_value.append(u'one') + msg.grouped.add() + msg.grouped[0].part_1 = 'hello' + msg.grouped[0].part_2 = 'world' + msg.grouped.add(part_1='testing', part_2='123') + msg.loop.loop.mandatory = 2 + msg.loop.loop.loop.loop.mandatory = 4 + serialized = msg.SerializeToString() + converted = factory_test2_pb2.Factory2Message.FromString(serialized) + reserialized = converted.SerializeToString() + self.assertEquals(serialized, reserialized) + result = cls.FromString(reserialized) + self.assertEquals(msg, result) + + def testGetPrototype(self): + db = descriptor_database.DescriptorDatabase() + pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + factory = message_factory.MessageFactory() + cls = factory.GetPrototype(pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message')) + self.assertIsNot(cls, factory_test2_pb2.Factory2Message) + self._ExerciseDynamicClass(cls) + cls2 = factory.GetPrototype(pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message')) + self.assertIs(cls, cls2) + + def testGetMessages(self): + # performed twice because multiple calls with the same input must be allowed + for _ in range(2): + messages = message_factory.GetMessages([self.factory_test2_fd, + self.factory_test1_fd]) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message', + 'google.protobuf.python.internal.Factory1Message'], + messages.keys()) + self._ExerciseDynamicClass( + messages['google.protobuf.python.internal.Factory2Message']) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message.one_more_field', + 'google.protobuf.python.internal.another_field'], + (messages['google.protobuf.python.internal.Factory1Message'] + ._extensions_by_name.keys())) + factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] + msg1 = messages['google.protobuf.python.internal.Factory1Message']() + ext1 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.Factory2Message.one_more_field'] + ext2 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.another_field'] + msg1.Extensions[ext1] = 'test1' + msg1.Extensions[ext2] = 'test2' + self.assertEquals('test1', msg1.Extensions[ext1]) + self.assertEquals('test2', msg1.Extensions[ext2]) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/message_python_test.py new file mode 100644 index 0000000..baf1504 --- /dev/null +++ b/python/google/protobuf/internal/message_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for ..public.message for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 73a9a3a..f4c4ae0 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -43,47 +43,634 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' -import unittest -from google.protobuf import unittest_import_pb2 +import copy +import math +import operator +import pickle +import sys + +from google.apputils import basetest from google.protobuf import unittest_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf import message + +# Python pre-2.6 does not have isinf() or isnan() functions, so we have +# to provide our own. +def isnan(val): + # NaN is never equal to itself. + return val != val +def isinf(val): + # Infinity times zero equals NaN. + return not isnan(val) and isnan(val * 0) +def IsPosInf(val): + return isinf(val) and (val > 0) +def IsNegInf(val): + return isinf(val) and (val < 0) -class MessageTest(unittest.TestCase): +class MessageTest(basetest.TestCase): + + def testBadUtf8String(self): + if api_implementation.Type() != 'python': + self.skipTest("Skipping testBadUtf8String, currently only the python " + "api implementation raises UnicodeDecodeError when a " + "string field contains bad utf-8.") + bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') + with self.assertRaises(UnicodeDecodeError) as context: + unittest_pb2.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', + str(context.exception)) def testGoldenMessage(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedExtensions() test_util.SetAllPackedExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleSupport(self): + golden_data = test_util.GoldenFileData('golden_message') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + self.assertEquals(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + + def testPositiveInfinity(self): + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsPosInf(golden_message.optional_float)) + self.assertTrue(IsPosInf(golden_message.optional_double)) + self.assertTrue(IsPosInf(golden_message.repeated_float[0])) + self.assertTrue(IsPosInf(golden_message.repeated_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNegativeInfinity(self): + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsNegInf(golden_message.optional_float)) + self.assertTrue(IsNegInf(golden_message.optional_double)) + self.assertTrue(IsNegInf(golden_message.repeated_float[0])) + self.assertTrue(IsNegInf(golden_message.repeated_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNotANumber(self): + golden_data = (b'\x5D\x00\x00\xC0\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' + b'\xCD\x02\x00\x00\xC0\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(isnan(golden_message.optional_float)) + self.assertTrue(isnan(golden_message.optional_double)) + self.assertTrue(isnan(golden_message.repeated_float[0])) + self.assertTrue(isnan(golden_message.repeated_double[0])) + + # The protocol buffer may serialize to any one of multiple different + # representations of a NaN. Rather than verify a specific representation, + # verify the serialized string can be converted into a correctly + # behaving protocol buffer. + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestAllTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.optional_float)) + self.assertTrue(isnan(message.optional_double)) + self.assertTrue(isnan(message.repeated_float[0])) + self.assertTrue(isnan(message.repeated_double[0])) + + def testPositiveInfinityPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsPosInf(golden_message.packed_float[0])) + self.assertTrue(IsPosInf(golden_message.packed_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNegativeInfinityPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsNegInf(golden_message.packed_float[0])) + self.assertTrue(IsNegInf(golden_message.packed_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNotANumberPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(isnan(golden_message.packed_float[0])) + self.assertTrue(isnan(golden_message.packed_double[0])) + + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestPackedTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.packed_float[0])) + self.assertTrue(isnan(message.packed_double[0])) + + def testExtremeFloatValues(self): + message = unittest_pb2.TestAllTypes() + + # Most positive exponent, no significand bits set. + kMostPosExponentNoSigBits = math.pow(2, 127) + message.optional_float = kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) + + # Most positive exponent, one significand bit set. + kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) + message.optional_float = kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) + + # Repeat last two cases with values of same magnitude, but negative. + message.optional_float = -kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) + + message.optional_float = -kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) + + # Most negative exponent, no significand bits set. + kMostNegExponentNoSigBits = math.pow(2, -127) + message.optional_float = kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) + + # Most negative exponent, one significand bit set. + kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) + message.optional_float = kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) + + # Repeat last two cases with values of the same magnitude, but negative. + message.optional_float = -kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) + + message.optional_float = -kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) + + def testExtremeDoubleValues(self): + message = unittest_pb2.TestAllTypes() + + # Most positive exponent, no significand bits set. + kMostPosExponentNoSigBits = math.pow(2, 1023) + message.optional_double = kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) + + # Most positive exponent, one significand bit set. + kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) + message.optional_double = kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) + + # Repeat last two cases with values of same magnitude, but negative. + message.optional_double = -kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) + + message.optional_double = -kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) + + # Most negative exponent, no significand bits set. + kMostNegExponentNoSigBits = math.pow(2, -1023) + message.optional_double = kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) + + # Most negative exponent, one significand bit set. + kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) + message.optional_double = kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) + + # Repeat last two cases with values of the same magnitude, but negative. + message.optional_double = -kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) + + message.optional_double = -kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) + + def testFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_float = 2.0 + self.assertEqual(str(message), 'optional_float: 2.0\n') + + def testHighPrecisionFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_double = 0.12345678912345678 + if sys.version_info.major >= 3: + self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') + else: + self.assertEqual(str(message), 'optional_double: 0.123456789123\n') + + def testUnknownFieldPrinting(self): + populated = unittest_pb2.TestAllTypes() + test_util.SetAllNonLazyFields(populated) + empty = unittest_pb2.TestEmptyMessage() + empty.ParseFromString(populated.SerializeToString()) + self.assertEqual(str(empty), '') + + def testSortingRepeatedScalarFieldsDefaultComparator(self): + """Check some different types with the default comparator.""" + message = unittest_pb2.TestAllTypes() + + # TODO(mattp): would testing more scalar types strengthen test? + message.repeated_int32.append(1) + message.repeated_int32.append(3) + message.repeated_int32.append(2) + message.repeated_int32.sort() + self.assertEqual(message.repeated_int32[0], 1) + self.assertEqual(message.repeated_int32[1], 2) + self.assertEqual(message.repeated_int32[2], 3) + + message.repeated_float.append(1.1) + message.repeated_float.append(1.3) + message.repeated_float.append(1.2) + message.repeated_float.sort() + self.assertAlmostEqual(message.repeated_float[0], 1.1) + self.assertAlmostEqual(message.repeated_float[1], 1.2) + self.assertAlmostEqual(message.repeated_float[2], 1.3) + + message.repeated_string.append('a') + message.repeated_string.append('c') + message.repeated_string.append('b') + message.repeated_string.sort() + self.assertEqual(message.repeated_string[0], 'a') + self.assertEqual(message.repeated_string[1], 'b') + self.assertEqual(message.repeated_string[2], 'c') + + message.repeated_bytes.append(b'a') + message.repeated_bytes.append(b'c') + message.repeated_bytes.append(b'b') + message.repeated_bytes.sort() + self.assertEqual(message.repeated_bytes[0], b'a') + self.assertEqual(message.repeated_bytes[1], b'b') + self.assertEqual(message.repeated_bytes[2], b'c') + + def testSortingRepeatedScalarFieldsCustomComparator(self): + """Check some different types with custom comparator.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(message.repeated_int32[0], -1) + self.assertEqual(message.repeated_int32[1], -2) + self.assertEqual(message.repeated_int32[2], -3) + + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(message.repeated_string[0], 'c') + self.assertEqual(message.repeated_string[1], 'bb') + self.assertEqual(message.repeated_string[2], 'aaa') + + def testSortingRepeatedCompositeFieldsCustomComparator(self): + """Check passing a custom comparator to sort a repeated composite field.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=operator.attrgetter('bb')) + self.assertEqual(message.repeated_nested_message[0].bb, 1) + self.assertEqual(message.repeated_nested_message[1].bb, 2) + self.assertEqual(message.repeated_nested_message[2].bb, 3) + self.assertEqual(message.repeated_nested_message[3].bb, 4) + self.assertEqual(message.repeated_nested_message[4].bb, 5) + self.assertEqual(message.repeated_nested_message[5].bb, 6) + + def testRepeatedCompositeFieldSortArguments(self): + """Check sorting a repeated composite field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + get_bb = operator.attrgetter('bb') + cmp_bb = lambda a, b: cmp(a.bb, b.bb) + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=get_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(key=get_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + if sys.version_info.major >= 3: return # No cmp sorting in PY3. + message.repeated_nested_message.sort(sort_function=cmp_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + + def testRepeatedScalarFieldSortArguments(self): + """Check sorting a scalar field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(key=abs, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + if sys.version_info.major < 3: # No cmp sorting in PY3. + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(key=len, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + if sys.version_info.major < 3: # No cmp sorting in PY3. + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testRepeatedFieldsComparable(self): + m1 = unittest_pb2.TestAllTypes() + m2 = unittest_pb2.TestAllTypes() + m1.repeated_int32.append(0) + m1.repeated_int32.append(1) + m1.repeated_int32.append(2) + m2.repeated_int32.append(0) + m2.repeated_int32.append(1) + m2.repeated_int32.append(2) + m1.repeated_nested_message.add().bb = 1 + m1.repeated_nested_message.add().bb = 2 + m1.repeated_nested_message.add().bb = 3 + m2.repeated_nested_message.add().bb = 1 + m2.repeated_nested_message.add().bb = 2 + m2.repeated_nested_message.add().bb = 3 + + if sys.version_info.major >= 3: return # No cmp() in PY3. + + # These comparisons should not raise errors. + _ = m1 < m2 + _ = m1.repeated_nested_message < m2.repeated_nested_message + + # Make sure cmp always works. If it wasn't defined, these would be + # id() comparisons and would all fail. + self.assertEqual(cmp(m1, m2), 0) + self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) + self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) + self.assertEqual(cmp(m1.repeated_nested_message, + m2.repeated_nested_message), 0) + with self.assertRaises(TypeError): + # Can't compare repeated composite containers to lists. + cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) + + # TODO(anuraag): Implement extensiondict comparison in C++ and then add test + + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + def ensureNestedMessageExists(self, msg, attribute): + """Make sure that a nested message object exists. + + As soon as a nested message attribute is accessed, it will be present in the + _fields dict, without being marked as actually being set. + """ + getattr(msg, attribute) + self.assertFalse(msg.HasField(attribute)) + + def testOneofGetCaseNonexistingField(self): + m = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + + def testOneofSemantics(self): + m = unittest_pb2.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + + m.oneof_uint32 = 11 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + + m.oneof_string = u'foo' + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertTrue(m.HasField('oneof_string')) + + m.oneof_nested_message.bb = 11 + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_string')) + self.assertTrue(m.HasField('oneof_nested_message')) + + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_nested_message')) + self.assertTrue(m.HasField('oneof_bytes')) + + def testOneofCompositeFieldReadAccess(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + + self.ensureNestedMessageExists(m, 'oneof_nested_message') + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertEqual(11, m.oneof_uint32) + + def testOneofHasField(self): + m = unittest_pb2.TestAllTypes() + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' + self.assertTrue(m.HasField('oneof_field')) + m.ClearField('oneof_bytes') + self.assertFalse(m.HasField('oneof_field')) + + def testOneofClearField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_field') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearSetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_uint32') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearUnsetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + self.ensureNestedMessageExists(m, 'oneof_nested_message') + m.ClearField('oneof_nested_message') + self.assertEqual(11, m.oneof_uint32) + self.assertTrue(m.HasField('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + + + + def testSortEmptyRepeatedCompositeContainer(self): + """Exercise a scenario that has led to segfaults in the past. + """ + m = unittest_pb2.TestAllTypes() + m.repeated_nested_message.sort() + + def testHasFieldOnRepeatedField(self): + """Using HasField on a repeated field should raise an exception. + """ + m = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError) as _: + m.HasField('repeated_int32') + + +class ValidTypeNamesTest(basetest.TestCase): + + def assertImportFromName(self, msg, base_name): + # Parse <type 'module.class_name'> to extra 'some.name' as a string. + tp_name = str(type(msg)).split("'")[1] + valid_names = ('Repeated%sContainer' % base_name, + 'Repeated%sFieldContainer' % base_name) + self.assertTrue(any(tp_name.endswith(v) for v in valid_names), + '%r does end with any of %r' % (tp_name, valid_names)) + + parts = tp_name.split('.') + class_name = parts[-1] + module_name = '.'.join(parts[:-1]) + __import__(module_name, fromlist=[class_name]) + + def testTypeNamesCanBeImported(self): + # If import doesn't work, pickling won't work either. + pb = unittest_pb2.TestAllTypes() + self.assertImportFromName(pb.repeated_int32, 'Scalar') + self.assertImportFromName(pb.repeated_nested_message, 'Composite') + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto new file mode 100644 index 0000000..c9ae58b --- /dev/null +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -0,0 +1,50 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +message TestEnumValues { + enum NestedEnum { + ZERO = 0; + ONE = 1; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} + +message TestMissingEnumValues { + enum NestedEnum { + TWO = 2; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto new file mode 100644 index 0000000..df98ac4 --- /dev/null +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -0,0 +1,49 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jasonh@google.com (Jason Hsueh) +// +// This file is used to test a corner case in the CPP implementation where the +// generated C++ type is available for the extendee, but the extension is +// defined in a file whose C++ type is not in the binary. + + +import "google/protobuf/internal/more_extensions.proto"; + +package google.protobuf.internal; + +message DynamicMessageType { + optional int32 a = 1; +} + +extend ExtendedMessage { + optional int32 dynamic_int32_extension = 100; + optional DynamicMessageType dynamic_message_extension = 101; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py new file mode 100755 index 0000000..9ee352d --- /dev/null +++ b/python/google/protobuf/internal/python_message.py @@ -0,0 +1,1247 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Keep it Python2.5 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. +# +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import sys +if sys.version_info[0] < 3: + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + import copy_reg as copyreg +else: + from io import BytesIO + import copyreg +import struct +import weakref + +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import containers +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod +from google.protobuf import text_format + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +def NewMessage(bases, descriptor, dictionary): + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + return bases + + +def InitMessage(descriptor, cls): + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number)) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. + return proto_field_name + + +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + + if not isinstance(extension_handle, _FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) + + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_unknown_fields', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__', + '_oneofs'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor)) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Also exporting a class-level object that can name enum values. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. + + Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: + message: Message instance containing this field, or a weakref proxy + of same. + + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. + """ + + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.has_default_value and field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault + else: + type_checker = type_checkers.GetTypeChecker(field) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + result = message_type._concrete_class() + result._SetListener(message._listener_for_children) + return result + return MakeSubMessageDefault + + def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return field.default_value + return MakeScalarDefault + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self, **kwargs): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = len(kwargs) > 0 + self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () + self._is_present_in_parent = False + self._listener = message_listener_mod.NullMessageListener() + self._listener_for_children = _Listener(self) + for field_name, field_value in kwargs.iteritems(): + field = _GetFieldByName(message_descriptor, field_name) + if field is None: + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (message_descriptor.name, field_name)) + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + for val in field_value: + copy.add().MergeFrom(val) + else: # Scalar + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + copy.MergeFrom(field_value) + self._fields[field] = copy + else: + setattr(self, field_name, field_value) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + constant_name = field.name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, field.number) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field) + default_value = field.default_value + valid_values = set() + + def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return self._fields.get(field, default_value) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def field_setter(self, new_value): + # pylint: disable=protected-access + self._fields[field] = type_checker.CheckValue(new_value) + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() + + if field.containing_oneof is not None: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). + message_type = field.message_type + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = message_type._concrete_class() # use field.message_type? + field_value._SetListener( + _OneofListener(self, field) + if field.containing_oneof is not None + else self._listener_for_children) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForExtensions(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, extension_field.number) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + _AttachFieldHelpers(cls, extension_handle) + + # Try to insert our extension, failing if an extension with the same number + # already exists. + actual_handle = cls._extensions_by_number.setdefault( + extension_handle.number, extension_handle) + if actual_handle is not extension_handle: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" with ' + 'field number %d.' % + (extension_handle.full_name, actual_handle.full_name, + cls.DESCRIPTOR.full_name, extension_handle.number)) + + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + handle = extension_handle # avoid line wrapping + if _IsMessageSetExtension(handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + + +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ListFields(self): + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields + + cls.ListFields = ListFields + + +def _AddHasFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + singular_fields = {} + for field in message_descriptor.fields: + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field + # Fields inside oneofs are never repeated (enforced by the compiler). + for field in message_descriptor.oneofs: + singular_fields[field.name] = field + + def HasField(self, field_name): + try: + field = singular_fields[field_name] + except KeyError: + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) + + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + + cls.HasField = HasField + + +def _AddClearFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + if field in self._fields: + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._Modified() + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + + if self is other: + return True + + if not self.ListFields() == other.ListFields(): + return False + + # Sort unknown fields because their order shouldn't affect equality test. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + + return unknown_fields == other_unknown_fields + + cls.__eq__ = __eq__ + + +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + +def _AddUnicodeMethod(unused_message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def __unicode__(self): + return text_format.MessageToString(self, as_utf8=True).decode('utf-8') + cls.__unicode__ = __unicode__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False + return size + + cls.ByteSize = ByteSize + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializePartialToString(self): + out = BytesIO() + self._InternalSerialize(out.write) + return out.getvalue() + cls.SerializePartialToString = SerializePartialToString + + def InternalSerialize(self, write_bytes): + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) + cls._InternalSerialize = InternalSerialize + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. + raise message_mod.DecodeError('Truncated message.') + except struct.error, e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString + + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + self._Modified() + field_dict = self._fields + unknown_field_list = self._unknown_fields + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder = decoders_by_tag.get(tag_bytes) + if field_decoder is None: + value_start_pos = new_pos + new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if new_pos == -1: + return pos + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + return pos + cls._InternalParse = InternalParse + + +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" + + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] + + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. + + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + + Returns: + True iff the specified message has all required fields set. + """ + + # Performance is critical so we avoid HasField() and ListFields(). + + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + for field, value in list(self._fields.items()): # dict can change size! + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + return True + + cls.IsInitialized = IsInitialized + + def FindInitializationErrors(self): + """Finds required fields which are not initialized. + + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + + errors = [] # simplify things + + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) + + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = "(%s)" % field.full_name + else: + name = field.name + + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): + element = value[i] + prefix = "%s[%d]." % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + else: + prefix = name + "." + sub_errors = value.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + + return errors + + cls.FindInitializationErrors = FindInitializationErrors + + +def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + + def MergeFrom(self, msg): + if not isinstance(msg, cls): + raise TypeError( + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) + + assert msg is not self + self._Modified() + + fields = self._fields + + for field, value in msg._fields.iteritems(): + if field.label == LABEL_REPEATED: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + elif field.cpp_type == CPPTYPE_MESSAGE: + if value._is_present_in_parent: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + + cls.MergeFrom = MergeFrom + + +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) + _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) + _AddUnicodeMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(message_descriptor, cls) + _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) + +def _AddPrivateHelperMethods(message_descriptor, cls): + """Adds implementation of private helper methods to cls.""" + + def Modified(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + + cls._Modified = Modified + cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz.qux = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return + try: + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._Modified() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + def __init__(self, extended_message): + """extended_message: Message instance for which we are the Extensions dict. + """ + + self._extended_message = extended_message + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result + + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) + + return result + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [ field for field in my_fields if field.is_extension ] + other_fields = [ field for field in other_fields if field.is_extension ] + + return my_fields == other_fields + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necssary "has" bits in the + # ancestors of the extended message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker( + extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) + self._extended_message._Modified() + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_name.get(name, None) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 2c9fa30..b3c414c 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,12 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy +import gc import operator import struct -import unittest -# TODO(robinson): When we split this test in two, only some of these imports -# will be necessary in each test. +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -50,6 +50,8 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 from google.protobuf.internal import wire_format @@ -102,12 +104,12 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): - def assertIs(self, values, others): + def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) for i in range(len(values)): - self.assertTrue(values[i] is others[i]) + self.assertEqual(values[i], others[i]) def testScalarConstructor(self): # Constructor with only scalar types should succeed. @@ -200,6 +202,41 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + def testConstructorTypeError(self): + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) + + def testConstructorInvalidatesCachedByteSize(self): + message = unittest_pb2.TestAllTypes(optional_int32 = 12) + self.assertEquals(2, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) + self.assertEquals(3, message.ByteSize()) + def testSimpleHasBits(self): # Test a scalar. proto = unittest_pb2.TestAllTypes() @@ -284,12 +321,6 @@ class ReflectionTest(unittest.TestCase): # ...and ensure that the scalar field has returned to its default. self.assertEqual(0, getattr(composite_field, scalar_field_name)) - # Finally, ensure that modifications to the old composite field object - # don't have any effect on the parent. - # - # (NOTE that when we clear the composite field in the parent, we actually - # don't recursively clear down the tree. Instead, we just disconnect the - # cleared composite from the tree.) self.assertTrue(old_composite_field is not composite_field) setattr(old_composite_field, scalar_field_name, new_val) self.assertTrue(not composite_field.HasField(scalar_field_name)) @@ -319,6 +350,64 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) + def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') + del proto + del nested + # Force a garbage collect so that the underlying CMessages are freed along + # with the Messages they point to. This is to make sure we're not deleting + # default message instances. + gc.collect() + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + + def testDisconnectingNestedMessageAfterSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + self.assertTrue(proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertEqual(5, nested.bb) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testDisconnectingNestedMessageBeforeGettingField(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + + def testDisconnectingNestedMessageAfterMerge(self): + # This test exercises the code path that does not use ReleaseMessage(). + # The underlying fear is that if we use ReleaseMessage() incorrectly, + # we will have memory leaks. It's hard to check that that doesn't happen, + # but at least we can exercise that code path to make sure it works. + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_nested_message.bb = 5 + proto1.MergeFrom(proto2) + self.assertTrue(proto1.HasField('optional_nested_message')) + proto1.ClearField('optional_nested_message') + self.assertTrue(not proto1.HasField('optional_nested_message')) + + def testDisconnectingLazyNestedMessage(self): + # This test exercises releasing a nested message that is lazy. This test + # only exercises real code in the C++ implementation as Python does not + # support lazy parsing, but the current C++ implementation results in + # memory corruption and a crash. + if api_implementation.Type() != 'python': + return + proto = unittest_pb2.TestAllTypes() + proto.optional_lazy_message.bb = 5 + proto.ClearField('optional_lazy_message') + del proto + gc.collect() + def testHasBitsWhenModifyingRepeatedFields(self): # Test nesting when we add an element to a repeated field in a submessage. proto = unittest_pb2.TestNestedMessageHasBits() @@ -446,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -462,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -479,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -500,7 +600,6 @@ class ReflectionTest(unittest.TestCase): # proto.nonexistent_field = 23 should fail as well. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) - # TODO(robinson): Add type-safety check for enums. def testSingleScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) @@ -508,11 +607,37 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() setattr(pb, field_name, expected_min) + self.assertEqual(expected_min, getattr(pb, field_name)) setattr(pb, field_name, expected_max) + self.assertEqual(expected_max, getattr(pb, field_name)) self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) @@ -520,7 +645,10 @@ class ReflectionTest(unittest.TestCase): TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) - TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + pb = unittest_pb2.TestAllTypes() + pb.optional_nested_enum = 1 + self.assertEqual(1, pb.optional_nested_enum) def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() @@ -534,11 +662,19 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + # Repeated enums tests. + #proto.repeated_nested_enum.append(0) + def testSingleScalarGettersAndSetters(self): proto = unittest_pb2.TestAllTypes() self.assertEqual(0, proto.optional_int32) proto.optional_int32 = 1 self.assertEqual(1, proto.optional_int32) + + proto.optional_uint64 = 0xffffffffffff + self.assertEqual(0xffffffffffff, proto.optional_uint64) + proto.optional_uint64 = 0xffffffffffffffff + self.assertEqual(0xffffffffffffffff, proto.optional_uint64) # TODO(robinson): Test all other scalar field types. def testSingleScalarClearField(self): @@ -561,6 +697,77 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testEnum_Name(self): + self.assertEqual('FOREIGN_FOO', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) + self.assertEqual('FOREIGN_BAR', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) + self.assertEqual('FOREIGN_BAZ', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Name, 11312) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual('FOO', + proto.NestedEnum.Name(proto.FOO)) + self.assertEqual('FOO', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) + self.assertEqual('BAR', + proto.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAR', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAZ', + proto.NestedEnum.Name(proto.BAZ)) + self.assertEqual('BAZ', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) + self.assertRaises(ValueError, + proto.NestedEnum.Name, 11312) + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) + + def testEnum_Value(self): + self.assertEqual(unittest_pb2.FOREIGN_FOO, + unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Value, 'FO') + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(proto.FOO, + proto.NestedEnum.Value('FOO')) + self.assertEqual(proto.FOO, + unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) + self.assertEqual(proto.BAR, + proto.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAR, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAZ, + proto.NestedEnum.Value('BAZ')) + self.assertEqual(proto.BAZ, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) + self.assertRaises(ValueError, + proto.NestedEnum.Value, 'Foo') + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') + + def testEnum_KeysAndValues(self): + self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], + unittest_pb2.ForeignEnum.keys()) + self.assertEqual([4, 5, 6], + unittest_pb2.ForeignEnum.values()) + self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), + ('FOREIGN_BAZ', 6)], + unittest_pb2.ForeignEnum.items()) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], + proto.NestedEnum.items()) + def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -619,11 +826,38 @@ class ReflectionTest(unittest.TestCase): del proto.repeated_int32[2:] self.assertEqual([5, 35], proto.repeated_int32) + # Test extending. + proto.repeated_int32.extend([3, 13]) + self.assertEqual([5, 35, 3, 13], proto.repeated_int32) + # Test clearing. proto.ClearField('repeated_int32') self.assertTrue(not proto.repeated_int32) self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(1) + self.assertEqual(1, proto.repeated_int32[-1]) + # Test assignment to a negative index. + proto.repeated_int32[-1] = 2 + self.assertEqual(2, proto.repeated_int32[-1]) + + # Test deletion at negative indices. + proto.repeated_int32[:] = [0, 1, 2, 3] + del proto.repeated_int32[-1] + self.assertEqual([0, 1, 2], proto.repeated_int32) + + del proto.repeated_int32[-2] + self.assertEqual([0, 2], proto.repeated_int32) + + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) + + del proto.repeated_int32[-2:-1] + self.assertEqual([2], proto.repeated_int32) + + del proto.repeated_int32[100:10000] + self.assertEqual([2], proto.repeated_int32) + def testRepeatedScalarsRemove(self): proto = unittest_pb2.TestAllTypes() @@ -661,7 +895,7 @@ class ReflectionTest(unittest.TestCase): m1 = proto.repeated_nested_message.add() self.assertTrue(proto.repeated_nested_message) self.assertEqual(2, len(proto.repeated_nested_message)) - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) # Test out-of-bounds indices. @@ -680,32 +914,86 @@ class ReflectionTest(unittest.TestCase): m2 = proto.repeated_nested_message.add() m3 = proto.repeated_nested_message.add() m4 = proto.repeated_nested_message.add() - self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) - self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m1, m2, m3], proto.repeated_nested_message[1:4]) + self.assertListsEqual( + [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m0, m1], proto.repeated_nested_message[:2]) + self.assertListsEqual( + [m2, m3, m4], proto.repeated_nested_message[2:]) + self.assertEqual( + m0, proto.repeated_nested_message[0]) + self.assertListsEqual( + [m0], proto.repeated_nested_message[:1]) # Test that we can use the field as an iterator. result = [] for i in proto.repeated_nested_message: result.append(i) - self.assertIs([m0, m1, m2, m3, m4], result) + self.assertListsEqual([m0, m1, m2, m3, m4], result) # Test single deletion. del proto.repeated_nested_message[2] - self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) # Test slice deletion. del proto.repeated_nested_message[2:] - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) + + # Test extending. + n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) + n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) + proto.repeated_nested_message.extend([n1,n2]) + self.assertEqual(4, len(proto.repeated_nested_message)) + self.assertEqual(n1, proto.repeated_nested_message[2]) + self.assertEqual(n2, proto.repeated_nested_message[3]) # Test clearing. proto.ClearField('repeated_nested_message') self.assertTrue(not proto.repeated_nested_message) self.assertEqual(0, len(proto.repeated_nested_message)) + # Test constructing an element while adding it. + proto.repeated_nested_message.add(bb=23) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(23, proto.repeated_nested_message[0].bb) + + def testRepeatedCompositeRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + # Need to set some differentiating variable so m0 != m1 != m2: + m0.bb = len(proto.repeated_nested_message) + m1 = proto.repeated_nested_message.add() + m1.bb = len(proto.repeated_nested_message) + self.assertTrue(m0 != m1) + m2 = proto.repeated_nested_message.add() + m2.bb = len(proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) + + self.assertEqual(3, len(proto.repeated_nested_message)) + proto.repeated_nested_message.remove(m0) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + self.assertEqual(m2, proto.repeated_nested_message[1]) + + # Removing m0 again or removing None should raise error + self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) + self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) + self.assertEqual(2, len(proto.repeated_nested_message)) + + proto.repeated_nested_message.remove(m2) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + def testHandWrittenReflection(self): - # TODO(robinson): We probably need a better way to specify - # protocol types by hand. But then again, this isn't something - # we expect many people to do. Hmm. + # Hand written extensions are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + FieldDescriptor = descriptor.FieldDescriptor foo_field_descriptor = FieldDescriptor( name='foo_field', full_name='MyProto.foo_field', @@ -730,6 +1018,68 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + def testDescriptorProtoSupport(self): + # Hand written descriptors/reflection are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + + def AddDescriptorField(proto, field_name, field_type): + AddDescriptorField.field_index += 1 + new_field = proto.field.add() + new_field.name = field_name + new_field.type = field_type + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + AddDescriptorField.field_index = 0 + + desc_proto = descriptor_pb2.DescriptorProto() + desc_proto.name = 'Car' + fdp = descriptor_pb2.FieldDescriptorProto + AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) + AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) + AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) + AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) + # Add a repeated field + AddDescriptorField.field_index += 1 + new_field = desc_proto.field.add() + new_field.name = 'owners' + new_field.type = fdp.TYPE_STRING + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + + desc = descriptor.MakeDescriptor(desc_proto) + self.assertTrue(desc.fields_by_name.has_key('name')) + self.assertTrue(desc.fields_by_name.has_key('year')) + self.assertTrue(desc.fields_by_name.has_key('automatic')) + self.assertTrue(desc.fields_by_name.has_key('price')) + self.assertTrue(desc.fields_by_name.has_key('owners')) + + class CarMessage(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = desc + + prius = CarMessage() + prius.name = 'prius' + prius.year = 2010 + prius.automatic = True + prius.price = 25134.75 + prius.owners.extend(['bob', 'susan']) + + serialized_prius = prius.SerializeToString() + new_prius = reflection.ParseMessage(desc, serialized_prius) + self.assertTrue(new_prius is not prius) + self.assertEqual(prius, new_prius) + + # these are unnecessary assuming message equality works as advertised but + # explicitly check to be safe since we're mucking about in metaclass foo + self.assertEqual(prius.name, new_prius.name) + self.assertEqual(prius.year, new_prius.year) + self.assertEqual(prius.automatic, new_prius.automatic) + self.assertEqual(prius.price, new_prius.price) + self.assertEqual(prius.owners, new_prius.owners) + def testTopLevelExtensionsForOptionalScalar(self): extendee_proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.optional_int32_extension @@ -819,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -868,7 +1226,7 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not toplevel.HasField('submessage')) foreign = toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension].add() - self.assertTrue(foreign is toplevel.submessage.Extensions[ + self.assertEqual(foreign, toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension][0]) self.assertTrue(toplevel.HasField('submessage')) @@ -971,6 +1329,12 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(123, proto2.repeated_nested_message[1].bb) self.assertEqual(321, proto2.repeated_nested_message[2].bb) + proto3 = unittest_pb2.TestAllTypes() + proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) + self.assertEqual(999, proto3.repeated_nested_message[0].bb) + self.assertEqual(123, proto3.repeated_nested_message[1].bb) + self.assertEqual(321, proto3.repeated_nested_message[2].bb) + def testMergeFromAllFields(self): # With all fields set. proto1 = unittest_pb2.TestAllTypes() @@ -1035,6 +1399,19 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(222, ext2[1].bb) self.assertEqual(333, ext2[2].bb) + def testMergeFromBug(self): + message1 = unittest_pb2.TestAllTypes() + message2 = unittest_pb2.TestAllTypes() + + # Cause optional_nested_message to be instantiated within message1, even + # though it is not considered to be "present". + message1.optional_nested_message + self.assertFalse(message1.HasField('optional_nested_message')) + + # Merge into message2. This should not instantiate the field is message2. + message2.MergeFrom(message1) + self.assertFalse(message2.HasField('optional_nested_message')) + def testCopyFromSingularField(self): # Test copy with just a singular field. proto1 = unittest_pb2.TestAllTypes() @@ -1087,9 +1464,36 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(2, proto1.optional_int32) self.assertEqual('important-text', proto1.optional_string) + def testCopyFromBadType(self): + # The python implementation doesn't raise an exception in this + # case. In theory it should. + if api_implementation.Type() == 'python': + return + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllExtensions() + self.assertRaises(TypeError, proto1.CopyFrom, proto2) + + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() - test_util.SetAllFields(proto) + # C++ implementation does not support lazy fields right now so leave it + # out for now. + if api_implementation.Type() == 'python': + test_util.SetAllFields(proto) + else: + test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() self.assertEquals(proto.ByteSize(), 0) @@ -1105,6 +1509,45 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def testDisconnectingBeforeClear(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + foreign = proto.optional_foreign_message + foreign.c = 6 + + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + self.assertTrue(foreign is not proto.optional_foreign_message) + self.assertEqual(5, nested.bb) + self.assertEqual(6, foreign.c) + nested.bb = 15 + foreign.c = 16 + self.assertFalse(proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertFalse(proto.HasField('optional_foreign_message')) + self.assertEqual(0, proto.optional_foreign_message.c) + + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1175,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1192,16 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - # Values of type 'str' are also accepted as long as they can be encoded in - # UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1224,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1232,18 +1709,37 @@ class ReflectionTest(unittest.TestCase): # Check that the type_id is the same as the tag ID in the .proto file. self.assertEqual(raw.item[0].type_id, 1547769) - # Check the actually bytes on the wire. + # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) - # How about if the bytes on the wire aren't a valid UTF-8 encoded string. - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') - self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + # The pure Python API throws an exception on MergeFromString(), + # if any of the string fields of the message can't be UTF-8 decoded. + # The C++ implementation of the API has no way to check that on + # MergeFromString and thus has no way to throw the exception. + # + # The pure Python API always returns objects of type 'unicode' (UTF-8 + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') + + unicode_decode_failed = False + try: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: + unicode_decode_failed = True + string_field = message2.str + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1257,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1280,12 +1779,15 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() self.second_proto = unittest_pb2.TestAllTypes() + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testSelfEquality(self): self.assertEqual(self.first_proto, self.first_proto) @@ -1293,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1303,6 +1805,9 @@ class FullProtosEqualityTest(unittest.TestCase): test_util.SetAllFields(self.first_proto) test_util.SetAllFields(self.second_proto) + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testNoneNotEqual(self): self.assertNotEqual(self.first_proto, None) self.assertNotEqual(None, self.second_proto) @@ -1371,15 +1876,12 @@ class FullProtosEqualityTest(unittest.TestCase): self.first_proto.ClearField('optional_nested_message') self.second_proto.optional_nested_message.ClearField('bb') self.assertNotEqual(self.first_proto, self.second_proto) - # TODO(robinson): Replace next two lines with method - # to set the "has" bit without changing the value, - # if/when such a method exists. self.first_proto.optional_nested_message.bb = 0 self.first_proto.optional_nested_message.ClearField('bb') self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1412,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1424,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -1438,6 +1940,14 @@ class ByteSizeTest(unittest.TestCase): def testEmptyMessage(self): self.assertEqual(0, self.proto.ByteSize()) + def testSizedOnKwargs(self): + # Use a separate message to ensure testing right after creation. + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.ByteSize()) + proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) + # One byte for the tag, one to encode varint 1. + self.assertEqual(2, proto_kwargs.ByteSize()) + def testVarints(self): def Test(i, expected_varint_size): self.proto.Clear() @@ -1629,10 +2139,13 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(3, self.proto.ByteSize()) self.proto.ClearField('optional_foreign_message') self.assertEqual(0, self.proto.ByteSize()) - child = self.proto.optional_foreign_message - self.proto.ClearField('optional_foreign_message') - child.c = 128 - self.assertEqual(0, self.proto.ByteSize()) + + if api_implementation.Type() == 'python': + # This is only possible in pure-Python implementation of the API. + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) # Test within extension. extension = more_extensions_pb2.optional_message_extension @@ -1698,7 +2211,6 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(19, self.packed_extended_proto.ByteSize()) -# TODO(robinson): We need cross-language serialization consistency tests. # Issues to be sure to cover include: # * Handling of unrecognized tags ("uninterpreted_bytes"). # * Handling of MessageSets. @@ -1710,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -1726,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -1734,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -1753,6 +2281,10 @@ class SerializationTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) def testParseTruncated(self): + # This test is only applicable for the Python implementation of the API. + if api_implementation.Type() != 'python': + return + first_proto = unittest_pb2.TestAllTypes() test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() @@ -1822,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -1847,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -1900,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -1918,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -1928,13 +2474,15 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" try: callable_obj() - except exc_class, ex: + except exc_class as ex: # Check if the exception message is the right one. self.assertEqual(exception, str(ex)) return @@ -1946,15 +2494,22 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: a,b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: ' + 'a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() + proto2 = unittest_pb2.TestRequired() + self.assertFalse(proto2.HasField('a')) + # proto2 ParseFromString does not check that required fields are set. + proto2.ParseFromString(partial) + self.assertFalse(proto2.HasField('a')) + proto.a = 1 self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1962,7 +2517,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: c') + 'Message protobuf_unittest.TestRequired is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1972,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -1991,7 +2550,8 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign ' + 'is missing required fields: ' 'optional_message.b,optional_message.c') proto.optional_message.b = 2 @@ -2003,7 +2563,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 'repeated_message[0].b,repeated_message[0].c,' 'repeated_message[1].a,repeated_message[1].c') @@ -2043,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2076,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2085,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2137,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2155,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2205,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2232,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index e04f825..ef0981d 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -34,13 +34,13 @@ __author__ = 'petar@google.com (Petar Petrov)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service -class FooUnitTest(unittest.TestCase): +class FooUnitTest(basetest.TestCase): def testService(self): class MockRpcChannel(service.RpcChannel): @@ -133,4 +133,4 @@ class FooUnitTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py new file mode 100644 index 0000000..80bc8d6 --- /dev/null +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -0,0 +1,120 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.symbol_database.""" + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import symbol_database + + +class SymbolDatabaseTest(basetest.TestCase): + + def _Database(self): + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db + + def testGetPrototype(self): + instance = self._Database().GetPrototype( + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertTrue(instance is unittest_pb2.TestAllTypes) + + def testGetMessages(self): + messages = self._Database().GetMessages( + ['google/protobuf/unittest.proto']) + self.assertTrue( + unittest_pb2.TestAllTypes is + messages['protobuf_unittest.TestAllTypes']) + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + # Check registration of types in the pool. + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindMessageTypeByName(self): + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + def testFindFindContainingSymbol(self): + # Lookup based on either enum or message. + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes').name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto new file mode 100644 index 0000000..6a82299 --- /dev/null +++ b/python/google/protobuf/internal/test_bad_identifiers.proto @@ -0,0 +1,52 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + + +package protobuf_unittest; + +option py_generic_services = true; + +message TestBadIdentifiers { + extensions 100 to max; +} + +// Make sure these reasonable extension names don't conflict with internal +// variables. +extend TestBadIdentifiers { + optional string message = 100 [default="foo"]; + optional string descriptor = 101 [default="bar"]; + optional string reflection = 102 [default="baz"]; + optional string service = 103 [default="qux"]; +} + +message AnotherMessage {} +service AnotherService {} diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 1df1619..350d1c6 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -42,8 +42,8 @@ from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 -def SetAllFields(message): - """Sets every field in the message to a unique value. +def SetAllNonLazyFields(message): + """Sets every non-lazy field in the message to a unique value. Args: message: A unittest_pb2.TestAllTypes instance. @@ -66,26 +66,21 @@ def SetAllFields(message): message.optional_float = 111 message.optional_double = 112 message.optional_bool = True - # TODO(robinson): Firmly spec out and test how - # protos interact with unicode. One specific example: - # what happens if we change the literal below to - # u'115'? What *should* happen? Still some discussion - # to finish with Kenton about bytes vs. strings - # and forcing everything to be utf8. :-/ - message.optional_string = '115' - message.optional_bytes = '116' + message.optional_string = u'115' + message.optional_bytes = b'116' message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 + message.optional_public_import_message.e = 126 message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ - message.optional_string_piece = '124' - message.optional_cord = '125' + message.optional_string_piece = u'124' + message.optional_cord = u'125' # # Repeated fields. @@ -104,20 +99,21 @@ def SetAllFields(message): message.repeated_float.append(211) message.repeated_double.append(212) message.repeated_bool.append(True) - message.repeated_string.append('215') - message.repeated_bytes.append('216') + message.repeated_string.append(u'215') + message.repeated_bytes.append(b'216') message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 + message.repeated_lazy_message.add().bb = 227 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) - message.repeated_string_piece.append('224') - message.repeated_cord.append('225') + message.repeated_string_piece.append(u'224') + message.repeated_cord.append(u'225') # Add a second one of each field. message.repeated_int32.append(301) @@ -133,20 +129,21 @@ def SetAllFields(message): message.repeated_float.append(311) message.repeated_double.append(312) message.repeated_bool.append(False) - message.repeated_string.append('315') - message.repeated_bytes.append('316') + message.repeated_string.append(u'315') + message.repeated_bytes.append(b'316') message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 + message.repeated_lazy_message.add().bb = 327 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) - message.repeated_string_piece.append('324') - message.repeated_cord.append('325') + message.repeated_string_piece.append(u'324') + message.repeated_cord.append(u'325') # # Fields that have defaults. @@ -166,7 +163,7 @@ def SetAllFields(message): message.default_double = 412 message.default_bool = False message.default_string = '415' - message.default_bytes = '416' + message.default_bytes = b'416' message.default_nested_enum = unittest_pb2.TestAllTypes.FOO message.default_foreign_enum = unittest_pb2.FOREIGN_FOO @@ -175,6 +172,16 @@ def SetAllFields(message): message.default_string_piece = '424' message.default_cord = '425' + message.oneof_uint32 = 601 + message.oneof_nested_message.bb = 602 + message.oneof_string = '603' + message.oneof_bytes = b'604' + + +def SetAllFields(message): + SetAllNonLazyFields(message) + message.optional_lazy_message.bb = 127 + def SetAllExtensions(message): """Sets every extension in the message to a unique value. @@ -204,21 +211,23 @@ def SetAllExtensions(message): extensions[pb2.optional_float_extension] = 111 extensions[pb2.optional_double_extension] = 112 extensions[pb2.optional_bool_extension] = True - extensions[pb2.optional_string_extension] = '115' - extensions[pb2.optional_bytes_extension] = '116' + extensions[pb2.optional_string_extension] = u'115' + extensions[pb2.optional_bytes_extension] = b'116' extensions[pb2.optionalgroup_extension].a = 117 extensions[pb2.optional_nested_message_extension].bb = 118 extensions[pb2.optional_foreign_message_extension].c = 119 extensions[pb2.optional_import_message_extension].d = 120 + extensions[pb2.optional_public_import_message_extension].e = 126 + extensions[pb2.optional_lazy_message_extension].bb = 127 extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ - extensions[pb2.optional_string_piece_extension] = '124' - extensions[pb2.optional_cord_extension] = '125' + extensions[pb2.optional_string_piece_extension] = u'124' + extensions[pb2.optional_cord_extension] = u'125' # # Repeated fields. @@ -237,20 +246,21 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(211) extensions[pb2.repeated_double_extension].append(212) extensions[pb2.repeated_bool_extension].append(True) - extensions[pb2.repeated_string_extension].append('215') - extensions[pb2.repeated_bytes_extension].append('216') + extensions[pb2.repeated_string_extension].append(u'215') + extensions[pb2.repeated_bytes_extension].append(b'216') extensions[pb2.repeatedgroup_extension].add().a = 217 extensions[pb2.repeated_nested_message_extension].add().bb = 218 extensions[pb2.repeated_foreign_message_extension].add().c = 219 extensions[pb2.repeated_import_message_extension].add().d = 220 + extensions[pb2.repeated_lazy_message_extension].add().bb = 227 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) - extensions[pb2.repeated_string_piece_extension].append('224') - extensions[pb2.repeated_cord_extension].append('225') + extensions[pb2.repeated_string_piece_extension].append(u'224') + extensions[pb2.repeated_cord_extension].append(u'225') # Append a second one of each field. extensions[pb2.repeated_int32_extension].append(301) @@ -266,20 +276,21 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(311) extensions[pb2.repeated_double_extension].append(312) extensions[pb2.repeated_bool_extension].append(False) - extensions[pb2.repeated_string_extension].append('315') - extensions[pb2.repeated_bytes_extension].append('316') + extensions[pb2.repeated_string_extension].append(u'315') + extensions[pb2.repeated_bytes_extension].append(b'316') extensions[pb2.repeatedgroup_extension].add().a = 317 extensions[pb2.repeated_nested_message_extension].add().bb = 318 extensions[pb2.repeated_foreign_message_extension].add().c = 319 extensions[pb2.repeated_import_message_extension].add().d = 320 + extensions[pb2.repeated_lazy_message_extension].add().bb = 327 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) - extensions[pb2.repeated_string_piece_extension].append('324') - extensions[pb2.repeated_cord_extension].append('325') + extensions[pb2.repeated_string_piece_extension].append(u'324') + extensions[pb2.repeated_cord_extension].append(u'325') # # Fields with defaults. @@ -298,16 +309,21 @@ def SetAllExtensions(message): extensions[pb2.default_float_extension] = 411 extensions[pb2.default_double_extension] = 412 extensions[pb2.default_bool_extension] = False - extensions[pb2.default_string_extension] = '415' - extensions[pb2.default_bytes_extension] = '416' + extensions[pb2.default_string_extension] = u'415' + extensions[pb2.default_bytes_extension] = b'416' extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO - extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_string_piece_extension] = u'424' extensions[pb2.default_cord_extension] = '425' + extensions[pb2.oneof_uint32_extension] = 601 + extensions[pb2.oneof_nested_message_extension].bb = 602 + extensions[pb2.oneof_string_extension] = u'603' + extensions[pb2.oneof_bytes_extension] = b'604' + def SetAllFieldsAndExtensions(message): """Sets every field and extension in the message to a unique value. @@ -346,7 +362,7 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized): message.my_float = 1.0 expected_strings.append(message.SerializeToString()) message.Clear() - expected = ''.join(expected_strings) + expected = b''.join(expected_strings) if expected != serialized: raise ValueError('Expected %r, found %r' % (expected, serialized)) @@ -401,12 +417,14 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(112, message.optional_double) test_case.assertEqual(True, message.optional_bool) test_case.assertEqual('115', message.optional_string) - test_case.assertEqual('116', message.optional_bytes) + test_case.assertEqual(b'116', message.optional_bytes) test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) + test_case.assertEqual(126, message.optional_public_import_message.e) + test_case.assertEqual(127, message.optional_lazy_message.bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) @@ -458,12 +476,13 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(212, message.repeated_double[0]) test_case.assertEqual(True, message.repeated_bool[0]) test_case.assertEqual('215', message.repeated_string[0]) - test_case.assertEqual('216', message.repeated_bytes[0]) + test_case.assertEqual(b'216', message.repeated_bytes[0]) test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) + test_case.assertEqual(227, message.repeated_lazy_message[0].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, message.repeated_nested_enum[0]) @@ -486,12 +505,13 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(312, message.repeated_double[1]) test_case.assertEqual(False, message.repeated_bool[1]) test_case.assertEqual('315', message.repeated_string[1]) - test_case.assertEqual('316', message.repeated_bytes[1]) + test_case.assertEqual(b'316', message.repeated_bytes[1]) test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) + test_case.assertEqual(327, message.repeated_lazy_message[1].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.repeated_nested_enum[1]) @@ -536,7 +556,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(412, message.default_double) test_case.assertEqual(False, message.default_bool) test_case.assertEqual('415', message.default_string) - test_case.assertEqual('416', message.default_bytes) + test_case.assertEqual(b'416', message.default_bytes) test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) @@ -545,6 +565,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, message.default_import_enum) + def GoldenFile(filename): """Finds the given golden file and returns a file object representing it.""" @@ -558,9 +579,15 @@ def GoldenFile(filename): path = os.path.join(path, '..') raise RuntimeError( - 'Could not find golden files. This test must be run from within the ' - 'protobuf source package so that it can read test data files from the ' - 'C++ source tree.') + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def GoldenFileData(filename): + """Finds the given golden file and returns its contents.""" + with GoldenFile(filename) as f: + return f.read() def SetAllPackedFields(message): diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py new file mode 100755 index 0000000..ba0e45d --- /dev/null +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -0,0 +1,68 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.text_encoding.""" + +from google.apputils import basetest +from google.protobuf import text_encoding + +TEST_VALUES = [ + ("foo\\rbar\\nbaz\\t", + "foo\\rbar\\nbaz\\t", + b"foo\rbar\nbaz\t"), + ("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + "\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + b"'full of \"sound\" and \"fury\"'"), + ("signi\\\\fying\\\\ nothing\\\\", + "signi\\\\fying\\\\ nothing\\\\", + b"signi\\fying\\ nothing\\"), + ("\\010\\t\\n\\013\\014\\r", + "\x08\\t\\n\x0b\x0c\\r", + b"\010\011\012\013\014\015")] + + +class TextEncodingTestCase(basetest.TestCase): + def testCEscape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(escaped, + text_encoding.CEscape(unescaped, as_utf8=False)) + self.assertEquals(escaped_utf8, + text_encoding.CEscape(unescaped, as_utf8=True)) + + def testCUnescape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) + self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) + + +if __name__ == "__main__": + basetest.main() diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index e0991cb..d27ff7a 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -34,48 +34,71 @@ __author__ = 'kenton@google.com (Kenton Varda)' -import difflib +import re -import unittest +from google.apputils import basetest from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import unittest_pb2 from google.protobuf import unittest_mset_pb2 +class TextFormatTest(basetest.TestCase): -class TextFormatTest(unittest.TestCase): def ReadGolden(self, golden_filename): - f = test_util.GoldenFile(golden_filename) - golden_lines = f.readlines() - f.close() - return golden_lines + with test_util.GoldenFile(golden_filename) as f: + return (f.readlines() if str is bytes else # PY3 + [golden_line.decode('utf-8') for golden_line in f]) def CompareToGoldenFile(self, text, golden_filename): golden_lines = self.ReadGolden(golden_filename) - self.CompareToGoldenLines(text, golden_lines) + self.assertMultiLineEqual(text, ''.join(golden_lines)) def CompareToGoldenText(self, text, golden_text): - self.CompareToGoldenLines(text, golden_text.splitlines(1)) - - def CompareToGoldenLines(self, text, golden_lines): - actual_lines = text.splitlines(1) - self.assertEqual(golden_lines, actual_lines, - "Text doesn't match golden. Diff:\n" + - ''.join(difflib.ndiff(golden_lines, actual_lines))) + self.assertMultiLineEqual(text, golden_text) def testPrintAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n') def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') def testPrintMessageSet(self): message = unittest_mset_pb2.TestMessageSetContainer() @@ -83,33 +106,179 @@ class TextFormatTest(unittest.TestCase): ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText(text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') + self.CompareToGoldenText( + text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') def testPrintExotic(self): message = unittest_pb2.TestAllTypes() - message.repeated_int64.append(-9223372036854775808); - message.repeated_uint64.append(18446744073709551615); - message.repeated_double.append(123.456); - message.repeated_double.append(1.23e22); - message.repeated_double.append(1.23e-18); - message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"'); + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string:' + ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintExoticUnicodeSubclass(self): + class UnicodeSub(unicode): + pass + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) + self.CompareToGoldenText( + text_format.MessageToString(message), + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintNestedMessageAsOneLine(self): + message = unittest_pb2.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'repeated_nested_message { bb: 42 }') + + def testPrintRepeatedFieldsAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int32.append(1) + message.repeated_int32.append(1) + message.repeated_int32.append(3) + message.repeated_string.append("Google") + message.repeated_string.append("Zurich") self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'repeated_int64: -9223372036854775808\n' - 'repeated_uint64: 18446744073709551615\n' - 'repeated_double: 123.456\n' - 'repeated_double: 1.23e+22\n' - 'repeated_double: 1.23e-18\n' - 'repeated_string: ' - '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + text_format.MessageToString(message, as_one_line=True), + 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' + 'repeated_string: "Google" repeated_string: "Zurich"') + + def testPrintNestedNewLineInStringAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.optional_string = "a\nnew\nline" + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'optional_string: "a\\nnew\\nline"') + + def testPrintMessageSetAsOneLine(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'message_set {' + ' [protobuf_unittest.TestMessageSetExtension1] {' + ' i: 23' + ' }' + ' [protobuf_unittest.TestMessageSetExtension2] {' + ' str: \"foo\"' + ' }' + ' }') + + def testPrintExoticAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + self.CompareToGoldenText( + self.RemoveRedundantZeros( + text_format.MessageToString(message, as_one_line=True)), + 'repeated_int64: -9223372036854775808' + ' repeated_uint64: 18446744073709551615' + ' repeated_double: 123.456' + ' repeated_double: 1.23e+22' + ' repeated_double: 1.23e-18' + ' repeated_string: ' + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' + ' repeated_string: "\\303\\274\\352\\234\\237"') + + def testRoundTripExoticAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + + # Test as_utf8 = False. + wire_text = text_format.MessageToString( + message, as_one_line=True, as_utf8=False) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message) + + # Test as_utf8 = True. + wire_text = text_format.MessageToString( + message, as_one_line=True, as_utf8=True) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintRawUtf8String(self): + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(u'\u00fc\ua71f') + text = text_format.MessageToString(message, as_utf8=True) + self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') + parsed_message = unittest_pb2.TestAllTypes() + text_format.Parse(text, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintFloatFormat(self): + # Check that float_format argument is passed to sub-message formatting. + message = unittest_pb2.NestedTestAllTypes() + # We use 1.25 as it is a round number in binary. The proto 32-bit float + # will not gain additional imprecise digits as a 64-bit Python float and + # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: + # >>> struct.unpack('f', struct.pack('f', 1.2))[0] + # 1.2000000476837158 + # >>> struct.unpack('f', struct.pack('f', 1.25))[0] + # 1.25 + message.payload.optional_float = 1.25 + # Check rounding at 15 significant digits + message.payload.optional_double = -.000003456789012345678 + # Check no decimal point. + message.payload.repeated_float.append(-5642) + # Check no trailing zeros. + message.payload.repeated_double.append(.000078900) + formatted_fields = ['optional_float: 1.25', + 'optional_double: -3.45678901234568e-6', + 'repeated_float: -5642', + 'repeated_double: 7.89e-5'] + text_message = text_format.MessageToString(message, float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) + # as_one_line=True is a separate code branch where float_format is passed. + text_message = text_format.MessageToString(message, as_one_line=True, + float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) def testMessageToString(self): message = unittest_pb2.ForeignMessage() @@ -119,52 +288,57 @@ class TextFormatTest(unittest.TestCase): def RemoveRedundantZeros(self, text): # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove # these zeros in order to match the golden file. - return text.replace('e+0','e+').replace('e+0','e+') \ + text = text.replace('e+0','e+').replace('e+0','e+') \ .replace('e-0','e-').replace('e-0','e-') + # Floating point fields are printed with .0 suffix even if they are + # actualy integer numbers. + text = re.compile('\.0$', re.MULTILINE).sub('', text) + return text - def testMergeGolden(self): + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(golden_text, parsed_message) + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.assertEquals(message, parsed_message) - def testMergeGoldenExtensions(self): + def testParseGoldenExtensions(self): golden_text = '\n'.join(self.ReadGolden( 'text_format_unittest_extensions_data.txt')) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(golden_text, parsed_message) + text_format.Parse(golden_text, parsed_message) message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.assertEquals(message, parsed_message) - def testMergeAllFields(self): + def testParseAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) test_util.ExpectAllFieldsSet(self, message) - def testMergeAllExtensions(self): + def testParseAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - def testMergeMessageSet(self): + def testParseMessageSet(self): message = unittest_pb2.TestAllTypes() text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -177,13 +351,13 @@ class TextFormatTest(unittest.TestCase): ' str: \"foo\"\n' ' }\n' '}\n') - text_format.Merge(text, message) + text_format.Parse(text, message) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEquals(23, message.message_set.Extensions[ext1].i) self.assertEquals('foo', message.message_set.Extensions[ext2].str) - def testMergeExotic(self): + def testParseExotic(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' @@ -191,9 +365,12 @@ class TextFormatTest(unittest.TestCase): 'repeated_double: 1.23e+22\n' 'repeated_double: 1.23e-18\n' 'repeated_string: \n' - '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n' - 'repeated_string: "foo" \'corge\' "grault"') - text_format.Merge(text, message) + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "foo" \'corge\' "grault"\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n' + 'repeated_string: "\\xc3\\xbc"\n' + 'repeated_string: "\xc3\xbc"\n') + text_format.Parse(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) self.assertEqual(18446744073709551615, message.repeated_uint64[0]) @@ -201,95 +378,217 @@ class TextFormatTest(unittest.TestCase): self.assertEqual(1.23e22, message.repeated_double[1]) self.assertEqual(1.23e-18, message.repeated_double[2]) self.assertEqual( - '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0]) + '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0]) self.assertEqual('foocorgegrault', message.repeated_string[1]) + self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) + self.assertEqual(u'\u00fc', message.repeated_string[3]) + + def testParseTrailingCommas(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: 100;\n' + 'repeated_int64: 200;\n' + 'repeated_int64: 300,\n' + 'repeated_string: "one",\n' + 'repeated_string: "two";\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + + def testParseEmptyText(self): + message = unittest_pb2.TestAllTypes() + text = '' + text_format.Parse(text, message) + self.assertEquals(unittest_pb2.TestAllTypes(), message) + + def testParseInvalidUtf8(self): + message = unittest_pb2.TestAllTypes() + text = 'repeated_string: "\\xc3\\xc3"' + self.assertRaises(text_format.ParseError, text_format.Parse, text, message) + + def testParseSingleWord(self): + message = unittest_pb2.TestAllTypes() + text = 'foo' + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' + '"foo".'), + text_format.Parse, text, message) - def testMergeUnknownField(self): + def testParseUnknownField(self): message = unittest_pb2.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"unknown_field".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadExtension(self): + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeGroupNotClosed(self): + def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', - text_format.Merge, text, message) + text_format.Parse, text, message) text = 'RepeatedGroup: {' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected "}".', - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeEmptyGroup(self): + def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: {}' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) message.Clear() message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: <>' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) - def testMergeBadEnumValue(self): + def testParseBadEnumValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value named BARR.'), - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value with number 100.'), - text_format.Merge, text, message) - - def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): - """Same as assertRaises, but also compares the exception message.""" - if hasattr(e_class, '__name__'): - exc_name = e_class.__name__ - else: - exc_name = str(e_class) - - try: - func(*args, **kwargs) - except e_class, expr: - if str(expr) != e: - msg = '%s raised, but with wrong message: "%s" instead of "%s"' - raise self.failureException(msg % (exc_name, - str(expr).encode('string_escape'), - e.encode('string_escape'))) - return - else: - raise self.failureException('%s not raised' % exc_name) - - -class TokenizerTest(unittest.TestCase): + text_format.Parse, text, message) + + def testParseBadIntValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_int32: bork' + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:17 : Couldn\'t parse integer: bork'), + text_format.Parse, text, message) + + def testParseStringFieldUnescape(self): + message = unittest_pb2.TestAllTypes() + text = r'''repeated_string: "\xf\x62" + repeated_string: "\\xf\\x62" + repeated_string: "\\\xf\\\x62" + repeated_string: "\\\\xf\\\\x62" + repeated_string: "\\\\\xf\\\\\x62" + repeated_string: "\x5cx20"''' + text_format.Parse(text, message) + + SLASH = '\\' + self.assertEqual('\x0fb', message.repeated_string[0]) + self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1]) + self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2]) + self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62', + message.repeated_string[3]) + self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b', + message.repeated_string[4]) + self.assertEqual(SLASH + 'x20', message.repeated_string[5]) + + def testMergeRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + r = text_format.Merge(text, message) + self.assertIs(r, message) + self.assertEqual(67, message.optional_int32) + + def testParseRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + r = text_format.Merge(text, message) + self.assertTrue(r is message) + self.assertEqual(2, message.optional_nested_message.bb) + + def testParseRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + text_format.Merge(text, message) + self.assertEqual( + 67, + message.Extensions[unittest_pb2.optional_int32_extension]) + + def testParseRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' + 'should not have multiple ' + '"protobuf_unittest.optional_int32_extension" extensions.'), + text_format.Parse, text, message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + +class TokenizerTest(basetest.TestCase): def testSimpleTokenCases(self): text = ('identifier1:"string1"\n \n\n' @@ -297,8 +596,9 @@ class TokenizerTest(unittest.TestCase): 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' - 'ID12: 2222222222222222222') - tokenizer = text_format._Tokenizer(text) + 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' + 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') + tokenizer = text_format._Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -325,10 +625,10 @@ class TokenizerTest(unittest.TestCase): '{', (tokenizer.ConsumeIdentifier, 'A'), ':', - (tokenizer.ConsumeFloat, text_format._INFINITY), + (tokenizer.ConsumeFloat, float('inf')), (tokenizer.ConsumeIdentifier, 'B'), ':', - (tokenizer.ConsumeFloat, -text_format._INFINITY), + (tokenizer.ConsumeFloat, -float('inf')), (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), @@ -347,7 +647,25 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeInt32, -22), (tokenizer.ConsumeIdentifier, 'ID12'), ':', - (tokenizer.ConsumeUint64, 2222222222222222222)] + (tokenizer.ConsumeUint64, 2222222222222222222), + (tokenizer.ConsumeIdentifier, 'ID13'), + ':', + (tokenizer.ConsumeFloat, 1.23456), + (tokenizer.ConsumeIdentifier, 'ID14'), + ':', + (tokenizer.ConsumeFloat, 1.2e+2), + (tokenizer.ConsumeIdentifier, 'false_bool'), + ':', + (tokenizer.ConsumeBool, False), + (tokenizer.ConsumeIdentifier, 'true_BOOL'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'true_bool1'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'false_BOOL1'), + ':', + (tokenizer.ConsumeBool, False)] i = 0 while not tokenizer.AtEnd(): @@ -366,7 +684,7 @@ class TokenizerTest(unittest.TestCase): int64_max = (1 << 63) - 1 uint32_max = (1 << 32) - 1 text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) self.assertEqual(-1, tokenizer.ConsumeInt32()) @@ -380,7 +698,7 @@ class TokenizerTest(unittest.TestCase): self.assertTrue(tokenizer.AtEnd()) text = '-0 -0 0 0' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertEqual(0, tokenizer.ConsumeUint32()) self.assertEqual(0, tokenizer.ConsumeUint64()) self.assertEqual(0, tokenizer.ConsumeUint32()) @@ -389,40 +707,30 @@ class TokenizerTest(unittest.TestCase): def testConsumeByteString(self): text = '"string1\'' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = 'string1"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\xt"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\x"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) def testConsumeBool(self): text = 'not-a-bool' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) - def testInfNan(self): - # Make sure our infinity and NaN definitions are sound. - self.assertEquals(float, type(text_format._INFINITY)) - self.assertEquals(float, type(text_format._NAN)) - self.assertTrue(text_format._NAN != text_format._NAN) - - inf_times_zero = text_format._INFINITY * 0 - self.assertTrue(inf_times_zero != inf_times_zero) - self.assertTrue(text_format._INFINITY > 0) - if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 2b3cd4d..8e1b3cc 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2008 Google Inc. All Rights Reserved. + """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -45,6 +49,9 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' +import sys ##PY25 +if sys.version < '2.6': bytes = str ##PY25 +from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import wire_format @@ -53,21 +60,22 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor -def GetTypeChecker(cpp_type, field_type): +def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. Args: - cpp_type: C++ type of the field (see descriptor.py). - field_type: Protocol message field type (see descriptor.py). + field: FieldDescriptor object for this field. Returns: An instance of TypeChecker which can be used to verify the types of values assigned to a field of the specified type. """ - if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and - field_type == _FieldDescriptor.TYPE_STRING): + if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() - return _VALUE_CHECKERS[cpp_type] + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + return EnumValueChecker(field.enum_type) + return _VALUE_CHECKERS[field.cpp_type] # None of the typecheckers below make any attempt to guard against people @@ -85,10 +93,15 @@ class TypeChecker(object): self._acceptable_types = acceptable_types def CheckValue(self, proposed_value): + """Type check the provided value and return it. + + The returned value might have been normalized to another type. + """ if not isinstance(proposed_value, self._acceptable_types): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), self._acceptable_types)) raise TypeError(message) + return proposed_value # IntValueChecker and its subclasses perform integer type-checks @@ -104,28 +117,54 @@ class IntValueChecker(object): raise TypeError(message) if not self._MIN <= proposed_value <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + proposed_value = self._TYPE(proposed_value) + return proposed_value + + +class EnumValueChecker(object): + + """Checker used for enum fields. Performs type-check and range check.""" + + def __init__(self, enum_type): + self._enum_type = enum_type + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if proposed_value not in self._enum_type.values_by_number: + raise ValueError('Unknown enum value: %d' % proposed_value) + return proposed_value class UnicodeValueChecker(object): - """Checker used for string fields.""" + """Checker used for string fields. + + Always returns a unicode value, even if the input is of type str. + """ def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (str, unicode)): + if not isinstance(proposed_value, (bytes, unicode)): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (str, unicode))) + (proposed_value, type(proposed_value), (bytes, unicode))) raise TypeError(message) - # If the value is of type 'str' make sure that it is in 7-bit ASCII + # If the value is of type 'bytes' make sure that it is in 7-bit ASCII # encoding. - if isinstance(proposed_value, str): + if isinstance(proposed_value, bytes): try: - unicode(proposed_value, 'ascii') + proposed_value = proposed_value.decode('ascii') except UnicodeDecodeError: - raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII ' 'encoding. Non-ASCII strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + return proposed_value class Int32ValueChecker(IntValueChecker): @@ -133,21 +172,25 @@ class Int32ValueChecker(IntValueChecker): # efficient. _MIN = -2147483648 _MAX = 2147483647 + _TYPE = int class Uint32ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 32) - 1 + _TYPE = int class Int64ValueChecker(IntValueChecker): _MIN = -(1 << 63) _MAX = (1 << 63) - 1 + _TYPE = long class Uint64ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 64) - 1 + _TYPE = long # Type-checkers for all scalar CPPTYPEs. @@ -161,8 +204,7 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( float, int, long), _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py new file mode 100755 index 0000000..8f3354c --- /dev/null +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -0,0 +1,231 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for preservation of unknown fields in the pure Python implementation.""" + +__author__ = 'bohdank@google.com (Bohdan Koval)' + +from google.apputils import basetest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import encoder +from google.protobuf.internal import missing_enum_values_pb2 +from google.protobuf.internal import test_util +from google.protobuf.internal import type_checkers + + +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + self.unknown_fields = self.empty_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes] + decoder(value, 0, len(value), self.all_fields, result_dict) + return result_dict[field_descriptor] + + def testEnum(self): + value = self.GetField('optional_nested_enum') + self.assertEqual(self.all_fields.optional_nested_enum, value) + + def testRepeatedEnum(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.all_fields.repeated_nested_enum, value) + + def testVarint(self): + value = self.GetField('optional_int32') + self.assertEqual(self.all_fields.optional_int32, value) + + def testFixed32(self): + value = self.GetField('optional_fixed32') + self.assertEqual(self.all_fields.optional_fixed32, value) + + def testFixed64(self): + value = self.GetField('optional_fixed64') + self.assertEqual(self.all_fields.optional_fixed64, value) + + def testLengthDelimited(self): + value = self.GetField('optional_string') + self.assertEqual(self.all_fields.optional_string, value) + + def testGroup(self): + value = self.GetField('optionalgroup') + self.assertEqual(self.all_fields.optionalgroup, value) + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testCopyFrom(self): + message = unittest_pb2.TestEmptyMessage() + message.CopyFrom(self.empty_message) + self.assertEqual(self.unknown_fields, message._unknown_fields) + + def testMergeFrom(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 1 + message.optional_uint32 = 2 + source = unittest_pb2.TestEmptyMessage() + source.ParseFromString(message.SerializeToString()) + + message.ClearField('optional_int32') + message.optional_int64 = 3 + message.optional_uint32 = 4 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_fields = destination._unknown_fields[:] + + destination.MergeFrom(source) + self.assertEqual(unknown_fields + source._unknown_fields, + destination._unknown_fields) + + def testClear(self): + self.empty_message.Clear() + self.assertEqual(0, len(self.empty_message._unknown_fields)) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testUnknownExtensions(self): + message = unittest_pb2.TestEmptyMessageWithExtensions() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message._unknown_fields, + message._unknown_fields) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 1545009 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR + + self.message = missing_enum_values_pb2.TestEnumValues() + self.message.optional_nested_enum = ( + missing_enum_values_pb2.TestEnumValues.ZERO) + self.message.repeated_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message.packed_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message_data = self.message.SerializeToString() + self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() + self.missing_message.ParseFromString(self.message_data) + self.unknown_fields = self.missing_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ + tag_bytes] + decoder(value, 0, len(value), self.message, result_dict) + return result_dict[field_descriptor] + + def testUnknownEnumValue(self): + self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + value = self.GetField('optional_nested_enum') + self.assertEqual(self.message.optional_nested_enum, value) + + def testUnknownRepeatedEnumValue(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.message.repeated_nested_enum, value) + + def testUnknownPackedEnumValue(self): + value = self.GetField('packed_nested_enum') + self.assertEqual(self.message.packed_nested_enum, value) + + def testRoundTrip(self): + new_message = missing_enum_values_pb2.TestEnumValues() + new_message.ParseFromString(self.missing_message.SerializeToString()) + self.assertEqual(self.message, new_message) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 7600778..9362c72 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -34,12 +34,12 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import message from google.protobuf.internal import wire_format -class WireFormatTest(unittest.TestCase): +class WireFormatTest(basetest.TestCase): def testPackTag(self): field_number = 0xabc @@ -195,7 +195,7 @@ class WireFormatTest(unittest.TestCase): # Test UTF-8 string byte size calculation. # 1 byte for tag, 1 byte for length, 8 bytes for content. self.assertEqual(10, wire_format.StringByteSize( - 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + 5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8'))) class MockMessage(object): def __init__(self, byte_size): @@ -250,4 +250,4 @@ class WireFormatTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index f839847..37b0af1 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -67,14 +67,28 @@ class Message(object): DESCRIPTOR = None + def __deepcopy__(self, memo=None): + clone = type(self)() + clone.MergeFrom(self) + return clone + def __eq__(self, other_msg): + """Recursively compares two messages by value and structure.""" raise NotImplementedError def __ne__(self, other_msg): # Can't just say self != other_msg, since that would infinitely recurse. :) return not self == other_msg + def __hash__(self): + raise TypeError('unhashable object') + def __str__(self): + """Outputs a human-readable representation of the message.""" + raise NotImplementedError + + def __unicode__(self): + """Outputs a human-readable representation of the message.""" raise NotImplementedError def MergeFrom(self, other_msg): @@ -163,7 +177,11 @@ class Message(object): raise NotImplementedError def ParseFromString(self, serialized): - """Like MergeFromString(), except we clear the object first.""" + """Parse serialized protocol buffer data into this message. + + Like MergeFromString(), except we clear the object first and + do not return the value that MergeFromString returns. + """ self.Clear() self.MergeFromString(serialized) @@ -215,6 +233,9 @@ class Message(object): raise NotImplementedError def HasField(self, field_name): + """Checks if a certain field is set for the message. Note if the + field_name is not defined in the message descriptor, ValueError will be + raised.""" raise NotImplementedError def ClearField(self, field_name): @@ -252,3 +273,12 @@ class Message(object): via a previous _SetListener() call. """ raise NotImplementedError + + def __getstate__(self): + """Support the pickle protocol.""" + return dict(serialized=self.SerializePartialToString()) + + def __setstate__(self, state): + """Support the pickle protocol.""" + self.__init__() + self.ParseFromString(state['serialized']) diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py new file mode 100644 index 0000000..9004ffd --- /dev/null +++ b/python/google/protobuf/message_factory.py @@ -0,0 +1,155 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#PY25 compatible for GAE. +# +# Copyright 2012 Google Inc. All Rights Reserved. + +"""Provides a factory class for generating dynamic messages. + +The easiest way to use this class is if you have access to the FileDescriptor +protos containing the messages you want to create you can just do the following: + +message_classes = message_factory.GetMessages(iterable_of_file_descriptors) +my_proto_instance = message_classes['some.proto.package.MessageName']() +""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import sys ##PY25 +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message +from google.protobuf import reflection + + +class MessageFactory(object): + """Factory for creating Proto2 messages from descriptors in a pool.""" + + def __init__(self, pool=None): + """Initializes a new factory.""" + self.pool = (pool or descriptor_pool.DescriptorPool( + descriptor_database.DescriptorDatabase())) + + # local cache of all classes built from protobuf descriptors + self._classes = {} + + def GetPrototype(self, descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + if descriptor.full_name not in self._classes: + descriptor_name = descriptor.name + if sys.version_info[0] < 3: ##PY25 +##!PY25 if str is bytes: # PY2 + descriptor_name = descriptor.name.encode('ascii', 'ignore') + result_class = reflection.GeneratedProtocolMessageType( + descriptor_name, + (message.Message,), + {'DESCRIPTOR': descriptor, '__module__': None}) + # If module not set, it wrongly points to the reflection.py module. + self._classes[descriptor.full_name] = result_class + for field in descriptor.fields: + if field.message_type: + self.GetPrototype(field.message_type) + for extension in result_class.DESCRIPTOR.extensions: + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) + return self._classes[descriptor.full_name] + + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if the descriptor + pool cannot satisfy them. + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + result = {} + for file_name in files: + file_desc = self.pool.FindFileByName(file_name) + for name, msg in file_desc.message_types_by_name.iteritems(): + if file_desc.package: + full_name = '.'.join([file_desc.package, name]) + else: + full_name = msg.name + result[full_name] = self.GetPrototype( + self.pool.FindMessageTypeByName(full_name)) + + # While the extension FieldDescriptors are created by the descriptor pool, + # the python classes created in the factory need them to be registered + # explicitly, which is done below. + # + # The call to RegisterExtension will specifically check if the + # extension was already registered on the object and either + # ignore the registration if the original was the same, or raise + # an error if they were different. + + for name, extension in file_desc.extensions_by_name.iteritems(): + if extension.containing_type.full_name not in self._classes: + self.GetPrototype(extension.containing_type) + extended_class = self._classes[extension.containing_type.full_name] + extended_class.RegisterExtension(extension) + return result + + +_FACTORY = MessageFactory() + + +def GetMessages(file_protos): + """Builds a dictionary of all the messages available in a set of files. + + Args: + file_protos: A sequence of file protos to build messages out of. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + """ + for file_proto in file_protos: + _FACTORY.pool.Add(file_proto) + return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos]) diff --git a/python/google/protobuf/pyext/README b/python/google/protobuf/pyext/README new file mode 100644 index 0000000..6d61cb4 --- /dev/null +++ b/python/google/protobuf/pyext/README @@ -0,0 +1,6 @@ +This is the 'v2' C++ implementation for python proto2. + +It is active when: + +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 diff --git a/python/google/protobuf/pyext/__init__.py b/python/google/protobuf/pyext/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/google/protobuf/pyext/__init__.py diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py new file mode 100644 index 0000000..ba87f8e --- /dev/null +++ b/python/google/protobuf/pyext/cpp_message.py @@ -0,0 +1,61 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Protocol message implementation hooks for C++ implementation. + +Contains helper functions used to create protocol message classes from +Descriptor objects at runtime backed by the protocol buffer C++ API. +""" + +__author__ = 'tibell@google.com (Johan Tibell)' + +from google.protobuf.pyext import _message +from google.protobuf import message + + +def NewMessage(bases, message_descriptor, dictionary): + """Creates a new protocol message *class*.""" + new_bases = [] + for base in bases: + if base is message.Message: + # _message.Message must come before message.Message as it + # overrides methods in that class. + new_bases.append(_message.Message) + new_bases.append(base) + return tuple(new_bases) + + +def InitMessage(message_descriptor, cls): + """Constructs a new message instance (called before instance's __init__).""" + + def SubInit(self, **kwargs): + super(cls, self).__init__(message_descriptor, **kwargs) + cls.__init__ = SubInit + cls.AddDescriptors(message_descriptor) diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc new file mode 100644 index 0000000..cbf42c0 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor.cc @@ -0,0 +1,357 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#include <Python.h> +#include <string> + +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#define C(str) const_cast<char*>(str) + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif + + +static google::protobuf::DescriptorPool* g_descriptor_pool = NULL; + +namespace cfield_descriptor { + +static void Dealloc(CFieldDescriptor* self) { + Py_CLEAR(self->descriptor_field); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* GetFullName(CFieldDescriptor* self, void *closure) { + return PyString_FromStringAndSize( + self->descriptor->full_name().c_str(), + self->descriptor->full_name().size()); +} + +static PyObject* GetName(CFieldDescriptor *self, void *closure) { + return PyString_FromStringAndSize( + self->descriptor->name().c_str(), + self->descriptor->name().size()); +} + +static PyObject* GetCppType(CFieldDescriptor *self, void *closure) { + return PyInt_FromLong(self->descriptor->cpp_type()); +} + +static PyObject* GetLabel(CFieldDescriptor *self, void *closure) { + return PyInt_FromLong(self->descriptor->label()); +} + +static PyObject* GetID(CFieldDescriptor *self, void *closure) { + return PyLong_FromVoidPtr(self); +} + +static PyGetSetDef Getters[] = { + { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, + { C("name"), (getter)GetName, NULL, "last name", NULL}, + { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, + { C("label"), (getter)GetLabel, NULL, "Label", NULL}, + { C("id"), (getter)GetID, NULL, "ID", NULL}, + {NULL} +}; + +} // namespace cfield_descriptor + +PyTypeObject CFieldDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_net_proto2___python." + "CFieldDescriptor"), // tp_name + sizeof(CFieldDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)cfield_descriptor::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Field Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + cfield_descriptor::Getters, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + PyType_GenericAlloc, // tp_alloc + PyType_GenericNew, // tp_new + PyObject_Del, // tp_free +}; + +namespace cdescriptor_pool { + +static void Dealloc(CDescriptorPool* self) { + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyObject* NewCDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor) { + CFieldDescriptor* cfield_descriptor = PyObject_New( + CFieldDescriptor, &CFieldDescriptor_Type); + if (cfield_descriptor == NULL) { + return NULL; + } + cfield_descriptor->descriptor = field_descriptor; + cfield_descriptor->descriptor_field = NULL; + + return reinterpret_cast<PyObject*>(cfield_descriptor); +} + +PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name) { + const char* full_field_name = PyString_AsString(name); + if (full_field_name == NULL) { + return NULL; + } + + const google::protobuf::FieldDescriptor* field_descriptor = NULL; + + field_descriptor = self->pool->FindFieldByName(full_field_name); + + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + full_field_name); + return NULL; + } + + return NewCDescriptor(field_descriptor); +} + +PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg) { + const char* full_field_name = PyString_AsString(arg); + if (full_field_name == NULL) { + return NULL; + } + + const google::protobuf::FieldDescriptor* field_descriptor = + self->pool->FindExtensionByName(full_field_name); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + full_field_name); + return NULL; + } + + return NewCDescriptor(field_descriptor); +} + +static PyMethodDef Methods[] = { + { C("FindFieldByName"), + (PyCFunction)FindFieldByName, + METH_O, + C("Searches for a field descriptor by full name.") }, + { C("FindExtensionByName"), + (PyCFunction)FindExtensionByName, + METH_O, + C("Searches for extension descriptor by full name.") }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject CDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_net_proto2___python." + "CFieldDescriptor"), // tp_name + sizeof(CDescriptorPool), // tp_basicsize + 0, // tp_itemsize + (destructor)cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Descriptor Pool"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + PyType_GenericAlloc, // tp_alloc + PyType_GenericNew, // tp_new + PyObject_Del, // tp_free +}; + +google::protobuf::DescriptorPool* GetDescriptorPool() { + if (g_descriptor_pool == NULL) { + g_descriptor_pool = new google::protobuf::DescriptorPool( + google::protobuf::DescriptorPool::generated_pool()); + } + return g_descriptor_pool; +} + +PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args) { + CDescriptorPool* cdescriptor_pool = PyObject_New( + CDescriptorPool, &CDescriptorPool_Type); + if (cdescriptor_pool == NULL) { + return NULL; + } + cdescriptor_pool->pool = GetDescriptorPool(); + return reinterpret_cast<PyObject*>(cdescriptor_pool); +} + + +// Collects errors that occur during proto file building to allow them to be +// propagated in the python exception instead of only living in ERROR logs. +class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector { + public: + BuildFileErrorCollector() : error_message(""), had_errors(false) {} + + void AddError(const string& filename, const string& element_name, + const Message* descriptor, ErrorLocation location, + const string& message) { + // Replicates the logging behavior that happens in the C++ implementation + // when an error collector is not passed in. + if (!had_errors) { + error_message += + ("Invalid proto descriptor for file \"" + filename + "\":\n"); + } + // As this only happens on failure and will result in the program not + // running at all, no effort is made to optimize this string manipulation. + error_message += (" " + element_name + ": " + message + "\n"); + } + + string error_message; + bool had_errors; +}; + +PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) { + char* message_type; + Py_ssize_t message_len; + + if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) { + return NULL; + } + + google::protobuf::FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(message_type, message_len)) { + PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); + return NULL; + } + + if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName( + file_proto.name()) != NULL) { + Py_RETURN_NONE; + } + + BuildFileErrorCollector error_collector; + const google::protobuf::FileDescriptor* descriptor = + GetDescriptorPool()->BuildFileCollectingErrors(file_proto, + &error_collector); + if (descriptor == NULL) { + PyErr_Format(PyExc_TypeError, + "Couldn't build proto file into descriptor pool!\n%s", + error_collector.error_message.c_str()); + return NULL; + } + + Py_RETURN_NONE; +} + +bool InitDescriptor() { + CFieldDescriptor_Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&CFieldDescriptor_Type) < 0) + return false; + + CDescriptorPool_Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&CDescriptorPool_Type) < 0) + return false; + + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h new file mode 100644 index 0000000..d114425 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor.h @@ -0,0 +1,96 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: petar@google.com (Petar Petrov) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ + +#include <Python.h> +#include <structmember.h> + +#include <google/protobuf/descriptor.h> + +#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) +typedef int Py_ssize_t; +#define PY_SSIZE_T_MAX INT_MAX +#define PY_SSIZE_T_MIN INT_MIN +#endif + +namespace google { +namespace protobuf { +namespace python { + +typedef struct CFieldDescriptor { + PyObject_HEAD + + // The proto2 descriptor that this object represents. + const google::protobuf::FieldDescriptor* descriptor; + + // Reference to the original field object in the Python DESCRIPTOR. + PyObject* descriptor_field; +} CFieldDescriptor; + +typedef struct { + PyObject_HEAD + + const google::protobuf::DescriptorPool* pool; +} CDescriptorPool; + +extern PyTypeObject CFieldDescriptor_Type; + +extern PyTypeObject CDescriptorPool_Type; + +namespace cdescriptor_pool { + +// Looks up a field by name. Returns a CDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name); + +// Looks up an extension by name. Returns a CDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg); + +} // namespace cdescriptor_pool + +PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args); +PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); +bool InitDescriptor(); +google::protobuf::DescriptorPool* GetDescriptorPool(); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ diff --git a/python/google/protobuf/pyext/descriptor_cpp2_test.py b/python/google/protobuf/pyext/descriptor_cpp2_test.py new file mode 100644 index 0000000..3a3ff29 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_cpp2_test.py @@ -0,0 +1,58 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.pyext behavior.""" + +__author__ = 'anuraag@google.com (Anuraag Agrawal)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmCppApi2Test(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc new file mode 100644 index 0000000..1e14b42 --- /dev/null +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -0,0 +1,338 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/extension_dict.h> + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> +#include <google/protobuf/stubs/shared_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace extension_dict { + +// TODO(tibell): Always use self->message for clarity, just like in +// RepeatedCompositeContainer. +static google::protobuf::Message* GetMessage(ExtensionDict* self) { + if (self->parent != NULL) { + return self->parent->message; + } else { + return self->message; + } +} + +CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension) { + PyObject* cdescriptor = PyObject_GetAttrString(extension, "_cdescriptor"); + if (cdescriptor == NULL) { + PyErr_SetString(PyExc_KeyError, "Unregistered extension."); + return NULL; + } + if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor"); + Py_DECREF(cdescriptor); + return NULL; + } + CFieldDescriptor* descriptor = + reinterpret_cast<CFieldDescriptor*>(cdescriptor); + return descriptor; +} + +PyObject* len(ExtensionDict* self) { +#if PY_MAJOR_VERSION >= 3 + return PyLong_FromLong(PyDict_Size(self->values)); +#else + return PyInt_FromLong(PyDict_Size(self->values)); +#endif +} + +// TODO(tibell): Use VisitCompositeField. +int ReleaseExtension(ExtensionDict* self, + PyObject* extension, + const google::protobuf::FieldDescriptor* descriptor) { + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (repeated_composite_container::Release( + reinterpret_cast<RepeatedCompositeContainer*>( + extension)) < 0) { + return -1; + } + } else { + if (repeated_scalar_container::Release( + reinterpret_cast<RepeatedScalarContainer*>( + extension)) < 0) { + return -1; + } + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (cmessage::ReleaseSubMessage( + GetMessage(self), descriptor, + reinterpret_cast<CMessage*>(extension)) < 0) { + return -1; + } + } + + return 0; +} + +PyObject* subscript(ExtensionDict* self, PyObject* key) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + key); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + if (descriptor == NULL) { + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + return cmessage::InternalGetScalar(self->parent, descriptor); + } + + PyObject* value = PyDict_GetItem(self->values, key); + if (value != NULL) { + Py_INCREF(value); + return value; + } + + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = cmessage::InternalGetSubMessage( + self->parent, cdescriptor); + if (sub_message == NULL) { + return NULL; + } + PyDict_SetItem(self->values, key, sub_message); + return sub_message; + } + + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + // COPIED + PyObject* py_container = PyObject_CallObject( + reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type), + NULL); + if (py_container == NULL) { + return NULL; + } + RepeatedCompositeContainer* container = + reinterpret_cast<RepeatedCompositeContainer*>(py_container); + PyObject* field = cdescriptor->descriptor_field; + PyObject* message_type = PyObject_GetAttrString(field, "message_type"); + PyObject* concrete_class = PyObject_GetAttrString(message_type, + "_concrete_class"); + container->owner = self->owner; + container->parent = self->parent; + container->message = self->parent->message; + container->parent_field = cdescriptor; + container->subclass_init = concrete_class; + Py_DECREF(message_type); + PyDict_SetItem(self->values, key, py_container); + return py_container; + } else { + // COPIED + ScopedPyObjectPtr init_args(PyTuple_Pack(2, self->parent, cdescriptor)); + PyObject* py_container = PyObject_CallObject( + reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), + init_args); + if (py_container == NULL) { + return NULL; + } + PyDict_SetItem(self->values, key, py_container); + return py_container; + } + } + PyErr_SetString(PyExc_ValueError, "control reached unexpected line"); + return NULL; +} + +int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + key); + if (cdescriptor == NULL) { + return -1; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL || + descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite " + "type"); + return -1; + } + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } + // TODO(tibell): We shouldn't write scalars to the cache. + PyDict_SetItem(self->values, key, value); + return 0; +} + +PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + extension); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); + PyObject* value = PyDict_GetItem(self->values, extension); + if (value != NULL) { + if (ReleaseExtension(self, value, cdescriptor->descriptor) < 0) { + return NULL; + } + } + if (cmessage::ClearFieldByDescriptor(self->parent, + cdescriptor->descriptor) == NULL) { + return NULL; + } + if (PyDict_DelItem(self->values, extension) < 0) { + PyErr_Clear(); + } + Py_RETURN_NONE; +} + +PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { + CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension( + extension); + if (cdescriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); + PyObject* result = cmessage::HasFieldByDescriptor( + self->parent, cdescriptor->descriptor); + return result; +} + +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { + ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString( + reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name")); + if (extensions_by_name == NULL) { + return NULL; + } + PyObject* result = PyDict_GetItem(extensions_by_name, name); + if (result == NULL) { + Py_RETURN_NONE; + } else { + Py_INCREF(result); + return result; + } +} + +int init(ExtensionDict* self, PyObject* args, PyObject* kwargs) { + self->parent = NULL; + self->message = NULL; + self->values = PyDict_New(); + return 0; +} + +void dealloc(ExtensionDict* self) { + Py_CLEAR(self->values); + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PyMappingMethods MpMethods = { + (lenfunc)len, /* mp_length */ + (binaryfunc)subscript, /* mp_subscript */ + (objobjargproc)ass_subscript,/* mp_ass_subscript */ +}; + +#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc } +static PyMethodDef Methods[] = { + EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."), + EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), + EDMETHOD(_FindExtensionByName, METH_O, + "Finds an extension by name."), + { NULL, NULL } +}; + +} // namespace extension_dict + +PyTypeObject ExtensionDict_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.ExtensionDict", // tp_name + sizeof(ExtensionDict), // tp_basicsize + 0, // tp_itemsize + (destructor)extension_dict::dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &extension_dict::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "An extension dict", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + extension_dict::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)extension_dict::init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h new file mode 100644 index 0000000..1343001 --- /dev/null +++ b/python/google/protobuf/pyext/extension_dict.h @@ -0,0 +1,123 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + + +namespace google { +namespace protobuf { + +class Message; +class FieldDescriptor; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; +struct CFieldDescriptor; + +typedef struct ExtensionDict { + PyObject_HEAD; + shared_ptr<Message> owner; + CMessage* parent; + Message* message; + PyObject* values; +} ExtensionDict; + +extern PyTypeObject ExtensionDict_Type; + +namespace extension_dict { + +// Gets the _cdescriptor reference to a CFieldDescriptor object given a +// python descriptor object. +// +// Returns a new reference. +CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension); + +// Gets the number of extension values in this ExtensionDict as a python object. +// +// Returns a new reference. +PyObject* len(ExtensionDict* self); + +// Releases extensions referenced outside this dictionary to keep outside +// references alive. +// +// Returns 0 on success, -1 on failure. +int ReleaseExtension(ExtensionDict* self, + PyObject* extension, + const google::protobuf::FieldDescriptor* descriptor); + +// Gets an extension from the dict for the given extension descriptor. +// +// Returns a new reference. +PyObject* subscript(ExtensionDict* self, PyObject* key); + +// Assigns a value to an extension in the dict. Can only be used for singular +// simple types. +// +// Returns 0 on success, -1 on failure. +int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value); + +// Clears an extension from the dict. Will release the extension if there +// is still an external reference left to it. +// +// Returns None on success. +PyObject* ClearExtension(ExtensionDict* self, + PyObject* extension); + +// Checks if the dict has an extension. +// +// Returns a new python boolean reference. +PyObject* HasExtension(ExtensionDict* self, PyObject* extension); + +// Gets an extension from the dict given the extension name as opposed to +// descriptor. +// +// Returns a new reference. +PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name); + +} // namespace extension_dict +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc new file mode 100644 index 0000000..c45cbf0 --- /dev/null +++ b/python/google/protobuf/pyext/message.cc @@ -0,0 +1,2561 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/message.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif +#include <string> +#include <vector> + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/text_format.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/extension_dict.h> +#include <google/protobuf/pyext/repeated_composite_container.h> +#include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_Check PyLong_Check + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyString_Check PyUnicode_Check + #define PyString_FromString PyUnicode_FromString + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Forward declarations +namespace cmessage { +static PyObject* GetDescriptor(CMessage* self, PyObject* name); +static string GetMessageName(CMessage* self); +int InternalReleaseFieldByDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* composite_field, + google::protobuf::Message* parent_message); +} // namespace cmessage + +// --------------------------------------------------------------------- +// Visiting the composite children of a CMessage + +struct ChildVisitor { + // Returns 0 on success, -1 on failure. + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + return 0; + } + + // Returns 0 on success, -1 on failure. + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + return 0; + } + + // Returns 0 on success, -1 on failure. + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return 0; + } +}; + +// Apply a function to a composite field. Does nothing if child is of +// non-composite type. +template<class Visitor> +static int VisitCompositeField(const FieldDescriptor* descriptor, + PyObject* child, + Visitor visitor) { + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + RepeatedCompositeContainer* container = + reinterpret_cast<RepeatedCompositeContainer*>(child); + if (visitor.VisitRepeatedCompositeContainer(container) == -1) + return -1; + } else { + RepeatedScalarContainer* container = + reinterpret_cast<RepeatedScalarContainer*>(child); + if (visitor.VisitRepeatedScalarContainer(container) == -1) + return -1; + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + CMessage* cmsg = reinterpret_cast<CMessage*>(child); + if (visitor.VisitCMessage(cmsg, descriptor) == -1) + return -1; + } + // The ExtensionDict might contain non-composite fields, which we + // skip here. + return 0; +} + +// Visit each composite field and extension field of this CMessage. +// Returns -1 on error and 0 on success. +template<class Visitor> +int ForEachCompositeField(CMessage* self, Visitor visitor) { + Py_ssize_t pos = 0; + PyObject* key; + PyObject* field; + + // Visit normal fields. + while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { + PyObject* cdescriptor = cmessage::GetDescriptor(self, key); + if (cdescriptor != NULL) { + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor; + if (VisitCompositeField(descriptor, field, visitor) == -1) + return -1; + } + } + + // Visit extension fields. + if (self->extensions != NULL) { + while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { + CFieldDescriptor* cdescriptor = + extension_dict::InternalGetCDescriptorFromExtension(key); + if (cdescriptor == NULL) + return -1; + if (VisitCompositeField(cdescriptor->descriptor, field, visitor) == -1) + return -1; + } + } + + return 0; +} + +// --------------------------------------------------------------------- + +// Constants used for integer type range checking. +PyObject* kPythonZero; +PyObject* kint32min_py; +PyObject* kint32max_py; +PyObject* kuint32max_py; +PyObject* kint64min_py; +PyObject* kint64max_py; +PyObject* kuint64max_py; + +PyObject* EnumTypeWrapper_class; +PyObject* EncodeError_class; +PyObject* DecodeError_class; +PyObject* PickleError_class; + +// Constant PyString values used for GetAttr/GetItem. +static PyObject* kDESCRIPTOR; +static PyObject* k__descriptors; +static PyObject* kfull_name; +static PyObject* kname; +static PyObject* kmessage_type; +static PyObject* kis_extendable; +static PyObject* kextensions_by_name; +static PyObject* k_extensions_by_name; +static PyObject* k_extensions_by_number; +static PyObject* k_concrete_class; +static PyObject* kfields_by_name; + +static CDescriptorPool* descriptor_pool; + +/* Is 64bit */ +void FormatTypeError(PyObject* arg, char* expected_types) { + PyObject* repr = PyObject_Repr(arg); + if (repr) { + PyErr_Format(PyExc_TypeError, + "%.100s has type %.100s, but expected one of: %s", + PyString_AsString(repr), + Py_TYPE(arg)->tp_name, + expected_types); + Py_DECREF(repr); + } +} + +template<class T> +bool CheckAndGetInteger( + PyObject* arg, T* value, PyObject* min, PyObject* max) { + bool is_long = PyLong_Check(arg); +#if PY_MAJOR_VERSION < 3 + if (!PyInt_Check(arg) && !is_long) { + FormatTypeError(arg, "int, long"); + return false; + } + if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) { +#else + if (!is_long) { + FormatTypeError(arg, "int"); + return false; + } + if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || + PyObject_RichCompareBool(max, arg, Py_GE) != 1) { +#endif + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } + return false; + } +#if PY_MAJOR_VERSION < 3 + if (!is_long) { + *value = static_cast<T>(PyInt_AsLong(arg)); + } else // NOLINT +#endif + { + if (min == kPythonZero) { + *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg)); + } else { + *value = static_cast<T>(PyLong_AsLongLong(arg)); + } + } + return true; +} + +// These are referenced by repeated_scalar_container, and must +// be explicitly instantiated. +template bool CheckAndGetInteger<int32>( + PyObject*, int32*, PyObject*, PyObject*); +template bool CheckAndGetInteger<int64>( + PyObject*, int64*, PyObject*, PyObject*); +template bool CheckAndGetInteger<uint32>( + PyObject*, uint32*, PyObject*, PyObject*); +template bool CheckAndGetInteger<uint64>( + PyObject*, uint64*, PyObject*, PyObject*); + +bool CheckAndGetDouble(PyObject* arg, double* value) { + if (!PyInt_Check(arg) && !PyLong_Check(arg) && + !PyFloat_Check(arg)) { + FormatTypeError(arg, "int, long, float"); + return false; + } + *value = PyFloat_AsDouble(arg); + return true; +} + +bool CheckAndGetFloat(PyObject* arg, float* value) { + double double_value; + if (!CheckAndGetDouble(arg, &double_value)) { + return false; + } + *value = static_cast<float>(double_value); + return true; +} + +bool CheckAndGetBool(PyObject* arg, bool* value) { + if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) { + FormatTypeError(arg, "int, long, bool"); + return false; + } + *value = static_cast<bool>(PyInt_AsLong(arg)); + return true; +} + +bool CheckAndSetString( + PyObject* arg, google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + const google::protobuf::Reflection* reflection, + bool append, + int index) { + GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING || + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES); + if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { + FormatTypeError(arg, "bytes, unicode"); + return false; + } + + if (PyBytes_Check(arg)) { + PyObject* unicode = PyUnicode_FromEncodedObject(arg, "ascii", NULL); + if (unicode == NULL) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't in 7-bit ASCII " + "encoding. Non-ASCII strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return false; + } else { + Py_DECREF(unicode); + } + } + } else if (!PyBytes_Check(arg)) { + FormatTypeError(arg, "bytes"); + return false; + } + + PyObject* encoded_string = NULL; + if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (PyBytes_Check(arg)) { +#if PY_MAJOR_VERSION < 3 + encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL); +#else + encoded_string = arg; // Already encoded. + Py_INCREF(encoded_string); +#endif + } else { + encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL); + } + } else { + // In this case field type is "bytes". + encoded_string = arg; + Py_INCREF(encoded_string); + } + + if (encoded_string == NULL) { + return false; + } + + char* value; + Py_ssize_t value_len; + if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { + Py_DECREF(encoded_string); + return false; + } + + string value_string(value, value_len); + if (append) { + reflection->AddString(message, descriptor, value_string); + } else if (index < 0) { + reflection->SetString(message, descriptor, value_string); + } else { + reflection->SetRepeatedString(message, descriptor, index, value_string); + } + Py_DECREF(encoded_string); + return true; +} + +PyObject* ToStringObject( + const google::protobuf::FieldDescriptor* descriptor, string value) { + if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) { + return PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + + PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL); + // If the string can't be decoded in UTF-8, just return a string object that + // contains the raw bytes. This can't happen if the value was assigned using + // the members of the Python message object, but can happen if the values were + // parsed from the wire (binary). + if (result == NULL) { + PyErr_Clear(); + result = PyBytes_FromStringAndSize(value.c_str(), value.length()); + } + return result; +} + +google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace cmessage { + +static int MaybeReleaseOverlappingOneofField( + CMessage* cmessage, + const google::protobuf::FieldDescriptor* field) { +#ifdef GOOGLE_PROTOBUF_HAS_ONEOF + google::protobuf::Message* message = cmessage->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + if (!field->containing_oneof() || + !reflection->HasOneof(*message, field->containing_oneof()) || + reflection->HasField(*message, field)) { + // No other field in this oneof, no need to release. + return 0; + } + + const OneofDescriptor* oneof = field->containing_oneof(); + const FieldDescriptor* existing_field = + reflection->GetOneofFieldDescriptor(*message, oneof); + if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + // Non-message fields don't need to be released. + return 0; + } + const char* field_name = existing_field->name().c_str(); + PyObject* child_message = PyDict_GetItemString( + cmessage->composite_fields, field_name); + if (child_message == NULL) { + // No python reference to this field so no need to release. + return 0; + } + + if (InternalReleaseFieldByDescriptor( + existing_field, child_message, message) < 0) { + return -1; + } + return PyDict_DelItemString(cmessage->composite_fields, field_name); +#else + return 0; +#endif +} + +// --------------------------------------------------------------------- +// Making a message writable + +static google::protobuf::Message* GetMutableMessage( + CMessage* parent, + const google::protobuf::FieldDescriptor* parent_field) { + google::protobuf::Message* parent_message = parent->message; + const google::protobuf::Reflection* reflection = parent_message->GetReflection(); + if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) { + return NULL; + } + return reflection->MutableMessage( + parent_message, parent_field, global_message_factory); +} + +struct FixupMessageReference : public ChildVisitor { + // message must outlive this object. + explicit FixupMessageReference(google::protobuf::Message* message) : + message_(message) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + container->message = message_; + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + container->message = message_; + return 0; + } + + private: + google::protobuf::Message* message_; +}; + +int AssureWritable(CMessage* self) { + if (self == NULL || !self->read_only) { + return 0; + } + + if (self->parent == NULL) { + // If parent is NULL but we are trying to modify a read-only message, this + // is a reference to a constant default instance that needs to be replaced + // with a mutable top-level message. + const Message* prototype = global_message_factory->GetPrototype( + self->message->GetDescriptor()); + self->message = prototype->New(); + self->owner.reset(self->message); + } else { + // Otherwise, we need a mutable child message. + if (AssureWritable(self->parent) == -1) + return -1; + + // Make self->message writable. + google::protobuf::Message* parent_message = self->parent->message; + google::protobuf::Message* mutable_message = GetMutableMessage( + self->parent, + self->parent_field->descriptor); + if (mutable_message == NULL) { + return -1; + } + self->message = mutable_message; + } + self->read_only = false; + + // When a CMessage is made writable its Message pointer is updated + // to point to a new mutable Message. When that happens we need to + // update any references to the old, read-only CMessage. There are + // three places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, and ExtensionDict. + if (self->extensions != NULL) + self->extensions->message = self->message; + if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) + return -1; + + return 0; +} + +// --- Globals: + +static PyObject* GetDescriptor(CMessage* self, PyObject* name) { + PyObject* descriptors = + PyDict_GetItem(Py_TYPE(self)->tp_dict, k__descriptors); + if (descriptors == NULL) { + PyErr_SetString(PyExc_TypeError, "No __descriptors"); + return NULL; + } + + return PyDict_GetItem(descriptors, name); +} + +static const google::protobuf::Message* CreateMessage(const char* message_type) { + string message_name(message_type); + const google::protobuf::Descriptor* descriptor = + GetDescriptorPool()->FindMessageTypeByName(message_name); + if (descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, message_type); + return NULL; + } + return global_message_factory->GetPrototype(descriptor); +} + +// If cmessage_list is not NULL, this function releases values into the +// container CMessages instead of just removing. Repeated composite container +// needs to do this to make sure CMessages stay alive if they're still +// referenced after deletion. Repeated scalar container doesn't need to worry. +int InternalDeleteRepeatedField( + google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* slice, + PyObject* cmessage_list) { + Py_ssize_t length, from, to, step, slice_length; + const google::protobuf::Reflection* reflection = message->GetReflection(); + int min, max; + length = reflection->FieldSize(*message, field_descriptor); + + if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + step = 1; + min = max = from; + + // Range check. + if (from < 0 || from >= length) { + PyErr_Format(PyExc_IndexError, "list assignment index out of range"); + return -1; + } + } else if (PySlice_Check(slice)) { + from = to = step = slice_length = 0; + PySlice_GetIndicesEx( +#if PY_MAJOR_VERSION < 3 + reinterpret_cast<PySliceObject*>(slice), +#else + slice, +#endif + length, &from, &to, &step, &slice_length); + if (from < to) { + min = from; + max = to - 1; + } else { + min = to + 1; + max = from; + } + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + Py_ssize_t i = from; + std::vector<bool> to_delete(length, false); + while (i >= min && i <= max) { + to_delete[i] = true; + i += step; + } + + to = 0; + for (i = 0; i < length; ++i) { + if (!to_delete[i]) { + if (i != to) { + reflection->SwapElements(message, field_descriptor, i, to); + if (cmessage_list != NULL) { + // If a list of cmessages is passed in (i.e. from a repeated + // composite container), swap those as well to correspond to the + // swaps in the underlying message so they're in the right order + // when we start releasing. + PyObject* tmp = PyList_GET_ITEM(cmessage_list, i); + PyList_SET_ITEM(cmessage_list, i, + PyList_GET_ITEM(cmessage_list, to)); + PyList_SET_ITEM(cmessage_list, to, tmp); + } + } + ++to; + } + } + + while (i > to) { + if (cmessage_list == NULL) { + reflection->RemoveLast(message, field_descriptor); + } else { + CMessage* last_cmessage = reinterpret_cast<CMessage*>( + PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1)); + repeated_composite_container::ReleaseLastTo( + field_descriptor, message, last_cmessage); + if (PySequence_DelItem(cmessage_list, -1) < 0) { + return -1; + } + } + --i; + } + + return 0; +} + +int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) { + ScopedPyObjectPtr descriptor; + if (arg == NULL) { + descriptor.reset( + PyObject_GetAttr(reinterpret_cast<PyObject*>(self), kDESCRIPTOR)); + if (descriptor == NULL) { + return NULL; + } + } else { + descriptor.reset(arg); + descriptor.inc(); + } + ScopedPyObjectPtr is_extendable(PyObject_GetAttr(descriptor, kis_extendable)); + if (is_extendable == NULL) { + return NULL; + } + int retcode = PyObject_IsTrue(is_extendable); + if (retcode == -1) { + return NULL; + } + if (retcode) { + PyObject* py_extension_dict = PyObject_CallObject( + reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL); + if (py_extension_dict == NULL) { + return NULL; + } + ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>( + py_extension_dict); + extension_dict->parent = self; + extension_dict->message = self->message; + self->extensions = extension_dict; + } + + if (kwargs == NULL) { + return 0; + } + + Py_ssize_t pos = 0; + PyObject* name; + PyObject* value; + while (PyDict_Next(kwargs, &pos, &name, &value)) { + if (!PyString_Check(name)) { + PyErr_SetString(PyExc_ValueError, "Field name must be a string"); + return -1; + } + PyObject* py_cdescriptor = GetDescriptor(self, name); + if (py_cdescriptor == NULL) { + PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + PyString_AsString(name)); + return -1; + } + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast<CFieldDescriptor*>(py_cdescriptor)->descriptor; + if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + ScopedPyObjectPtr container(GetAttr(self, name)); + if (container == NULL) { + return -1; + } + if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (repeated_composite_container::Extend( + reinterpret_cast<RepeatedCompositeContainer*>(container.get()), + value) + == NULL) { + return -1; + } + } else { + if (repeated_scalar_container::Extend( + reinterpret_cast<RepeatedScalarContainer*>(container.get()), + value) == + NULL) { + return -1; + } + } + } else if (descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + ScopedPyObjectPtr message(GetAttr(self, name)); + if (message == NULL) { + return -1; + } + if (MergeFrom(reinterpret_cast<CMessage*>(message.get()), + value) == NULL) { + return -1; + } + } else { + if (SetAttr(self, name, value) < 0) { + return -1; + } + } + } + return 0; +} + +static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + CMessage* self = reinterpret_cast<CMessage*>(type->tp_alloc(type, 0)); + if (self == NULL) { + return NULL; + } + + self->message = NULL; + self->parent = NULL; + self->parent_field = NULL; + self->read_only = false; + self->extensions = NULL; + + self->composite_fields = PyDict_New(); + if (self->composite_fields == NULL) { + return NULL; + } + return reinterpret_cast<PyObject*>(self); +} + +PyObject* NewEmpty(PyObject* type) { + return New(reinterpret_cast<PyTypeObject*>(type), NULL, NULL); +} + +static int Init(CMessage* self, PyObject* args, PyObject* kwargs) { + if (kwargs == NULL) { + // TODO(anuraag): Set error + return -1; + } + + PyObject* descriptor = PyTuple_GetItem(args, 0); + if (descriptor == NULL || PyTuple_Size(args) != 1) { + PyErr_SetString(PyExc_ValueError, "args must contain one arg: descriptor"); + return -1; + } + + ScopedPyObjectPtr py_message_type(PyObject_GetAttr(descriptor, kfull_name)); + if (py_message_type == NULL) { + return -1; + } + + const char* message_type = PyString_AsString(py_message_type.get()); + const google::protobuf::Message* message = CreateMessage(message_type); + if (message == NULL) { + return -1; + } + + self->message = message->New(); + self->owner.reset(self->message); + + if (InitAttributes(self, descriptor, kwargs) < 0) { + return -1; + } + return 0; +} + +// --------------------------------------------------------------------- +// Deallocating a CMessage +// +// Deallocating a CMessage requires that we clear any weak references +// from children to the message being deallocated. + +// Clear the weak reference from the child to the parent. +struct ClearWeakReferences : public ChildVisitor { + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + container->parent = NULL; + // The elements in the container have the same parent as the + // container itself, so NULL out that pointer as well. + const Py_ssize_t n = PyList_GET_SIZE(container->child_messages); + for (Py_ssize_t i = 0; i < n; ++i) { + CMessage* child_cmessage = reinterpret_cast<CMessage*>( + PyList_GET_ITEM(container->child_messages, i)); + child_cmessage->parent = NULL; + } + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + container->parent = NULL; + return 0; + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + cmessage->parent = NULL; + return 0; + } +}; + +static void Dealloc(CMessage* self) { + // Null out all weak references from children to this message. + GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + + Py_CLEAR(self->extensions); + Py_CLEAR(self->composite_fields); + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +// --------------------------------------------------------------------- + + +PyObject* IsInitialized(CMessage* self, PyObject* args) { + PyObject* errors = NULL; + if (PyArg_ParseTuple(args, "|O", &errors) < 0) { + return NULL; + } + if (self->message->IsInitialized()) { + Py_RETURN_TRUE; + } + if (errors != NULL) { + ScopedPyObjectPtr initialization_errors( + FindInitializationErrors(self)); + if (initialization_errors == NULL) { + return NULL; + } + ScopedPyObjectPtr extend_name(PyString_FromString("extend")); + if (extend_name == NULL) { + return NULL; + } + ScopedPyObjectPtr result(PyObject_CallMethodObjArgs( + errors, + extend_name.get(), + initialization_errors.get(), + NULL)); + if (result == NULL) { + return NULL; + } + } + Py_RETURN_FALSE; +} + +PyObject* HasFieldByDescriptor( + CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) { + google::protobuf::Message* message = self->message; + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString(PyExc_KeyError, + "Field does not belong to message!"); + return NULL; + } + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_SetString(PyExc_KeyError, + "Field is repeated. A singular method is required."); + return NULL; + } + bool has_field = + message->GetReflection()->HasField(*message, field_descriptor); + return PyBool_FromLong(has_field ? 1 : 0); +} + +const google::protobuf::FieldDescriptor* FindFieldWithOneofs( + const google::protobuf::Message* message, const char* field_name, bool* in_oneof) { + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + const google::protobuf::FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == NULL) { + const google::protobuf::OneofDescriptor* oneof_desc = + message->GetDescriptor()->FindOneofByName(field_name); + if (oneof_desc == NULL) { + *in_oneof = false; + return NULL; + } else { + *in_oneof = true; + return message->GetReflection()->GetOneofFieldDescriptor( + *message, oneof_desc); + } + } + return field_descriptor; +} + +PyObject* HasField(CMessage* self, PyObject* arg) { +#if PY_MAJOR_VERSION < 3 + char* field_name; + if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { +#else + char* field_name = PyUnicode_AsUTF8(arg); + if (!field_name) { +#endif + return NULL; + } + + google::protobuf::Message* message = self->message; + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + bool is_in_oneof; + const google::protobuf::FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, field_name, &is_in_oneof); + if (field_descriptor == NULL) { + if (!is_in_oneof) { + PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name); + return NULL; + } else { + Py_RETURN_FALSE; + } + } + + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no singular \"%s\" field.", field_name); + return NULL; + } + + bool has_field = + message->GetReflection()->HasField(*message, field_descriptor); + if (!has_field && field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { + // We may have an invalid enum value stored in the UnknownFieldSet and need + // to check presence in there as well. + const google::protobuf::UnknownFieldSet& unknown_field_set = + message->GetReflection()->GetUnknownFields(*message); + for (int i = 0; i < unknown_field_set.field_count(); ++i) { + if (unknown_field_set.field(i).number() == field_descriptor->number()) { + Py_RETURN_TRUE; + } + } + Py_RETURN_FALSE; + } + return PyBool_FromLong(has_field ? 1 : 0); +} + +PyObject* ClearExtension(CMessage* self, PyObject* arg) { + if (self->extensions != NULL) { + return extension_dict::ClearExtension(self->extensions, arg); + } + PyErr_SetString(PyExc_TypeError, "Message is not extendable"); + return NULL; +} + +PyObject* HasExtension(CMessage* self, PyObject* arg) { + if (self->extensions != NULL) { + return extension_dict::HasExtension(self->extensions, arg); + } + PyErr_SetString(PyExc_TypeError, "Message is not extendable"); + return NULL; +} + +// --------------------------------------------------------------------- +// Releasing messages +// +// The Python API's ClearField() and Clear() methods behave +// differently than their C++ counterparts. While the C++ versions +// clears the children the Python versions detaches the children, +// without touching their content. This impedance mismatch causes +// some complexity in the implementation, which is captured in this +// section. +// +// When a CMessage field is cleared we need to: +// +// * Release the Message used as the backing store for the CMessage +// from its parent. +// +// * Change the owner field of the released CMessage and all of its +// children to point to the newly released Message. +// +// * Clear the weak references from the released CMessage to the +// parent. +// +// When a RepeatedCompositeContainer field is cleared we need to: +// +// * Release all the Message used as the backing store for the +// CMessages stored in the container. +// +// * Change the owner field of all the released CMessage and all of +// their children to point to the newly released Messages. +// +// * Clear the weak references from the released container to the +// parent. + +struct SetOwnerVisitor : public ChildVisitor { + // new_owner must outlive this object. + explicit SetOwnerVisitor(const shared_ptr<Message>& new_owner) + : new_owner_(new_owner) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + repeated_composite_container::SetOwner(container, new_owner_); + return 0; + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + repeated_scalar_container::SetOwner(container, new_owner_); + return 0; + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return SetOwner(cmessage, new_owner_); + } + + private: + const shared_ptr<Message>& new_owner_; +}; + +// Change the owner of this CMessage and all its children, recursively. +int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { + self->owner = new_owner; + if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1) + return -1; + return 0; +} + +// Releases the message specified by 'field' and returns the +// pointer. If the field does not exist a new message is created using +// 'descriptor'. The caller takes ownership of the returned pointer. +Message* ReleaseMessage(google::protobuf::Message* message, + const google::protobuf::Descriptor* descriptor, + const google::protobuf::FieldDescriptor* field_descriptor) { + Message* released_message = message->GetReflection()->ReleaseMessage( + message, field_descriptor, global_message_factory); + // ReleaseMessage will return NULL which differs from + // child_cmessage->message, if the field does not exist. In this case, + // the latter points to the default instance via a const_cast<>, so we + // have to reset it to a new mutable object since we are taking ownership. + if (released_message == NULL) { + const Message* prototype = global_message_factory->GetPrototype( + descriptor); + GOOGLE_DCHECK(prototype != NULL); + released_message = prototype->New(); + } + + return released_message; +} + +int ReleaseSubMessage(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* child_cmessage) { + // Release the Message + shared_ptr<Message> released_message(ReleaseMessage( + message, child_cmessage->message->GetDescriptor(), field_descriptor)); + child_cmessage->message = released_message.get(); + child_cmessage->owner.swap(released_message); + child_cmessage->parent = NULL; + child_cmessage->parent_field = NULL; + child_cmessage->read_only = false; + return ForEachCompositeField(child_cmessage, + SetOwnerVisitor(child_cmessage->owner)); +} + +struct ReleaseChild : public ChildVisitor { + // message must outlive this object. + explicit ReleaseChild(google::protobuf::Message* parent_message) : + parent_message_(parent_message) {} + + int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { + return repeated_composite_container::Release( + reinterpret_cast<RepeatedCompositeContainer*>(container)); + } + + int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { + return repeated_scalar_container::Release( + reinterpret_cast<RepeatedScalarContainer*>(container)); + } + + int VisitCMessage(CMessage* cmessage, + const google::protobuf::FieldDescriptor* field_descriptor) { + return ReleaseSubMessage(parent_message_, field_descriptor, + reinterpret_cast<CMessage*>(cmessage)); + } + + google::protobuf::Message* parent_message_; +}; + +int InternalReleaseFieldByDescriptor( + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* composite_field, + google::protobuf::Message* parent_message) { + return VisitCompositeField( + field_descriptor, + composite_field, + ReleaseChild(parent_message)); +} + +int InternalReleaseField(CMessage* self, PyObject* composite_field, + PyObject* name) { + PyObject* cdescriptor = GetDescriptor(self, name); + if (cdescriptor != NULL) { + const google::protobuf::FieldDescriptor* descriptor = + reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor; + return InternalReleaseFieldByDescriptor( + descriptor, composite_field, self->message); + } + + return 0; +} + +PyObject* ClearFieldByDescriptor( + CMessage* self, + const google::protobuf::FieldDescriptor* descriptor) { + if (!FIELD_BELONGS_TO_MESSAGE(descriptor, self->message)) { + PyErr_SetString(PyExc_KeyError, + "Field does not belong to message!"); + return NULL; + } + AssureWritable(self); + self->message->GetReflection()->ClearField(self->message, descriptor); + Py_RETURN_NONE; +} + +PyObject* ClearField(CMessage* self, PyObject* arg) { + char* field_name; + if (!PyString_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "field name must be a string"); + return NULL; + } +#if PY_MAJOR_VERSION < 3 + if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { + return NULL; + } +#else + field_name = PyUnicode_AsUTF8(arg); +#endif + AssureWritable(self); + google::protobuf::Message* message = self->message; + const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + ScopedPyObjectPtr arg_in_oneof; + bool is_in_oneof; + const google::protobuf::FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, field_name, &is_in_oneof); + if (field_descriptor == NULL) { + if (!is_in_oneof) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no \"%s\" field.", field_name); + return NULL; + } else { + Py_RETURN_NONE; + } + } else if (is_in_oneof) { + arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str())); + arg = arg_in_oneof.get(); + } + + PyObject* composite_field = PyDict_GetItem(self->composite_fields, + arg); + + // Only release the field if there's a possibility that there are + // references to it. + if (composite_field != NULL) { + if (InternalReleaseField(self, composite_field, arg) < 0) { + return NULL; + } + PyDict_DelItem(self->composite_fields, arg); + } + message->GetReflection()->ClearField(message, field_descriptor); + if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { + google::protobuf::UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + unknown_field_set->DeleteByNumber(field_descriptor->number()); + } + + Py_RETURN_NONE; +} + +PyObject* Clear(CMessage* self) { + AssureWritable(self); + if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1) + return NULL; + + // The old ExtensionDict still aliases this CMessage, but all its + // fields have been released. + if (self->extensions != NULL) { + Py_CLEAR(self->extensions); + PyObject* py_extension_dict = PyObject_CallObject( + reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL); + if (py_extension_dict == NULL) { + return NULL; + } + ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>( + py_extension_dict); + extension_dict->parent = self; + extension_dict->message = self->message; + self->extensions = extension_dict; + } + PyDict_Clear(self->composite_fields); + self->message->Clear(); + Py_RETURN_NONE; +} + +// --------------------------------------------------------------------- + +static string GetMessageName(CMessage* self) { + if (self->parent_field != NULL) { + return self->parent_field->descriptor->full_name(); + } else { + return self->message->GetDescriptor()->full_name(); + } +} + +static PyObject* SerializeToString(CMessage* self, PyObject* args) { + if (!self->message->IsInitialized()) { + ScopedPyObjectPtr errors(FindInitializationErrors(self)); + if (errors == NULL) { + return NULL; + } + ScopedPyObjectPtr comma(PyString_FromString(",")); + if (comma == NULL) { + return NULL; + } + ScopedPyObjectPtr joined( + PyObject_CallMethod(comma.get(), "join", "O", errors.get())); + if (joined == NULL) { + return NULL; + } + PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s", + GetMessageName(self).c_str(), PyString_AsString(joined.get())); + return NULL; + } + int size = self->message->ByteSize(); + if (size <= 0) { + return PyBytes_FromString(""); + } + PyObject* result = PyBytes_FromStringAndSize(NULL, size); + if (result == NULL) { + return NULL; + } + char* buffer = PyBytes_AS_STRING(result); + self->message->SerializeWithCachedSizesToArray( + reinterpret_cast<uint8*>(buffer)); + return result; +} + +static PyObject* SerializePartialToString(CMessage* self) { + string contents; + self->message->SerializePartialToString(&contents); + return PyBytes_FromStringAndSize(contents.c_str(), contents.size()); +} + +// Formats proto fields for ascii dumps using python formatting functions where +// appropriate. +class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter { + public: + PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} + + // Python has some differences from C++ when printing floating point numbers. + // + // 1) Trailing .0 is always printed. + // 2) Outputted is rounded to 12 digits. + // + // We override floating point printing with the C-API function for printing + // Python floats to ensure consistency. + string PrintFloat(float value) const { return PrintDouble(value); } + string PrintDouble(double value) const { + reinterpret_cast<PyFloatObject*>(float_holder_.get())->ob_fval = value; + ScopedPyObjectPtr s(PyObject_Str(float_holder_.get())); + if (s == NULL) return string(); +#if PY_MAJOR_VERSION < 3 + char *cstr = PyBytes_AS_STRING(static_cast<PyObject*>(s)); +#else + char *cstr = PyUnicode_AsUTF8(s); +#endif + return string(cstr); + } + + private: + // Holder for a python float object which we use to allow us to use + // the Python API for printing doubles. We initialize once and then + // directly modify it for every float printed to save on allocations + // and refcounting. + ScopedPyObjectPtr float_holder_; +}; + +static PyObject* ToStr(CMessage* self) { + google::protobuf::TextFormat::Printer printer; + // Passes ownership + printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter()); + printer.SetHideUnknownFields(true); + string output; + if (!printer.PrintToString(*self->message, &output)) { + PyErr_SetString(PyExc_ValueError, "Unable to convert message to str"); + return NULL; + } + return PyString_FromString(output.c_str()); +} + +PyObject* MergeFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Must be a message"); + return NULL; + } + + other_message = reinterpret_cast<CMessage*>(arg); + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Tried to merge from a message with a different type. " + "to: %s, from: %s", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + AssureWritable(self); + + // TODO(tibell): Message::MergeFrom might turn some child Messages + // into mutable messages, invalidating the message field in the + // corresponding CMessages. We should run a FixupMessageReferences + // pass here. + + self->message->MergeFrom(*other_message->message); + Py_RETURN_NONE; +} + +static PyObject* CopyFrom(CMessage* self, PyObject* arg) { + CMessage* other_message; + if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Must be a message"); + return NULL; + } + + other_message = reinterpret_cast<CMessage*>(arg); + + if (self == other_message) { + Py_RETURN_NONE; + } + + if (other_message->message->GetDescriptor() != + self->message->GetDescriptor()) { + PyErr_Format(PyExc_TypeError, + "Tried to copy from a message with a different type. " + "to: %s, from: %s", + self->message->GetDescriptor()->full_name().c_str(), + other_message->message->GetDescriptor()->full_name().c_str()); + return NULL; + } + + AssureWritable(self); + + // CopyFrom on the message will not clean up self->composite_fields, + // which can leave us in an inconsistent state, so clear it out here. + Clear(self); + + self->message->CopyFrom(*other_message->message); + + Py_RETURN_NONE; +} + +static PyObject* MergeFromString(CMessage* self, PyObject* arg) { + const void* data; + Py_ssize_t data_length; + if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) { + return NULL; + } + + AssureWritable(self); + google::protobuf::io::CodedInputStream input( + reinterpret_cast<const uint8*>(data), data_length); + input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory); + bool success = self->message->MergePartialFromCodedStream(&input); + if (success) { + return PyInt_FromLong(input.CurrentPosition()); + } else { + PyErr_Format(DecodeError_class, "Error parsing message"); + return NULL; + } +} + +static PyObject* ParseFromString(CMessage* self, PyObject* arg) { + if (Clear(self) == NULL) { + return NULL; + } + return MergeFromString(self, arg); +} + +static PyObject* ByteSize(CMessage* self, PyObject* args) { + return PyLong_FromLong(self->message->ByteSize()); +} + +static PyObject* RegisterExtension(PyObject* cls, + PyObject* extension_handle) { + ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR)); + if (message_descriptor == NULL) { + return NULL; + } + if (PyObject_SetAttrString(extension_handle, "containing_type", + message_descriptor) < 0) { + return NULL; + } + ScopedPyObjectPtr extensions_by_name( + PyObject_GetAttr(cls, k_extensions_by_name)); + if (extensions_by_name == NULL) { + PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class"); + return NULL; + } + ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name)); + if (full_name == NULL) { + return NULL; + } + if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) { + return NULL; + } + + // Also store a mapping from extension number to implementing class. + ScopedPyObjectPtr extensions_by_number( + PyObject_GetAttr(cls, k_extensions_by_number)); + if (extensions_by_number == NULL) { + PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); + return NULL; + } + ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); + if (number == NULL) { + return NULL; + } + if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) { + return NULL; + } + + CFieldDescriptor* cdescriptor = + extension_dict::InternalGetCDescriptorFromExtension(extension_handle); + ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor)); + if (cdescriptor == NULL) { + return NULL; + } + Py_INCREF(extension_handle); + cdescriptor->descriptor_field = extension_handle; + const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor; + // Check if it's a message set + if (descriptor->is_extension() && + descriptor->containing_type()->options().message_set_wire_format() && + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE && + descriptor->message_type() == descriptor->extension_scope() && + descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) { + ScopedPyObjectPtr message_name(PyString_FromStringAndSize( + descriptor->message_type()->full_name().c_str(), + descriptor->message_type()->full_name().size())); + if (message_name == NULL) { + return NULL; + } + PyDict_SetItem(extensions_by_name, message_name, extension_handle); + } + + Py_RETURN_NONE; +} + +static PyObject* SetInParent(CMessage* self, PyObject* args) { + AssureWritable(self); + Py_RETURN_NONE; +} + +static PyObject* WhichOneof(CMessage* self, PyObject* arg) { + char* oneof_name; + if (!PyString_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "field name must be a string"); + return NULL; + } + oneof_name = PyString_AsString(arg); + if (oneof_name == NULL) { + return NULL; + } + const google::protobuf::OneofDescriptor* oneof_desc = + self->message->GetDescriptor()->FindOneofByName(oneof_name); + if (oneof_desc == NULL) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no oneof \"%s\" field.", oneof_name); + return NULL; + } + const google::protobuf::FieldDescriptor* field_in_oneof = + self->message->GetReflection()->GetOneofFieldDescriptor( + *self->message, oneof_desc); + if (field_in_oneof == NULL) { + Py_RETURN_NONE; + } else { + return PyString_FromString(field_in_oneof->name().c_str()); + } +} + +static PyObject* ListFields(CMessage* self) { + vector<const google::protobuf::FieldDescriptor*> fields; + self->message->GetReflection()->ListFields(*self->message, &fields); + + PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR); + if (descriptor == NULL) { + return NULL; + } + ScopedPyObjectPtr fields_by_name( + PyObject_GetAttr(descriptor, kfields_by_name)); + if (fields_by_name == NULL) { + return NULL; + } + ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr( + reinterpret_cast<PyObject*>(Py_TYPE(self)), k_extensions_by_name)); + if (extensions_by_name == NULL) { + PyErr_SetString(PyExc_ValueError, "no extensionsbyname"); + return NULL; + } + // Normally, the list will be exactly the size of the fields. + PyObject* all_fields = PyList_New(fields.size()); + if (all_fields == NULL) { + return NULL; + } + + // When there are unknown extensions, the py list will *not* contain + // the field information. Thus the actual size of the py list will be + // smaller than the size of fields. Set the actual size at the end. + Py_ssize_t actual_size = 0; + for (Py_ssize_t i = 0; i < fields.size(); ++i) { + ScopedPyObjectPtr t(PyTuple_New(2)); + if (t == NULL) { + Py_DECREF(all_fields); + return NULL; + } + + if (fields[i]->is_extension()) { + const string& field_name = fields[i]->full_name(); + PyObject* extension_field = PyDict_GetItemString(extensions_by_name, + field_name.c_str()); + if (extension_field == NULL) { + // If we couldn't fetch extension_field, it means the module that + // defines this extension has not been explicitly imported in Python + // code, and the extension hasn't been registered. There's nothing much + // we can do about this, so just skip it in the output to match the + // behavior of the python implementation. + continue; + } + PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions); + if (extensions == NULL) { + Py_DECREF(all_fields); + return NULL; + } + // 'extension' reference later stolen by PyTuple_SET_ITEM. + PyObject* extension = PyObject_GetItem(extensions, extension_field); + if (extension == NULL) { + Py_DECREF(all_fields); + return NULL; + } + Py_INCREF(extension_field); + PyTuple_SET_ITEM(t.get(), 0, extension_field); + // Steals reference to 'extension' + PyTuple_SET_ITEM(t.get(), 1, extension); + } else { + const string& field_name = fields[i]->name(); + ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( + field_name.c_str(), field_name.length())); + if (py_field_name == NULL) { + PyErr_SetString(PyExc_ValueError, "bad string"); + Py_DECREF(all_fields); + return NULL; + } + PyObject* field_descriptor = + PyDict_GetItem(fields_by_name, py_field_name); + if (field_descriptor == NULL) { + Py_DECREF(all_fields); + return NULL; + } + + PyObject* field_value = GetAttr(self, py_field_name); + if (field_value == NULL) { + PyErr_SetObject(PyExc_ValueError, py_field_name); + Py_DECREF(all_fields); + return NULL; + } + Py_INCREF(field_descriptor); + PyTuple_SET_ITEM(t.get(), 0, field_descriptor); + PyTuple_SET_ITEM(t.get(), 1, field_value); + } + PyList_SET_ITEM(all_fields, actual_size, t.release()); + ++actual_size; + } + Py_SIZE(all_fields) = actual_size; + return all_fields; +} + +PyObject* FindInitializationErrors(CMessage* self) { + google::protobuf::Message* message = self->message; + vector<string> errors; + message->FindInitializationErrors(&errors); + + PyObject* error_list = PyList_New(errors.size()); + if (error_list == NULL) { + return NULL; + } + for (Py_ssize_t i = 0; i < errors.size(); ++i) { + const string& error = errors[i]; + PyObject* error_string = PyString_FromStringAndSize( + error.c_str(), error.length()); + if (error_string == NULL) { + Py_DECREF(error_list); + return NULL; + } + PyList_SET_ITEM(error_list, i, error_string); + } + return error_list; +} + +static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { + if (!PyObject_TypeCheck(other, &CMessage_Type)) { + if (opid == Py_EQ) { + Py_RETURN_FALSE; + } else if (opid == Py_NE) { + Py_RETURN_TRUE; + } + } + if (opid == Py_EQ || opid == Py_NE) { + ScopedPyObjectPtr self_fields(ListFields(self)); + ScopedPyObjectPtr other_fields(ListFields( + reinterpret_cast<CMessage*>(other))); + return PyObject_RichCompare(self_fields, other_fields, opid); + } else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + +PyObject* InternalGetScalar( + CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return NULL; + } + + PyObject* result = NULL; + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int32 value = reflection->GetInt32(*message, field_descriptor); + result = PyInt_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + int64 value = reflection->GetInt64(*message, field_descriptor); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = reflection->GetUInt32(*message, field_descriptor); + result = PyInt_FromSize_t(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = reflection->GetUInt64(*message, field_descriptor); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + float value = reflection->GetFloat(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + double value = reflection->GetDouble(*message, field_descriptor); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + bool value = reflection->GetBool(*message, field_descriptor); + result = PyBool_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetString(*message, field_descriptor); + result = ToStringObject(field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + if (!message->GetReflection()->HasField(*message, field_descriptor)) { + // Look for the value in the unknown fields. + google::protobuf::UnknownFieldSet* unknown_field_set = + message->GetReflection()->MutableUnknownFields(message); + for (int i = 0; i < unknown_field_set->field_count(); ++i) { + if (unknown_field_set->field(i).number() == + field_descriptor->number()) { + result = PyInt_FromLong(unknown_field_set->field(i).varint()); + break; + } + } + } + + if (result == NULL) { + const google::protobuf::EnumValueDescriptor* enum_value = + message->GetReflection()->GetEnum(*message, field_descriptor); + result = PyInt_FromLong(enum_value->number()); + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Getting a value from a field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +PyObject* InternalGetSubMessage(CMessage* self, + CFieldDescriptor* cfield_descriptor) { + PyObject* field = cfield_descriptor->descriptor_field; + ScopedPyObjectPtr message_type(PyObject_GetAttr(field, kmessage_type)); + if (message_type == NULL) { + return NULL; + } + ScopedPyObjectPtr concrete_class( + PyObject_GetAttr(message_type, k_concrete_class)); + if (concrete_class == NULL) { + return NULL; + } + PyObject* py_cmsg = cmessage::NewEmpty(concrete_class); + if (py_cmsg == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(py_cmsg, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a CMessage!"); + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + + const google::protobuf::FieldDescriptor* field_descriptor = + cfield_descriptor->descriptor; + const google::protobuf::Reflection* reflection = self->message->GetReflection(); + const google::protobuf::Message& sub_message = reflection->GetMessage( + *self->message, field_descriptor, global_message_factory); + cmsg->owner = self->owner; + cmsg->parent = self; + cmsg->parent_field = cfield_descriptor; + cmsg->read_only = !reflection->HasField(*self->message, field_descriptor); + cmsg->message = const_cast<google::protobuf::Message*>(&sub_message); + + if (InitAttributes(cmsg, NULL, NULL) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + return py_cmsg; +} + +int InternalSetScalar( + CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* arg) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + return -1; + } + + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetUInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetUInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetFloat(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetDouble(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetBool(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + arg, message, field_descriptor, reflection, false, -1)) { + return -1; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetEnum(message, field_descriptor, enum_value); + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return -1; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + + return 0; +} + +PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { + PyObject* py_cmsg = PyObject_CallObject( + reinterpret_cast<PyObject*>(cls), NULL); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + + ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized)); + if (py_length == NULL) { + Py_DECREF(py_cmsg); + return NULL; + } + + if (InitAttributes(cmsg, NULL, NULL) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + return py_cmsg; +} + +static PyObject* AddDescriptors(PyTypeObject* cls, + PyObject* descriptor) { + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + k_extensions_by_name, PyDict_New()) < 0) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + k_extensions_by_number, PyDict_New()) < 0) { + return NULL; + } + + ScopedPyObjectPtr field_descriptors(PyDict_New()); + + ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields")); + if (fields == NULL) { + return NULL; + } + + ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER")); + if (_NUMBER_string == NULL) { + return NULL; + } + + const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get()); + for (int i = 0; i < fields_size; ++i) { + PyObject* field = PyList_GET_ITEM(fields.get(), i); + ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname)); + ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name)); + if (field_name == NULL || full_field_name == NULL) { + PyErr_SetString(PyExc_TypeError, "Name is null"); + return NULL; + } + + PyObject* field_descriptor = + cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name); + if (field_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "Couldn't find field"); + return NULL; + } + Py_INCREF(field); + CFieldDescriptor* cfield_descriptor = reinterpret_cast<CFieldDescriptor*>( + field_descriptor); + cfield_descriptor->descriptor_field = field; + if (PyDict_SetItem(field_descriptors, field_name, field_descriptor) < 0) { + return NULL; + } + + // The FieldDescriptor's name field might either be of type bytes or + // of type unicode, depending on whether the FieldDescriptor was + // parsed from a serialized message or read from the + // <message>_pb2.py module. + ScopedPyObjectPtr field_name_upcased( + PyObject_CallMethod(field_name, "upper", NULL)); + if (field_name_upcased == NULL) { + return NULL; + } + + ScopedPyObjectPtr field_number_name(PyObject_CallMethod( + field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); + if (field_number_name == NULL) { + return NULL; + } + + ScopedPyObjectPtr number(PyInt_FromLong( + cfield_descriptor->descriptor->number())); + if (number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + field_number_name, number) == -1) { + return NULL; + } + } + + PyDict_SetItem(cls->tp_dict, k__descriptors, field_descriptors); + + // Enum Values + ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor, + "enum_types")); + if (enum_types == NULL) { + return NULL; + } + ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types)); + if (type_iter == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_type; + while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) { + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname)); + if (enum_name == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + enum_name, wrapped) == -1) { + return NULL; + } + + ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values")); + if (enum_values == NULL) { + return NULL; + } + ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values)); + if (values_iter == NULL) { + return NULL; + } + ScopedPyObjectPtr enum_value; + while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) { + ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname)); + if (value_name == NULL) { + return NULL; + } + ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value, + "number")); + if (value_number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + value_name, value_number) == -1) { + return NULL; + } + } + if (PyErr_Occurred()) { // If PyIter_Next failed + return NULL; + } + } + if (PyErr_Occurred()) { // If PyIter_Next failed + return NULL; + } + + ScopedPyObjectPtr extension_dict( + PyObject_GetAttr(descriptor, kextensions_by_name)); + if (extension_dict == NULL || !PyDict_Check(extension_dict)) { + PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict"); + return NULL; + } + Py_ssize_t pos = 0; + PyObject* extension_name; + PyObject* extension_field; + + while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) { + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + extension_name, extension_field) == -1) { + return NULL; + } + ScopedPyObjectPtr py_cfield_descriptor( + PyObject_GetAttrString(extension_field, "_cdescriptor")); + if (py_cfield_descriptor == NULL) { + return NULL; + } + CFieldDescriptor* cfield_descriptor = + reinterpret_cast<CFieldDescriptor*>(py_cfield_descriptor.get()); + Py_INCREF(extension_field); + cfield_descriptor->descriptor_field = extension_field; + + ScopedPyObjectPtr field_name_upcased( + PyObject_CallMethod(extension_name, "upper", NULL)); + if (field_name_upcased == NULL) { + return NULL; + } + ScopedPyObjectPtr field_number_name(PyObject_CallMethod( + field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); + if (field_number_name == NULL) { + return NULL; + } + ScopedPyObjectPtr number(PyInt_FromLong( + cfield_descriptor->descriptor->number())); + if (number == NULL) { + return NULL; + } + if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls), + field_number_name, PyInt_FromLong( + cfield_descriptor->descriptor->number())) == -1) { + return NULL; + } + } + + Py_RETURN_NONE; +} + +PyObject* DeepCopy(CMessage* self, PyObject* arg) { + PyObject* clone = PyObject_CallObject( + reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL); + if (clone == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(clone, &CMessage_Type)) { + Py_DECREF(clone); + return NULL; + } + if (InitAttributes(reinterpret_cast<CMessage*>(clone), NULL, NULL) < 0) { + Py_DECREF(clone); + return NULL; + } + if (MergeFrom(reinterpret_cast<CMessage*>(clone), + reinterpret_cast<PyObject*>(self)) == NULL) { + Py_DECREF(clone); + return NULL; + } + return clone; +} + +PyObject* ToUnicode(CMessage* self) { + // Lazy import to prevent circular dependencies + ScopedPyObjectPtr text_format( + PyImport_ImportModule("google.protobuf.text_format")); + if (text_format == NULL) { + return NULL; + } + ScopedPyObjectPtr method_name(PyString_FromString("MessageToString")); + if (method_name == NULL) { + return NULL; + } + Py_INCREF(Py_True); + ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name, + self, Py_True, NULL)); + Py_DECREF(Py_True); + if (encoded == NULL) { + return NULL; + } +#if PY_MAJOR_VERSION < 3 + PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL); +#else + PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL); +#endif + if (decoded == NULL) { + return NULL; + } + return decoded; +} + +PyObject* Reduce(CMessage* self) { + ScopedPyObjectPtr constructor(reinterpret_cast<PyObject*>(Py_TYPE(self))); + constructor.inc(); + ScopedPyObjectPtr args(PyTuple_New(0)); + if (args == NULL) { + return NULL; + } + ScopedPyObjectPtr state(PyDict_New()); + if (state == NULL) { + return NULL; + } + ScopedPyObjectPtr serialized(SerializePartialToString(self)); + if (serialized == NULL) { + return NULL; + } + if (PyDict_SetItemString(state, "serialized", serialized) < 0) { + return NULL; + } + return Py_BuildValue("OOO", constructor.get(), args.get(), state.get()); +} + +PyObject* SetState(CMessage* self, PyObject* state) { + if (!PyDict_Check(state)) { + PyErr_SetString(PyExc_TypeError, "state not a dict"); + return NULL; + } + PyObject* serialized = PyDict_GetItemString(state, "serialized"); + if (serialized == NULL) { + return NULL; + } + if (ParseFromString(self, serialized) == NULL) { + return NULL; + } + Py_RETURN_NONE; +} + +// CMessage static methods: +PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindFieldByName(descriptor_pool, arg); +} + +PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg); +} + +static PyMemberDef Members[] = { + {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, + "Extension dict"}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the message." }, + { "__setstate__", (PyCFunction)SetState, METH_O, + "Inputs picklable representation of the message." }, + { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, + "Outputs a unicode representation of the message." }, + { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS, + "Adds field descriptors to the class" }, + { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, + "Returns the size of the message in bytes." }, + { "Clear", (PyCFunction)Clear, METH_NOARGS, + "Clears the message." }, + { "ClearExtension", (PyCFunction)ClearExtension, METH_O, + "Clears a message field." }, + { "ClearField", (PyCFunction)ClearField, METH_O, + "Clears a message field." }, + { "CopyFrom", (PyCFunction)CopyFrom, METH_O, + "Copies a protocol message into the current message." }, + { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, + METH_NOARGS, + "Finds unset required fields." }, + { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS, + "Creates new method instance from given serialized data." }, + { "HasExtension", (PyCFunction)HasExtension, METH_O, + "Checks if a message field is set." }, + { "HasField", (PyCFunction)HasField, METH_O, + "Checks if a message field is set." }, + { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS, + "Checks if all required fields of a protocol message are set." }, + { "ListFields", (PyCFunction)ListFields, METH_NOARGS, + "Lists all set fields of a message." }, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a protocol message into the current message." }, + { "MergeFromString", (PyCFunction)MergeFromString, METH_O, + "Merges a serialized message into the current message." }, + { "ParseFromString", (PyCFunction)ParseFromString, METH_O, + "Parses a serialized message into the current message." }, + { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS, + "Registers an extension with the current message." }, + { "SerializePartialToString", (PyCFunction)SerializePartialToString, + METH_NOARGS, + "Serializes the message to a string, even if it isn't initialized." }, + { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS, + "Serializes the message to a string, only for initialized messages." }, + { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, + "Sets the has bit of the given field in its parent message." }, + { "WhichOneof", (PyCFunction)WhichOneof, METH_O, + "Returns the name of the field set inside a oneof, " + "or None if no field is set." }, + + // Static Methods. + { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC, + "Registers a new protocol buffer file in the global C++ descriptor pool." }, + { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor, + METH_O | METH_STATIC, "Finds a field descriptor in the message pool." }, + { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor, + METH_O | METH_STATIC, + "Finds a extension descriptor in the message pool." }, + { NULL, NULL} +}; + +PyObject* GetAttr(CMessage* self, PyObject* name) { + PyObject* value = PyDict_GetItem(self->composite_fields, name); + if (value != NULL) { + Py_INCREF(value); + return value; + } + + PyObject* descriptor = GetDescriptor(self, name); + if (descriptor != NULL) { + CFieldDescriptor* cdescriptor = + reinterpret_cast<CFieldDescriptor*>(descriptor); + const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* py_container = PyObject_CallObject( + reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type), + NULL); + if (py_container == NULL) { + return NULL; + } + RepeatedCompositeContainer* container = + reinterpret_cast<RepeatedCompositeContainer*>(py_container); + PyObject* field = cdescriptor->descriptor_field; + PyObject* message_type = PyObject_GetAttr(field, kmessage_type); + if (message_type == NULL) { + return NULL; + } + PyObject* concrete_class = + PyObject_GetAttr(message_type, k_concrete_class); + if (concrete_class == NULL) { + return NULL; + } + container->parent = self; + container->parent_field = cdescriptor; + container->message = self->message; + container->owner = self->owner; + container->subclass_init = concrete_class; + Py_DECREF(message_type); + if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } else { + ScopedPyObjectPtr init_args(PyTuple_Pack(2, self, cdescriptor)); + PyObject* py_container = PyObject_CallObject( + reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), + init_args); + if (py_container == NULL) { + return NULL; + } + if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + } else { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = InternalGetSubMessage(self, cdescriptor); + if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) { + Py_DECREF(sub_message); + return NULL; + } + return sub_message; + } else { + return InternalGetScalar(self, field_descriptor); + } + } + } + + return CMessage_Type.tp_base->tp_getattro(reinterpret_cast<PyObject*>(self), + name); +} + +int SetAttr(CMessage* self, PyObject* name, PyObject* value) { + if (PyDict_Contains(self->composite_fields, name)) { + PyErr_SetString(PyExc_TypeError, "Can't set composite field"); + return -1; + } + + PyObject* descriptor = GetDescriptor(self, name); + if (descriptor != NULL) { + AssureWritable(self); + CFieldDescriptor* cdescriptor = + reinterpret_cast<CFieldDescriptor*>(descriptor); + const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor; + if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { + if (field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { + return InternalSetScalar(self, field_descriptor, value); + } + } + } + + PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + return -1; +} + +} // namespace cmessage + +PyTypeObject CMessage_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.CMessage", // tp_name + sizeof(CMessage), // tp_basicsize + 0, // tp_itemsize + (destructor)cmessage::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + (reprfunc)cmessage::ToStr, // tp_str + (getattrofunc)cmessage::GetAttr, // tp_getattro + (setattrofunc)cmessage::SetAttr, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "A ProtocolMessage", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)cmessage::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cmessage::Methods, // tp_methods + cmessage::Members, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)cmessage::Init, // tp_init + 0, // tp_alloc + cmessage::New, // tp_new +}; + +// --- Exposing the C proto living inside Python proto to C code: + +const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); +Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); + +static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { + if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); + return cmsg->message; +} + +static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { + if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(msg); + if (PyDict_Size(cmsg->composite_fields) != 0 || + (cmsg->extensions != NULL && + PyDict_Size(cmsg->extensions->values) != 0)) { + // There is currently no way of accurately syncing arbitrary changes to + // the underlying C++ message back to the CMessage (e.g. removed repeated + // composite containers). We only allow direct mutation of the underlying + // C++ message if there is no child data in the CMessage. + return NULL; + } + cmessage::AssureWritable(cmsg); + return cmsg->message; +} + +static const char module_docstring[] = +"python-proto2 is a module that can be used to enhance proto2 Python API\n" +"performance.\n" +"\n" +"It provides access to the protocol buffers C++ reflection API that\n" +"implements the basic protocol buffer functions."; + +void InitGlobals() { + // TODO(gps): Check all return values in this function for NULL and propagate + // the error (MemoryError) on up to result in an import failure. These should + // also be freed and reset to NULL during finalization. + kPythonZero = PyInt_FromLong(0); + kint32min_py = PyInt_FromLong(kint32min); + kint32max_py = PyInt_FromLong(kint32max); + kuint32max_py = PyLong_FromLongLong(kuint32max); + kint64min_py = PyLong_FromLongLong(kint64min); + kint64max_py = PyLong_FromLongLong(kint64max); + kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max); + + kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); + k__descriptors = PyString_FromString("__descriptors"); + kfull_name = PyString_FromString("full_name"); + kis_extendable = PyString_FromString("is_extendable"); + kextensions_by_name = PyString_FromString("extensions_by_name"); + k_extensions_by_name = PyString_FromString("_extensions_by_name"); + k_extensions_by_number = PyString_FromString("_extensions_by_number"); + k_concrete_class = PyString_FromString("_concrete_class"); + kmessage_type = PyString_FromString("message_type"); + kname = PyString_FromString("name"); + kfields_by_name = PyString_FromString("fields_by_name"); + + global_message_factory = new DynamicMessageFactory(GetDescriptorPool()); + global_message_factory->SetDelegateToGeneratedFactory(true); + + descriptor_pool = reinterpret_cast<google::protobuf::python::CDescriptorPool*>( + Python_NewCDescriptorPool(NULL, NULL)); +} + +bool InitProto2MessageModule(PyObject *m) { + InitGlobals(); + + google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) { + return false; + } + + // All three of these are actually set elsewhere, directly onto the child + // protocol buffer message class, but set them here as well to document that + // subclasses need to set these. + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + k_extensions_by_name, Py_None); + PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + k_extensions_by_number, Py_None); + + PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>( + &google::protobuf::python::CMessage_Type)); + + google::protobuf::python::RepeatedScalarContainer_Type.tp_new = PyType_GenericNew; + google::protobuf::python::RepeatedScalarContainer_Type.tp_hash = + PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "RepeatedScalarContainer", + reinterpret_cast<PyObject*>( + &google::protobuf::python::RepeatedScalarContainer_Type)); + + google::protobuf::python::RepeatedCompositeContainer_Type.tp_new = PyType_GenericNew; + google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash = + PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast<PyObject*>( + &google::protobuf::python::RepeatedCompositeContainer_Type)); + + google::protobuf::python::ExtensionDict_Type.tp_new = PyType_GenericNew; + google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) { + return false; + } + + PyModule_AddObject( + m, "ExtensionDict", + reinterpret_cast<PyObject*>(&google::protobuf::python::ExtensionDict_Type)); + + if (!google::protobuf::python::InitDescriptor()) { + return false; + } + + PyObject* enum_type_wrapper = PyImport_ImportModule( + "google.protobuf.internal.enum_type_wrapper"); + if (enum_type_wrapper == NULL) { + return false; + } + google::protobuf::python::EnumTypeWrapper_class = + PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper"); + Py_DECREF(enum_type_wrapper); + + PyObject* message_module = PyImport_ImportModule( + "google.protobuf.message"); + if (message_module == NULL) { + return false; + } + google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module, + "EncodeError"); + google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module, + "DecodeError"); + Py_DECREF(message_module); + + PyObject* pickle_module = PyImport_ImportModule("pickle"); + if (pickle_module == NULL) { + return false; + } + google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module, + "PickleError"); + Py_DECREF(pickle_module); + + // Override {Get,Mutable}CProtoInsidePyProto. + google::protobuf::python::GetCProtoInsidePyProtoPtr = + google::protobuf::python::GetCProtoInsidePyProtoImpl; + google::protobuf::python::MutableCProtoInsidePyProtoPtr = + google::protobuf::python::MutableCProtoInsidePyProtoImpl; + + return true; +} + +} // namespace python +} // namespace protobuf + + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + "_message", + google::protobuf::python::module_docstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__message +#define INITFUNC_ERRORVAL NULL +#else // Python 2 +#define INITFUNC init_message +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC(void) { + PyObject* m; +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&_module); +#else + m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring); +#endif + if (m == NULL) { + return INITFUNC_ERRORVAL; + } + + if (!google::protobuf::python::InitProto2MessageModule(m)) { + Py_DECREF(m); + return INITFUNC_ERRORVAL; + } + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif + } +} +} // namespace google diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h new file mode 100644 index 0000000..28e504f --- /dev/null +++ b/python/google/protobuf/pyext/message.h @@ -0,0 +1,305 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif +#include <string> + + +namespace google { +namespace protobuf { + +class Message; +class Reflection; +class FieldDescriptor; + +using internal::shared_ptr; + +namespace python { + +struct CFieldDescriptor; +struct ExtensionDict; + +typedef struct CMessage { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python CMessage holds a reference to it in + // order to keep it alive as long as there's a Python object that + // references any part of the tree. + shared_ptr<Message> owner; + + // Weak reference to a parent CMessage object. This is NULL for any top-level + // message and is set for any child message (i.e. a child submessage or a + // part of a repeated composite field). + // + // Used to make sure all ancestors are also mutable when first modifying + // a child submessage (in other words, turning a default message instance + // into a mutable one). + // + // If a submessage is released (becomes a new top-level message), this field + // MUST be set to NULL. The parent may get deallocated and further attempts + // to use this pointer will result in a crash. + struct CMessage* parent; + + // Weak reference to the parent's descriptor that describes this submessage. + // Used together with the parent's message when making a default message + // instance mutable. + // TODO(anuraag): With a bit of work on the Python/C++ layer, it should be + // possible to make this a direct pointer to a C++ FieldDescriptor, this would + // be easier if this implementation replaces upstream. + CFieldDescriptor* parent_field; + + // Pointer to the C++ Message object for this CMessage. The + // CMessage does not own this pointer. + Message* message; + + // Indicates this submessage is pointing to a default instance of a message. + // Submessages are always first created as read only messages and are then + // made writable, at which point this field is set to false. + bool read_only; + + // A reference to a Python dictionary containing CMessage, + // RepeatedCompositeContainer, and RepeatedScalarContainer + // objects. Used as a cache to make sure we don't have to make a + // Python wrapper for the C++ Message objects on every access, or + // deal with the synchronization nightmare that could create. + PyObject* composite_fields; + + // A reference to the dictionary containing the message's extensions. + // Similar to composite_fields, acting as a cache, but also contains the + // required extension dict logic. + ExtensionDict* extensions; +} CMessage; + +extern PyTypeObject CMessage_Type; + +namespace cmessage { + +// Create a new empty message that can be populated by the parent. +PyObject* NewEmpty(PyObject* type); + +// Release a submessage from its proto tree, making it a new top-level messgae. +// A new message will be created if this is a read-only default instance. +// +// Corresponds to reflection api method ReleaseMessage. +int ReleaseSubMessage(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* child_cmessage); + +// Initializes a new CMessage instance for a submessage. Only called once per +// submessage as the result is cached in composite_fields. +// +// Corresponds to reflection api method GetMessage. +PyObject* InternalGetSubMessage(CMessage* self, + CFieldDescriptor* cfield_descriptor); + +// Deletes a range of C++ submessages in a repeated field (following a +// removal in a RepeatedCompositeContainer). +// +// Releases messages to the provided cmessage_list if it is not NULL rather +// than just removing them from the underlying proto. This cmessage_list must +// have a CMessage for each underlying submessage. The CMessages refered to +// by slice will be removed from cmessage_list by this function. +// +// Corresponds to reflection api method RemoveLast. +int InternalDeleteRepeatedField(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* slice, PyObject* cmessage_list); + +// Sets the specified scalar value to the message. +int InternalSetScalar(CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor, + PyObject* value); + +// Retrieves the specified scalar value from the message. +// +// Returns a new python reference. +PyObject* InternalGetScalar(CMessage* self, + const google::protobuf::FieldDescriptor* field_descriptor); + +// Clears the message, removing all contained data. Extension dictionary and +// submessages are released first if there are remaining external references. +// +// Corresponds to message api method Clear. +PyObject* Clear(CMessage* self); + +// Clears the data described by the given descriptor. Used to clear extensions +// (which don't have names). Extension release is handled by ExtensionDict +// class, not this function. +// TODO(anuraag): Try to make this discrepancy in release semantics with +// ClearField less confusing. +// +// Corresponds to reflection api method ClearField. +PyObject* ClearFieldByDescriptor( + CMessage* self, + const google::protobuf::FieldDescriptor* descriptor); + +// Clears the data for the given field name. The message is released if there +// are any external references. +// +// Corresponds to reflection api method ClearField. +PyObject* ClearField(CMessage* self, PyObject* arg); + +// Checks if the message has the field described by the descriptor. Used for +// extensions (which have no name). +// +// Corresponds to reflection api method HasField +PyObject* HasFieldByDescriptor( + CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor); + +// Checks if the message has the named field. +// +// Corresponds to reflection api method HasField. +PyObject* HasField(CMessage* self, PyObject* arg); + +// Initializes constants/enum values on a message. This is called by +// RepeatedCompositeContainer and ExtensionDict after calling the constructor. +// TODO(anuraag): Make it always called from within the constructor since it can +int InitAttributes(CMessage* self, PyObject* descriptor, PyObject* kwargs); + +PyObject* MergeFrom(CMessage* self, PyObject* arg); + +// Retrieves an attribute named 'name' from CMessage 'self'. Returns +// the attribute value on success, or NULL on failure. +// +// Returns a new reference. +PyObject* GetAttr(CMessage* self, PyObject* name); + +// Set the value of the attribute named 'name', for CMessage 'self', +// to the value 'value'. Returns -1 on failure. +int SetAttr(CMessage* self, PyObject* name, PyObject* value); + +PyObject* FindInitializationErrors(CMessage* self); + +// Set the owner field of self and any children of self, recursively. +// Used when self is being released and thus has a new owner (the +// released Message.) +int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner); + +int AssureWritable(CMessage* self); + +} // namespace cmessage + +/* Is 64bit */ +#define IS_64BIT (SIZEOF_LONG == 8) + +#define FIELD_BELONGS_TO_MESSAGE(field_descriptor, message) \ + ((message)->GetDescriptor() == (field_descriptor)->containing_type()) + +#define FIELD_IS_REPEATED(field_descriptor) \ + ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) + +#define GOOGLE_CHECK_GET_INT32(arg, value, err) \ + int32 value; \ + if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_INT64(arg, value, err) \ + int64 value; \ + if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \ + uint32 value; \ + if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \ + uint64 value; \ + if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_FLOAT(arg, value, err) \ + float value; \ + if (!CheckAndGetFloat(arg, &value)) { \ + return err; \ + } \ + +#define GOOGLE_CHECK_GET_DOUBLE(arg, value, err) \ + double value; \ + if (!CheckAndGetDouble(arg, &value)) { \ + return err; \ + } + +#define GOOGLE_CHECK_GET_BOOL(arg, value, err) \ + bool value; \ + if (!CheckAndGetBool(arg, &value)) { \ + return err; \ + } + + +extern PyObject* kPythonZero; +extern PyObject* kint32min_py; +extern PyObject* kint32max_py; +extern PyObject* kuint32max_py; +extern PyObject* kint64min_py; +extern PyObject* kint64max_py; +extern PyObject* kuint64max_py; + +#define C(str) const_cast<char*>(str) + +void FormatTypeError(PyObject* arg, char* expected_types); +template<class T> +bool CheckAndGetInteger( + PyObject* arg, T* value, PyObject* min, PyObject* max); +bool CheckAndGetDouble(PyObject* arg, double* value); +bool CheckAndGetFloat(PyObject* arg, float* value); +bool CheckAndGetBool(PyObject* arg, bool* value); +bool CheckAndSetString( + PyObject* arg, google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* descriptor, + const google::protobuf::Reflection* reflection, + bool append, + int index); +PyObject* ToStringObject( + const google::protobuf::FieldDescriptor* descriptor, string value); + +extern PyObject* PickleError_class; + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ diff --git a/python/google/protobuf/pyext/message_factory_cpp2_test.py b/python/google/protobuf/pyext/message_factory_cpp2_test.py new file mode 100644 index 0000000..fb52e1b --- /dev/null +++ b/python/google/protobuf/pyext/message_factory_cpp2_test.py @@ -0,0 +1,56 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_factory_test import * + + +class ConfirmCppApi2Test(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/pyext/proto2_api_test.proto b/python/google/protobuf/pyext/proto2_api_test.proto new file mode 100644 index 0000000..eef9b73 --- /dev/null +++ b/python/google/protobuf/pyext/proto2_api_test.proto @@ -0,0 +1,38 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import "google/protobuf/internal/cpp/proto1_api_test.proto"; + +package google.protobuf.python.internal; + +message TestNestedProto1APIMessage { + optional int32 a = 1; + optional TestMessage.NestedMessage b = 2; +} diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto new file mode 100644 index 0000000..ee6d5ab --- /dev/null +++ b/python/google/protobuf/pyext/python.proto @@ -0,0 +1,66 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: tibell@google.com (Johan Tibell) +// +// These message definitions are used to exercises known corner cases +// in the C++ implementation of the Python API. + + +package google.protobuf.python.internal; + +// Protos optimized for SPEED use a strict superset of the generated code +// of equivalent ones optimized for CODE_SIZE, so we should optimize all our +// tests for speed unless explicitly testing code size optimization. +option optimize_for = SPEED; + +message TestAllTypes { + message NestedMessage { + optional int32 bb = 1; + optional ForeignMessage cc = 2; + } + + repeated NestedMessage repeated_nested_message = 1; + optional NestedMessage optional_nested_message = 2; + optional int32 optional_int32 = 3; +} + +message ForeignMessage { + optional int32 c = 1; + repeated int32 d = 2; +} + +message TestAllExtensions { + extensions 1 to max; +} + +extend TestAllExtensions { + optional TestAllTypes.NestedMessage optional_nested_message_extension = 1; +} diff --git a/python/google/protobuf/pyext/python_protobuf.h b/python/google/protobuf/pyext/python_protobuf.h new file mode 100644 index 0000000..c5b0b1c --- /dev/null +++ b/python/google/protobuf/pyext/python_protobuf.h @@ -0,0 +1,57 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: qrczak@google.com (Marcin Kowalczyk) +// +// This module exposes the C proto inside the given Python proto, in +// case the Python proto is implemented with a C proto. + +#ifndef GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ +#define GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ + +#include <Python.h> + +namespace google { +namespace protobuf { + +class Message; + +namespace python { + +// Return the pointer to the C proto inside the given Python proto, +// or NULL when this is not a Python proto implemented with a C proto. +const Message* GetCProtoInsidePyProto(PyObject* msg); +Message* MutableCProtoInsidePyProto(PyObject* msg); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ diff --git a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py new file mode 100755 index 0000000..d7fce5f --- /dev/null +++ b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py @@ -0,0 +1,94 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which tests the generated C++ implementation.""" + +__author__ = 'jasonh@google.com (Jason Hsueh)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' + +from google.apputils import basetest +from google.protobuf.internal import api_implementation +from google.protobuf.internal import more_extensions_dynamic_pb2 +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal.reflection_test import * + + +class ReflectionCppTest(basetest.TestCase): + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + def testExtensionOfGeneratedTypeInDynamicFile(self): + """Tests that a file built dynamically can extend a generated C++ type. + + The C++ implementation uses a DescriptorPool that has the generated + DescriptorPool as an underlay. Typically, a type can only find + extensions in its own pool. With the python C-extension, the generated C++ + extendee may be available, but not the extension. This tests that the + C-extension implements the correct special handling to make such extensions + available. + """ + pb1 = more_extensions_pb2.ExtendedMessage() + # Test that basic accessors work. + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + + # Now serialize the data and parse to a new message. + pb2 = more_extensions_pb2.ExtendedMessage() + pb2.MergeFromString(pb1.SerializeToString()) + + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + self.assertEqual( + 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) + self.assertEqual( + 24, + pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) + + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc new file mode 100644 index 0000000..b164505 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -0,0 +1,763 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/repeated_composite_container.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_Check PyLong_Check + #define PyInt_AsLong PyLong_AsLong + #define PyInt_FromLong PyLong_FromLong +#endif + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace repeated_composite_container { + +// TODO(tibell): We might also want to check: +// GOOGLE_CHECK_NOTNULL((self)->owner.get()); +#define GOOGLE_CHECK_ATTACHED(self) \ + do { \ + GOOGLE_CHECK_NOTNULL((self)->message); \ + GOOGLE_CHECK_NOTNULL((self)->parent_field); \ + } while (0); + +#define GOOGLE_CHECK_RELEASED(self) \ + do { \ + GOOGLE_CHECK((self)->owner.get() == NULL); \ + GOOGLE_CHECK((self)->message == NULL); \ + GOOGLE_CHECK((self)->parent_field == NULL); \ + GOOGLE_CHECK((self)->parent == NULL); \ + } while (0); + +// Returns a new reference. +static PyObject* GetKey(PyObject* x) { + // Just the identity function. + Py_INCREF(x); + return x; +} + +#define GET_KEY(keyfunc, value) \ + ((keyfunc) == NULL ? \ + GetKey((value)) : \ + PyObject_CallFunctionObjArgs((keyfunc), (value), NULL)) + +// Converts a comparison function that returns -1, 0, or 1 into a +// less-than predicate. +// +// Returns -1 on error, 1 if x < y, 0 if x >= y. +static int islt(PyObject *x, PyObject *y, PyObject *compare) { + if (compare == NULL) + return PyObject_RichCompareBool(x, y, Py_LT); + + ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL)); + if (res == NULL) + return -1; + if (!PyInt_Check(res)) { + PyErr_Format(PyExc_TypeError, + "comparison function must return int, not %.200s", + Py_TYPE(res)->tp_name); + return -1; + } + return PyInt_AsLong(res) < 0; +} + +// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps +// TODO(anuraag): Is there a better way to do this then reinventing the wheel? +static int InternalQuickSort(RepeatedCompositeContainer* self, + Py_ssize_t start, + Py_ssize_t limit, + PyObject* cmp, + PyObject* keyfunc) { + if (limit - start <= 1) + return 0; // Nothing to sort. + + GOOGLE_CHECK_ATTACHED(self); + + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; + Py_ssize_t left; + Py_ssize_t right; + + PyObject* children = self->child_messages; + + do { + left = start; + right = limit; + ScopedPyObjectPtr mid( + GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2))); + do { + ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left))); + int is_lt = islt(key, mid, cmp); + if (is_lt == -1) + return -1; + /* array[left]<x */ + while (is_lt) { + ++left; + ScopedPyObjectPtr key(GET_KEY(keyfunc, + PyList_GET_ITEM(children, left))); + is_lt = islt(key, mid, cmp); + if (is_lt == -1) + return -1; + } + key.reset(GET_KEY(keyfunc, PyList_GET_ITEM(children, right - 1))); + is_lt = islt(mid, key, cmp); + if (is_lt == -1) + return -1; + while (is_lt) { + --right; + ScopedPyObjectPtr key(GET_KEY(keyfunc, + PyList_GET_ITEM(children, right - 1))); + is_lt = islt(mid, key, cmp); + if (is_lt == -1) + return -1; + } + if (left < right) { + --right; + if (left < right) { + reflection->SwapElements(message, descriptor, left, right); + PyObject* tmp = PyList_GET_ITEM(children, left); + PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right)); + PyList_SET_ITEM(children, right, tmp); + } + ++left; + } + } while (left < right); + + if ((right - start) < (limit - left)) { + /* sort [start..right[ */ + if (start < (right - 1)) { + InternalQuickSort(self, start, right, cmp, keyfunc); + } + + /* sort [left..limit[ */ + start = left; + } else { + /* sort [left..limit[ */ + if (left < (limit - 1)) { + InternalQuickSort(self, left, limit, cmp, keyfunc); + } + + /* sort [start..right[ */ + limit = right; + } + } while (start < (limit - 1)); + + return 0; +} + +#undef GET_KEY + +// --------------------------------------------------------------------- +// len() + +static Py_ssize_t Length(RepeatedCompositeContainer* self) { + google::protobuf::Message* message = self->message; + if (message != NULL) { + return message->GetReflection()->FieldSize(*message, + self->parent_field->descriptor); + } else { + // The container has been released (i.e. by a call to Clear() or + // ClearField() on the parent) and thus there's no message. + return PyList_GET_SIZE(self->child_messages); + } +} + +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int UpdateChildMessages(RepeatedCompositeContainer* self) { + if (self->message == NULL) + return 0; + + // A MergeFrom on a parent message could have caused extra messages to be + // added in the underlying protobuf so add them to our list. They can never + // be removed in such a way so there's no need to worry about that. + Py_ssize_t message_length = Length(self); + Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages); + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + for (Py_ssize_t i = child_length; i < message_length; ++i) { + const Message& sub_message = reflection->GetRepeatedMessage( + *(self->message), self->parent_field->descriptor, i); + ScopedPyObjectPtr py_cmsg(cmessage::NewEmpty(self->subclass_init)); + if (py_cmsg == NULL) { + return -1; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg.get()); + cmsg->owner = self->owner; + cmsg->message = const_cast<google::protobuf::Message*>(&sub_message); + cmsg->parent = self->parent; + if (cmessage::InitAttributes(cmsg, NULL, NULL) < 0) { + return -1; + } + PyList_Append(self->child_messages, py_cmsg); + } + return 0; +} + +// --------------------------------------------------------------------- +// add() + +static PyObject* AddToAttached(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + GOOGLE_CHECK_ATTACHED(self); + + if (UpdateChildMessages(self) < 0) { + return NULL; + } + if (cmessage::AssureWritable(self->parent) == -1) + return NULL; + google::protobuf::Message* message = self->message; + google::protobuf::Message* sub_message = + message->GetReflection()->AddMessage(message, + self->parent_field->descriptor); + PyObject* py_cmsg = cmessage::NewEmpty(self->subclass_init); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + + cmsg->owner = self->owner; + cmsg->message = sub_message; + cmsg->parent = self->parent; + // cmessage::InitAttributes must be called after cmsg->message has + // been set. + if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + PyList_Append(self->child_messages, py_cmsg); + return py_cmsg; +} + +static PyObject* AddToReleased(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + GOOGLE_CHECK_RELEASED(self); + + // Create the CMessage + PyObject* py_cmsg = PyObject_CallObject(self->subclass_init, NULL); + if (py_cmsg == NULL) + return NULL; + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) { + Py_DECREF(py_cmsg); + return NULL; + } + + // The Message got created by the call to subclass_init above and + // it set self->owner to the newly allocated message. + + PyList_Append(self->child_messages, py_cmsg); + return py_cmsg; +} + +PyObject* Add(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + if (self->message == NULL) + return AddToReleased(self, args, kwargs); + else + return AddToAttached(self, args, kwargs); +} + +// --------------------------------------------------------------------- +// extend() + +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + if (UpdateChildMessages(self) < 0) { + return NULL; + } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return NULL; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + if (!PyObject_TypeCheck(next, &CMessage_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a cmessage"); + return NULL; + } + ScopedPyObjectPtr new_message(Add(self, NULL, NULL)); + if (new_message == NULL) { + return NULL; + } + CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get()); + if (cmessage::MergeFrom(new_cmessage, next) == NULL) { + return NULL; + } + } + if (PyErr_Occurred()) { + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + return Extend(self, other); +} + +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length = Length(self); + Py_ssize_t slicelength; + if (PySlice_Check(slice)) { +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return NULL; + } + return PyList_GetSlice(self->child_messages, from, to); + } else if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + PyObject* result = PyList_GetItem(self->child_messages, from); + if (result == NULL) { + return NULL; + } + Py_INCREF(result); + return result; + } + PyErr_SetString(PyExc_TypeError, "index must be an integer or slice"); + return NULL; +} + +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value) { + if (UpdateChildMessages(self) < 0) { + return -1; + } + if (value != NULL) { + PyErr_SetString(PyExc_TypeError, "does not support assignment"); + return -1; + } + + // Delete from the underlying Message, if any. + if (self->message != NULL) { + if (cmessage::InternalDeleteRepeatedField(self->message, + self->parent_field->descriptor, + slice, + self->child_messages) < 0) { + return -1; + } + } else { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length = Length(self); + Py_ssize_t slicelength; + if (PySlice_Check(slice)) { +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return -1; + } + return PySequence_DelSlice(self->child_messages, from, to); + } else if (PyInt_Check(slice) || PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + if (from < 0) { + from = to = length + from; + } + return PySequence_DelItem(self->child_messages, from); + } + } + + return 0; +} + +static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t index = PySequence_Index(self->child_messages, value); + if (index == -1) { + return NULL; + } + ScopedPyObjectPtr py_index(PyLong_FromLong(index)); + if (AssignSubscript(self, py_index, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* RichCompare(RepeatedCompositeContainer* self, + PyObject* other, + int opid) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + if (!PyObject_TypeCheck(other, &RepeatedCompositeContainer_Type)) { + PyErr_SetString(PyExc_TypeError, + "Can only compare repeated composite fields " + "against other repeated composite fields."); + return NULL; + } + if (opid == Py_EQ || opid == Py_NE) { + // TODO(anuraag): Don't make new lists just for this... + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + ScopedPyObjectPtr other_list( + Subscript( + reinterpret_cast<RepeatedCompositeContainer*>(other), full_slice)); + if (other_list == NULL) { + return NULL; + } + return PyObject_RichCompare(list, other_list, opid); + } else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + +// --------------------------------------------------------------------- +// sort() + +static PyObject* SortAttached(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + // Sort the underlying Message array. + PyObject *compare = NULL; + int reverse = 0; + PyObject *keyfunc = NULL; + static char *kwlist[] = {"cmp", "key", "reverse", 0}; + + if (args != NULL) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort", + kwlist, &compare, &keyfunc, &reverse)) + return NULL; + } + if (compare == Py_None) + compare = NULL; + if (keyfunc == Py_None) + keyfunc = NULL; + + const Py_ssize_t length = Length(self); + if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0) + return NULL; + + // Finally reverse the result if requested. + if (reverse) { + google::protobuf::Message* message = self->message; + const google::protobuf::Reflection* reflection = message->GetReflection(); + const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor; + + // Reverse the Message array. + for (int i = 0; i < length / 2; ++i) + reflection->SwapElements(message, descriptor, i, length - i - 1); + + // Reverse the Python list. + ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages, + "reverse", NULL)); + if (res == NULL) + return NULL; + } + + Py_RETURN_NONE; +} + +static PyObject* SortReleased(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort")); + if (m == NULL) + return NULL; + if (PyObject_Call(m, args, kwds) == NULL) + return NULL; + Py_RETURN_NONE; +} + +static PyObject* Sort(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != NULL) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != NULL) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + PyDict_SetItemString(kwds, "cmp", sort_func); + PyDict_DelItemString(kwds, "sort_function"); + } + } + + if (UpdateChildMessages(self) < 0) + return NULL; + if (self->message == NULL) { + return SortReleased(self, args, kwds); + } else { + return SortAttached(self, args, kwds); + } +} + +// --------------------------------------------------------------------- + +static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) { + if (UpdateChildMessages(self) < 0) { + return NULL; + } + Py_ssize_t length = Length(self); + if (index < 0) { + index = length + index; + } + PyObject* item = PyList_GetItem(self->child_messages, index); + if (item == NULL) { + return NULL; + } + Py_INCREF(item); + return item; +} + +// The caller takes ownership of the returned Message. +Message* ReleaseLast(const FieldDescriptor* field, + const Descriptor* type, + Message* message) { + GOOGLE_CHECK_NOTNULL(field); + GOOGLE_CHECK_NOTNULL(type); + GOOGLE_CHECK_NOTNULL(message); + + Message* released_message = message->GetReflection()->ReleaseLast( + message, field); + // TODO(tibell): Deal with proto1. + + // ReleaseMessage will return NULL which differs from + // child_cmessage->message, if the field does not exist. In this case, + // the latter points to the default instance via a const_cast<>, so we + // have to reset it to a new mutable object since we are taking ownership. + if (released_message == NULL) { + const Message* prototype = global_message_factory->GetPrototype(type); + GOOGLE_CHECK_NOTNULL(prototype); + return prototype->New(); + } else { + return released_message; + } +} + +// Release field of message and transfer the ownership to cmessage. +void ReleaseLastTo(const FieldDescriptor* field, + Message* message, + CMessage* cmessage) { + GOOGLE_CHECK_NOTNULL(field); + GOOGLE_CHECK_NOTNULL(message); + GOOGLE_CHECK_NOTNULL(cmessage); + + shared_ptr<Message> released_message( + ReleaseLast(field, cmessage->message->GetDescriptor(), message)); + cmessage->parent = NULL; + cmessage->parent_field = NULL; + cmessage->message = released_message.get(); + cmessage->read_only = false; + cmessage::SetOwner(cmessage, released_message); +} + +// Called to release a container using +// ClearField('container_field_name') on the parent. +int Release(RepeatedCompositeContainer* self) { + if (UpdateChildMessages(self) < 0) { + PyErr_WriteUnraisable(PyBytes_FromString("Failed to update released " + "messages")); + return -1; + } + + Message* message = self->message; + const FieldDescriptor* field = self->parent_field->descriptor; + + // The reflection API only lets us release the last message in a + // repeated field. Therefore we iterate through the children + // starting with the last one. + const Py_ssize_t size = PyList_GET_SIZE(self->child_messages); + GOOGLE_DCHECK_EQ(size, message->GetReflection()->FieldSize(*message, field)); + for (Py_ssize_t i = size - 1; i >= 0; --i) { + CMessage* child_cmessage = reinterpret_cast<CMessage*>( + PyList_GET_ITEM(self->child_messages, i)); + ReleaseLastTo(field, message, child_cmessage); + } + + // Detach from containing message. + self->parent = NULL; + self->parent_field = NULL; + self->message = NULL; + self->owner.reset(); + + return 0; +} + +int SetOwner(RepeatedCompositeContainer* self, + const shared_ptr<Message>& new_owner) { + GOOGLE_CHECK_ATTACHED(self); + + self->owner = new_owner; + const Py_ssize_t n = PyList_GET_SIZE(self->child_messages); + for (Py_ssize_t i = 0; i < n; ++i) { + PyObject* msg = PyList_GET_ITEM(self->child_messages, i); + if (cmessage::SetOwner(reinterpret_cast<CMessage*>(msg), new_owner) == -1) { + return -1; + } + } + return 0; +} + +static int Init(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs) { + self->message = NULL; + self->parent = NULL; + self->parent_field = NULL; + self->subclass_init = NULL; + self->child_messages = PyList_New(0); + return 0; +} + +static void Dealloc(RepeatedCompositeContainer* self) { + Py_CLEAR(self->child_messages); + // TODO(tibell): Do we need to call delete on these objects to make + // sure their destructors are called? + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +static PySequenceMethods SqMethods = { + (lenfunc)Length, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)Item /* sq_item */ +}; + +static PyMappingMethods MpMethods = { + (lenfunc)Length, /* mp_length */ + (binaryfunc)Subscript, /* mp_subscript */ + (objobjargproc)AssignSubscript,/* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + { "add", (PyCFunction) Add, METH_VARARGS | METH_KEYWORDS, + "Adds an object to the repeated container." }, + { "extend", (PyCFunction) Extend, METH_O, + "Adds objects to the repeated container." }, + { "remove", (PyCFunction) Remove, METH_O, + "Removes an object from the repeated container." }, + { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container." }, + { "MergeFrom", (PyCFunction) MergeFrom, METH_O, + "Adds objects to the repeated container." }, + { NULL, NULL } +}; + +} // namespace repeated_composite_container + +PyTypeObject RepeatedCompositeContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.RepeatedCompositeContainer", // tp_name + sizeof(RepeatedCompositeContainer), // tp_basicsize + 0, // tp_itemsize + (destructor)repeated_composite_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &repeated_composite_container::SqMethods, // tp_as_sequence + &repeated_composite_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)repeated_composite_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + repeated_composite_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)repeated_composite_container::Init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h new file mode 100644 index 0000000..e8ed30e --- /dev/null +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -0,0 +1,172 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif +#include <string> +#include <vector> + + +namespace google { +namespace protobuf { + +class FieldDescriptor; +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; +struct CFieldDescriptor; + +// A RepeatedCompositeContainer can be in one of two states: attached +// or released. +// +// When in the attached state all modifications to the container are +// done both on the 'message' and on the 'child_messages' +// list. In this state all Messages refered to by the children in +// 'child_messages' are owner by the 'owner'. +// +// When in the released state 'message', 'owner', 'parent', and +// 'parent_field' are NULL. +typedef struct RepeatedCompositeContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python RepeatedCompositeContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr<Message> owner; + + // Weak reference to parent object. May be NULL. Used to make sure + // the parent is writable before modifying the + // RepeatedCompositeContainer. + CMessage* parent; + + // A descriptor used to modify the underlying 'message'. + CFieldDescriptor* parent_field; + + // Pointer to the C++ Message that contains this container. The + // RepeatedCompositeContainer does not own this pointer. + // + // If NULL, this message has been released from its parent (by + // calling Clear() or ClearField() on the parent. + Message* message; + + // A callable that is used to create new child messages. + PyObject* subclass_init; + + // A list of child messages. + PyObject* child_messages; +} RepeatedCompositeContainer; + +extern PyTypeObject RepeatedCompositeContainer_Type; + +namespace repeated_composite_container { + +// Returns the number of items in this repeated composite container. +static Py_ssize_t Length(RepeatedCompositeContainer* self); + +// Appends a new CMessage to the container and returns it. The +// CMessage is initialized using the content of kwargs. +// +// Returns a new reference if successful; returns NULL and sets an +// exception if unsuccessful. +PyObject* Add(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwargs); + +// Appends all the CMessages in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value); + +// Appends a new message to the container for each message in the +// input iterator, merging each data element in. Equivalent to extend. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other); + +// Accesses messages in the container. +// +// Returns a new reference to the message for an integer parameter. +// Returns a new reference to a list of messages for a slice. +PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice); + +// Deletes items from the container (cannot be used for assignment). +// +// Returns 0 on success, -1 on failure. +int AssignSubscript(RepeatedCompositeContainer* self, + PyObject* slice, + PyObject* value); + +// Releases the messages in the container to the given message. +// +// Returns 0 on success, -1 on failure. +int ReleaseToMessage(RepeatedCompositeContainer* self, + google::protobuf::Message* new_message); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(RepeatedCompositeContainer* self); + +// Returns 0 on success, -1 on failure. +int SetOwner(RepeatedCompositeContainer* self, + const shared_ptr<Message>& new_owner); + +// Removes the last element of the repeated message field 'field' on +// the Message 'message', and transfers the ownership of the released +// Message to 'cmessage'. +// +// Corresponds to reflection api method ReleaseMessage. +void ReleaseLastTo(const FieldDescriptor* field, + Message* message, + CMessage* cmessage); + +} // namespace repeated_composite_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc new file mode 100644 index 0000000..b0fcd81 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -0,0 +1,825 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#include <google/protobuf/pyext/repeated_scalar_container.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.h> +#include <google/protobuf/dynamic_message.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #else + #define PyString_AsString(ob) \ + (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob)) + #endif +#endif + +namespace google { +namespace protobuf { +namespace python { + +extern google::protobuf::DynamicMessageFactory* global_message_factory; + +namespace repeated_scalar_container { + +static int InternalAssignRepeatedField( + RepeatedScalarContainer* self, PyObject* list) { + self->message->GetReflection()->ClearField(self->message, + self->parent_field->descriptor); + for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) { + PyObject* value = PyList_GET_ITEM(list, i); + if (Append(self, value) == NULL) { + return -1; + } + } + return 0; +} + +static Py_ssize_t Len(RepeatedScalarContainer* self) { + google::protobuf::Message* message = self->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field->descriptor); +} + +static int AssignItem(RepeatedScalarContainer* self, + Py_ssize_t index, + PyObject* arg) { + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + const google::protobuf::Reflection* reflection = message->GetReflection(); + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, + "list assignment index (%d) out of range", + static_cast<int>(index)); + return -1; + } + + if (arg == NULL) { + ScopedPyObjectPtr py_index(PyLong_FromLong(index)); + return cmessage::InternalDeleteRepeatedField(message, field_descriptor, + py_index, NULL); + } + + if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) { + PyErr_SetString(PyExc_TypeError, "Value must be scalar"); + return -1; + } + + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + reflection->SetRepeatedInt32(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(arg, value, -1); + reflection->SetRepeatedInt64(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(arg, value, -1); + reflection->SetRepeatedUInt32(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(arg, value, -1); + reflection->SetRepeatedUInt64(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(arg, value, -1); + reflection->SetRepeatedFloat(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); + reflection->SetRepeatedDouble(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(arg, value, -1); + reflection->SetRepeatedBool(message, field_descriptor, index, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + arg, message, field_descriptor, reflection, false, index)) { + return -1; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(arg, value, -1); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetRepeatedEnum(message, field_descriptor, index, + enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(arg)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return -1; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return -1; + } + return 0; +} + +static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + const google::protobuf::Reflection* reflection = message->GetReflection(); + + int field_size = reflection->FieldSize(*message, field_descriptor); + if (index < 0) { + index = field_size + index; + } + if (index < 0 || index >= field_size) { + PyErr_Format(PyExc_IndexError, + "list assignment index (%d) out of range", + static_cast<int>(index)); + return NULL; + } + + PyObject* result = NULL; + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + int32 value = reflection->GetRepeatedInt32( + *message, field_descriptor, index); + result = PyInt_FromLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + int64 value = reflection->GetRepeatedInt64( + *message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = reflection->GetRepeatedUInt32( + *message, field_descriptor, index); + result = PyLong_FromLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = reflection->GetRepeatedUInt64( + *message, field_descriptor, index); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + float value = reflection->GetRepeatedFloat( + *message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + double value = reflection->GetRepeatedDouble( + *message, field_descriptor, index); + result = PyFloat_FromDouble(value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + bool value = reflection->GetRepeatedBool( + *message, field_descriptor, index); + result = PyBool_FromLong(value ? 1 : 0); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + const google::protobuf::EnumValueDescriptor* enum_value = + message->GetReflection()->GetRepeatedEnum( + *message, field_descriptor, index); + result = PyInt_FromLong(enum_value->number()); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetRepeatedString( + *message, field_descriptor, index); + result = ToStringObject(field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>( + &CMessage_Type), NULL); + if (py_cmsg == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg); + const google::protobuf::Message& msg = reflection->GetRepeatedMessage( + *message, field_descriptor, index); + cmsg->owner = self->owner; + cmsg->parent = self->parent; + cmsg->message = const_cast<google::protobuf::Message*>(&msg); + cmsg->read_only = false; + result = reinterpret_cast<PyObject*>(py_cmsg); + break; + } + default: + PyErr_Format( + PyExc_SystemError, + "Getting value from a repeated field of unknown type %d", + field_descriptor->cpp_type()); + } + + return result; +} + +static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool return_list = false; +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else // NOLINT +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + length = Len(self); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return NULL; + } + return_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return NULL; + } + + if (!return_list) { + return Item(self, from); + } + + PyObject* list = PyList_New(0); + if (list == NULL) { + return NULL; + } + if (from <= to) { + if (step < 0) { + return list; + } + for (Py_ssize_t index = from; index < to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(self, index)); + PyList_Append(list, s); + } + } else { + if (step > 0) { + return list; + } + for (Py_ssize_t index = from; index > to; index += step) { + if (index < 0 || index >= length) { + break; + } + ScopedPyObjectPtr s(Item(self, index)); + PyList_Append(list, s); + } + } + return list; +} + +PyObject* Append(RepeatedScalarContainer* self, PyObject* item) { + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + + if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return NULL; + } + + const google::protobuf::Reflection* reflection = message->GetReflection(); + switch (field_descriptor->cpp_type()) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(item, value, NULL); + reflection->AddInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(item, value, NULL); + reflection->AddInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(item, value, NULL); + reflection->AddUInt32(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(item, value, NULL); + reflection->AddUInt64(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(item, value, NULL); + reflection->AddFloat(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(item, value, NULL); + reflection->AddDouble(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(item, value, NULL); + reflection->AddBool(message, field_descriptor, value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + if (!CheckAndSetString( + item, message, field_descriptor, reflection, true, -1)) { + return NULL; + } + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(item, value, NULL); + const google::protobuf::EnumDescriptor* enum_descriptor = + field_descriptor->enum_type(); + const google::protobuf::EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->AddEnum(message, field_descriptor, enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(item)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s.get())); + } + return NULL; + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Adding value to a field of unknown type %d", + field_descriptor->cpp_type()); + return NULL; + } + + Py_RETURN_NONE; +} + +static int AssSubscript(RepeatedScalarContainer* self, + PyObject* slice, + PyObject* value) { + Py_ssize_t from; + Py_ssize_t to; + Py_ssize_t step; + Py_ssize_t length; + Py_ssize_t slicelength; + bool create_list = false; + + cmessage::AssureWritable(self->parent); + google::protobuf::Message* message = self->message; + const google::protobuf::FieldDescriptor* field_descriptor = + self->parent_field->descriptor; + +#if PY_MAJOR_VERSION < 3 + if (PyInt_Check(slice)) { + from = to = PyInt_AsLong(slice); + } else +#endif + if (PyLong_Check(slice)) { + from = to = PyLong_AsLong(slice); + } else if (PySlice_Check(slice)) { + const google::protobuf::Reflection* reflection = message->GetReflection(); + length = reflection->FieldSize(*message, field_descriptor); +#if PY_MAJOR_VERSION >= 3 + if (PySlice_GetIndicesEx(slice, +#else + if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice), +#endif + length, &from, &to, &step, &slicelength) == -1) { + return -1; + } + create_list = true; + } else { + PyErr_SetString(PyExc_TypeError, "list indices must be integers"); + return -1; + } + + if (value == NULL) { + return cmessage::InternalDeleteRepeatedField( + message, field_descriptor, slice, NULL); + } + + if (!create_list) { + return AssignItem(self, from, value); + } + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return -1; + } + ScopedPyObjectPtr new_list(Subscript(self, full_slice)); + if (new_list == NULL) { + return -1; + } + if (PySequence_SetSlice(new_list, from, to, value) < 0) { + return -1; + } + + return InternalAssignRepeatedField(self, new_list); +} + +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { + cmessage::AssureWritable(self->parent); + if (PyObject_Not(value)) { + Py_RETURN_NONE; + } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return NULL; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + if (Append(self, next) == NULL) { + return NULL; + } + } + if (PyErr_Occurred()) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) { + Py_ssize_t index; + PyObject* value; + if (!PyArg_ParseTuple(args, "lO", &index, &value)) { + return NULL; + } + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + ScopedPyObjectPtr new_list(Subscript(self, full_slice)); + if (PyList_Insert(new_list, index, value) < 0) { + return NULL; + } + int ret = InternalAssignRepeatedField(self, new_list); + if (ret < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) { + Py_ssize_t match_index = -1; + for (Py_ssize_t i = 0; i < Len(self); ++i) { + ScopedPyObjectPtr elem(Item(self, i)); + if (PyObject_RichCompareBool(elem, value, Py_EQ)) { + match_index = i; + break; + } + } + if (match_index == -1) { + PyErr_SetString(PyExc_ValueError, "remove(x): x not in container"); + return NULL; + } + if (AssignItem(self, match_index, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static PyObject* RichCompare(RepeatedScalarContainer* self, + PyObject* other, + int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + // Copy the contents of this repeated scalar container, and other if it is + // also a repeated scalar container, into Python lists so we can delegate + // to the list's compare method. + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + + ScopedPyObjectPtr other_list_deleter; + if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) { + other_list_deleter.reset(Subscript( + reinterpret_cast<RepeatedScalarContainer*>(other), full_slice)); + other = other_list_deleter.get(); + } + + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + return PyObject_RichCompare(list, other, opid); +} + +PyObject* Reduce(RepeatedScalarContainer* unused_self) { + PyErr_Format( + PickleError_class, + "can't pickle repeated message fields, convert to list first"); + return NULL; +} + +static PyObject* Sort(RepeatedScalarContainer* self, + PyObject* args, + PyObject* kwds) { + // Support the old sort_function argument for backwards + // compatibility. + if (kwds != NULL) { + PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function"); + if (sort_func != NULL) { + // Must set before deleting as sort_func is a borrowed reference + // and kwds might be the only thing keeping it alive. + if (PyDict_SetItemString(kwds, "cmp", sort_func) == -1) + return NULL; + if (PyDict_DelItemString(kwds, "sort_function") == -1) + return NULL; + } + } + + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return NULL; + } + ScopedPyObjectPtr list(Subscript(self, full_slice)); + if (list == NULL) { + return NULL; + } + ScopedPyObjectPtr m(PyObject_GetAttrString(list, "sort")); + if (m == NULL) { + return NULL; + } + ScopedPyObjectPtr res(PyObject_Call(m, args, kwds)); + if (res == NULL) { + return NULL; + } + int ret = InternalAssignRepeatedField(self, list); + if (ret < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static int Init(RepeatedScalarContainer* self, + PyObject* args, + PyObject* kwargs) { + PyObject* py_parent; + PyObject* py_parent_field; + if (!PyArg_UnpackTuple(args, "__init__()", 2, 2, &py_parent, + &py_parent_field)) { + return -1; + } + + if (!PyObject_TypeCheck(py_parent, &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, + "expect %s, but got %s", + CMessage_Type.tp_name, + Py_TYPE(py_parent)->tp_name); + return -1; + } + + if (!PyObject_TypeCheck(py_parent_field, &CFieldDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, + "expect %s, but got %s", + CFieldDescriptor_Type.tp_name, + Py_TYPE(py_parent_field)->tp_name); + return -1; + } + + CMessage* cmessage = reinterpret_cast<CMessage*>(py_parent); + CFieldDescriptor* cdescriptor = reinterpret_cast<CFieldDescriptor*>( + py_parent_field); + + if (!FIELD_BELONGS_TO_MESSAGE(cdescriptor->descriptor, cmessage->message)) { + PyErr_SetString( + PyExc_KeyError, "Field does not belong to message!"); + return -1; + } + + self->message = cmessage->message; + self->parent = cmessage; + self->parent_field = cdescriptor; + self->owner = cmessage->owner; + return 0; +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer( + RepeatedScalarContainer* from, + RepeatedScalarContainer* to) { + ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL)); + if (full_slice == NULL) { + return -1; + } + ScopedPyObjectPtr values(Subscript(from, full_slice)); + if (values == NULL) { + return -1; + } + google::protobuf::Message* new_message = global_message_factory->GetPrototype( + from->message->GetDescriptor())->New(); + to->parent = NULL; + // TODO(anuraag): Document why it's OK to hang on to parent_field, + // even though it's a weak reference. It ought to be enough to + // hold on to the FieldDescriptor only. + to->parent_field = from->parent_field; + to->message = new_message; + to->owner.reset(new_message); + if (InternalAssignRepeatedField(to, values) < 0) { + return -1; + } + return 0; +} + +int Release(RepeatedScalarContainer* self) { + return InitializeAndCopyToParentContainer(self, self); +} + +PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) { + ScopedPyObjectPtr init_args( + PyTuple_Pack(2, self->parent, self->parent_field)); + PyObject* clone = PyObject_CallObject( + reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), init_args); + if (clone == NULL) { + return NULL; + } + if (!PyObject_TypeCheck(clone, &RepeatedScalarContainer_Type)) { + Py_DECREF(clone); + return NULL; + } + if (InitializeAndCopyToParentContainer( + self, reinterpret_cast<RepeatedScalarContainer*>(clone)) < 0) { + Py_DECREF(clone); + return NULL; + } + return clone; +} + +static void Dealloc(RepeatedScalarContainer* self) { + self->owner.reset(); + Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); +} + +void SetOwner(RepeatedScalarContainer* self, + const shared_ptr<Message>& new_owner) { + self->owner = new_owner; +} + +static PySequenceMethods SqMethods = { + (lenfunc)Len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)Item, /* sq_item */ + 0, /* sq_slice */ + (ssizeobjargproc)AssignItem /* sq_ass_item */ +}; + +static PyMappingMethods MpMethods = { + (lenfunc)Len, /* mp_length */ + (binaryfunc)Subscript, /* mp_subscript */ + (objobjargproc)AssSubscript, /* mp_ass_subscript */ +}; + +static PyMethodDef Methods[] = { + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + { "append", (PyCFunction)Append, METH_O, + "Appends an object to the repeated container." }, + { "extend", (PyCFunction)Extend, METH_O, + "Appends objects to the repeated container." }, + { "insert", (PyCFunction)Insert, METH_VARARGS, + "Appends objects to the repeated container." }, + { "remove", (PyCFunction)Remove, METH_O, + "Removes an object from the repeated container." }, + { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS, + "Sorts the repeated container."}, + { NULL, NULL } +}; + +} // namespace repeated_scalar_container + +PyTypeObject RepeatedScalarContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "cpp._message.RepeatedScalarContainer", // tp_name + sizeof(RepeatedScalarContainer), // tp_basicsize + 0, // tp_itemsize + (destructor)repeated_scalar_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + &repeated_scalar_container::SqMethods, // tp_as_sequence + &repeated_scalar_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Repeated scalar container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)repeated_scalar_container::RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + repeated_scalar_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + (initproc)repeated_scalar_container::Init, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h new file mode 100644 index 0000000..8a30138 --- /dev/null +++ b/python/google/protobuf/pyext/repeated_scalar_container.h @@ -0,0 +1,112 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: anuraag@google.com (Anuraag Agrawal) +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + + +namespace google { +namespace protobuf { + +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CFieldDescriptor; +struct CMessage; + +typedef struct RepeatedScalarContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python RepeatedScalarContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr<Message> owner; + + // Pointer to the C++ Message that contains this container. The + // RepeatedScalarContainer does not own this pointer. + Message* message; + + // Weak reference to a parent CMessage object (i.e. may be NULL.) + // + // Used to make sure all ancestors are also mutable when first + // modifying the container. + CMessage* parent; + + // Weak reference to the parent's descriptor that describes this + // field. Used together with the parent's message when making a + // default message instance mutable. + CFieldDescriptor* parent_field; +} RepeatedScalarContainer; + +extern PyTypeObject RepeatedScalarContainer_Type; + +namespace repeated_scalar_container { + +// Appends the scalar 'item' to the end of the container 'self'. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Append(RepeatedScalarContainer* self, PyObject* item); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(RepeatedScalarContainer* self); + +// Appends all the elements in the input iterator to the container. +// +// Returns None if successful; returns NULL and sets an exception if +// unsuccessful. +PyObject* Extend(RepeatedScalarContainer* self, PyObject* value); + +// Set the owner field of self and any children of self. +void SetOwner(RepeatedScalarContainer* self, + const shared_ptr<Message>& new_owner); + +} // namespace repeated_scalar_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h new file mode 100644 index 0000000..1b27a89 --- /dev/null +++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h @@ -0,0 +1,95 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: tibell@google.com (Johan Tibell) + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ + +#include <Python.h> + +namespace google { +class ScopedPyObjectPtr { + public: + // Constructor. Defaults to intializing with NULL. + // There is no way to create an uninitialized ScopedPyObjectPtr. + explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { } + + // Destructor. If there is a PyObject object, delete it. + ~ScopedPyObjectPtr() { + Py_XDECREF(ptr_); + } + + // Reset. Deletes the current owned object, if any. + // Then takes ownership of a new object, if given. + // this->reset(this->get()) works. + PyObject* reset(PyObject* p = NULL) { + if (p != ptr_) { + Py_XDECREF(ptr_); + ptr_ = p; + } + return ptr_; + } + + // Releases ownership of the object. + PyObject* release() { + PyObject* p = ptr_; + ptr_ = NULL; + return p; + } + + operator PyObject*() { return ptr_; } + + PyObject* operator->() const { + assert(ptr_ != NULL); + return ptr_; + } + + PyObject* get() const { return ptr_; } + + Py_ssize_t refcnt() const { return Py_REFCNT(ptr_); } + + void inc() const { Py_INCREF(ptr_); } + + // Comparison operators. + // These return whether a ScopedPyObjectPtr and a raw pointer + // refer to the same object, not just to two different but equal + // objects. + bool operator==(const PyObject* p) const { return ptr_ == p; } + bool operator!=(const PyObject* p) const { return ptr_ != p; } + + private: + PyObject* ptr_; + + GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPyObjectPtr); +}; + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__ diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 5b23803..7aac623 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -29,9 +29,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # This code is meant to work on Python 2.4 and above only. -# -# TODO(robinson): Helpers for verbose, common checks like seeing if a -# descriptor's cpp_type is CPPTYPE_MESSAGE. """Contains a metaclass and helper functions used to create protocol message classes from Descriptor objects at runtime. @@ -50,27 +47,29 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -import struct -import weakref -# We use "as" to avoid name collisions with variables. -from google.protobuf.internal import containers -from google.protobuf.internal import decoder -from google.protobuf.internal import encoder -from google.protobuf.internal import message_listener as message_listener_mod -from google.protobuf.internal import type_checkers -from google.protobuf.internal import wire_format +from google.protobuf.internal import api_implementation from google.protobuf import descriptor as descriptor_mod -from google.protobuf import message as message_mod -from google.protobuf import text_format +from google.protobuf import message _FieldDescriptor = descriptor_mod.FieldDescriptor +if api_implementation.Type() == 'cpp': + if api_implementation.Version() == 2: + from google.protobuf.pyext import cpp_message + _NewMessage = cpp_message.NewMessage + _InitMessage = cpp_message.InitMessage + else: + from google.protobuf.internal import cpp_message + _NewMessage = cpp_message.NewMessage + _InitMessage = cpp_message.InitMessage +else: + from google.protobuf.internal import python_message + _NewMessage = python_message.NewMessage + _InitMessage = python_message.InitMessage + + class GeneratedProtocolMessageType(type): """Metaclass for protocol message classes created at runtime from Descriptors. @@ -92,6 +91,10 @@ class GeneratedProtocolMessageType(type): myproto_instance = MyProtoClass() myproto.foo_field = 23 ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. """ # Must be consistent with the protocol-compiler code in @@ -120,10 +123,12 @@ class GeneratedProtocolMessageType(type): Newly-allocated class. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - _AddSlots(descriptor, dictionary) - _AddClassAttributesForNestedExtensions(descriptor, dictionary) + bases = _NewMessage(bases, descriptor, dictionary) superclass = super(GeneratedProtocolMessageType, cls) - return superclass.__new__(cls, name, bases, dictionary) + + new_class = superclass.__new__(cls, name, bases, dictionary) + setattr(descriptor, '_concrete_class', new_class) + return new_class def __init__(cls, name, bases, dictionary): """Here we perform the majority of our work on the class. @@ -143,1006 +148,58 @@ class GeneratedProtocolMessageType(type): type. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - - cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} - if (descriptor.has_options and - descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number)) - - # We act as a "friend" class of the descriptor, setting - # its _concrete_class attribute the first time we use a - # given descriptor to initialize a concrete protocol message - # class. We also attach stuff to each FieldDescriptor for quick - # lookup later on. - concrete_class_attr_name = '_concrete_class' - if not hasattr(descriptor, concrete_class_attr_name): - setattr(descriptor, concrete_class_attr_name, cls) - for field in descriptor.fields: - _AttachFieldHelpers(cls, field) - - _AddEnumValues(descriptor, cls) - _AddInitMethod(descriptor, cls) - _AddPropertiesForFields(descriptor, cls) - _AddPropertiesForExtensions(descriptor, cls) - _AddStaticMethods(cls) - _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(cls) + _InitMessage(descriptor, cls) superclass = super(GeneratedProtocolMessageType, cls) superclass.__init__(name, bases, dictionary) -# Stateless helpers for GeneratedProtocolMessageType below. -# Outside clients should not access these directly. -# -# I opted not to make any of these methods on the metaclass, to make it more -# clear that I'm not really using any state there and to keep clients from -# thinking that they have direct access to these construction helpers. - - -def _PropertyName(proto_field_name): - """Returns the name of the public property attribute which - clients can use to get and (in some cases) set the value - of a protocol message field. - - Args: - proto_field_name: The protocol message field name, exactly - as it appears (or would appear) in a .proto file. - """ - # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. - # nnorwitz makes my day by writing: - # """ - # FYI. See the keyword module in the stdlib. This could be as simple as: - # - # if keyword.iskeyword(proto_field_name): - # return proto_field_name + "_" - # return proto_field_name - # """ - # Kenton says: The above is a BAD IDEA. People rely on being able to use - # getattr() and setattr() to reflectively manipulate field values. If we - # rename the properties, then every such user has to also make sure to apply - # the same transformation. Note that currently if you name a field "yield", - # you can still access it just fine using getattr/setattr -- it's not even - # that cumbersome to do so. - # TODO(kenton): Remove this method entirely if/when everyone agrees with my - # position. - return proto_field_name - - -def _VerifyExtensionHandle(message, extension_handle): - """Verify that the given extension handle is valid.""" - - if not isinstance(extension_handle, _FieldDescriptor): - raise KeyError('HasExtension() expects an extension handle, got: %s' % - extension_handle) - - if not extension_handle.is_extension: - raise KeyError('"%s" is not an extension.' % extension_handle.full_name) - - if extension_handle.containing_type is not message.DESCRIPTOR: - raise KeyError('Extension "%s" extends message type "%s", but this ' - 'message is of type "%s".' % - (extension_handle.full_name, - extension_handle.containing_type.full_name, - message.DESCRIPTOR.full_name)) - - -def _AddSlots(message_descriptor, dictionary): - """Adds a __slots__ entry to dictionary, containing the names of all valid - attributes for this message type. - - Args: - message_descriptor: A Descriptor instance describing this message type. - dictionary: Class dictionary to which we'll add a '__slots__' entry. - """ - dictionary['__slots__'] = ['_cached_byte_size', - '_cached_byte_size_dirty', - '_fields', - '_is_present_in_parent', - '_listener', - '_listener_for_children', - '__weakref__'] - - -def _IsMessageSetExtension(field): - return (field.is_extension and - field.containing_type.has_options and - field.containing_type.GetOptions().message_set_wire_format and - field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type == field.extension_scope and - field.label == _FieldDescriptor.LABEL_OPTIONAL) - - -def _AttachFieldHelpers(cls, field_descriptor): - is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - - if _IsMessageSetExtension(field_descriptor): - field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) - sizer = encoder.MessageSetItemSizer(field_descriptor.number) - else: - field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed) - sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed) - - field_descriptor._encoder = field_encoder - field_descriptor._sizer = sizer - field_descriptor._default_constructor = _DefaultValueConstructorForField( - field_descriptor) - - def AddDecoder(wiretype, is_packed): - tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - cls._decoders_by_tag[tag_bytes] = ( - type_checkers.TYPE_TO_DECODER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor)) - - AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], - False) - - if is_repeated and wire_format.IsTypePackable(field_descriptor.type): - # To support wire compatibility of adding packed = true, add a decoder for - # packed values regardless of the field's options. - AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) - - -def _AddClassAttributesForNestedExtensions(descriptor, dictionary): - extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): - assert extension_name not in dictionary - dictionary[extension_name] = extension_field - - -def _AddEnumValues(descriptor, cls): - """Sets class-level attributes for all enum fields defined in this message. - - Args: - descriptor: Descriptor object for this message type. - cls: Class we're constructing for this message type. - """ - for enum_type in descriptor.enum_types: - for enum_value in enum_type.values: - setattr(cls, enum_value.name, enum_value.number) - - -def _DefaultValueConstructorForField(field): - """Returns a function which returns a default value for a field. +def ParseMessage(descriptor, byte_str): + """Generate a new Message instance from this Descriptor and a byte string. Args: - field: FieldDescriptor object for this field. + descriptor: Protobuf Descriptor object + byte_str: Serialized protocol buffer byte string - The returned function has one argument: - message: Message instance containing this field, or a weakref proxy - of same. - - That function in turn returns a default value for this field. The default - value may refer back to |message| via a weak reference. - """ - - if field.label == _FieldDescriptor.LABEL_REPEATED: - if field.default_value != []: - raise ValueError('Repeated field default value not empty list: %s' % ( - field.default_value)) - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - # We can't look at _concrete_class yet since it might not have - # been set. (Depends on order in which we initialize the classes). - message_type = field.message_type - def MakeRepeatedMessageDefault(message): - return containers.RepeatedCompositeFieldContainer( - message._listener_for_children, field.message_type) - return MakeRepeatedMessageDefault - else: - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) - def MakeRepeatedScalarDefault(message): - return containers.RepeatedScalarFieldContainer( - message._listener_for_children, type_checker) - return MakeRepeatedScalarDefault - - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - # _concrete_class may not yet be initialized. - message_type = field.message_type - def MakeSubMessageDefault(message): - result = message_type._concrete_class() - result._SetListener(message._listener_for_children) - return result - return MakeSubMessageDefault - - def MakeScalarDefault(message): - return field.default_value - return MakeScalarDefault - - -def _AddInitMethod(message_descriptor, cls): - """Adds an __init__ method to cls.""" - fields = message_descriptor.fields - def init(self, **kwargs): - self._cached_byte_size = 0 - self._cached_byte_size_dirty = False - self._fields = {} - self._is_present_in_parent = False - self._listener = message_listener_mod.NullMessageListener() - self._listener_for_children = _Listener(self) - for field_name, field_value in kwargs.iteritems(): - field = _GetFieldByName(message_descriptor, field_name) - if field is None: - raise TypeError("%s() got an unexpected keyword argument '%s'" % - (message_descriptor.name, field_name)) - if field.label == _FieldDescriptor.LABEL_REPEATED: - copy = field._default_constructor(self) - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - for val in field_value: - copy.add().MergeFrom(val) - else: # Scalar - copy.extend(field_value) - self._fields[field] = copy - elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - copy = field._default_constructor(self) - copy.MergeFrom(field_value) - self._fields[field] = copy - else: - self._fields[field] = field_value - - init.__module__ = None - init.__doc__ = None - cls.__init__ = init - - -def _GetFieldByName(message_descriptor, field_name): - """Returns a field descriptor by field name. - - Args: - message_descriptor: A Descriptor describing all fields in message. - field_name: The name of the field to retrieve. Returns: - The field descriptor associated with the field name. - """ - try: - return message_descriptor.fields_by_name[field_name] - except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) - - -def _AddPropertiesForFields(descriptor, cls): - """Adds properties for all fields in this protocol message type.""" - for field in descriptor.fields: - _AddPropertiesForField(field, cls) - - if descriptor.is_extendable: - # _ExtensionDict is just an adaptor with no state so we allocate a new one - # every time it is accessed. - cls.Extensions = property(lambda self: _ExtensionDict(self)) - - -def _AddPropertiesForField(field, cls): - """Adds a public property for a protocol message field. - Clients can use this property to get and (in the case - of non-repeated scalar fields) directly set the value - of a protocol message field. - - Args: - field: A FieldDescriptor for this field. - cls: The class we're constructing. - """ - # Catch it if we add other types that we should - # handle specially here. - assert _FieldDescriptor.MAX_CPPTYPE == 10 - - constant_name = field.name.upper() + "_FIELD_NUMBER" - setattr(cls, constant_name, field.number) - - if field.label == _FieldDescriptor.LABEL_REPEATED: - _AddPropertiesForRepeatedField(field, cls) - elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - _AddPropertiesForNonRepeatedCompositeField(field, cls) - else: - _AddPropertiesForNonRepeatedScalarField(field, cls) - - -def _AddPropertiesForRepeatedField(field, cls): - """Adds a public property for a "repeated" protocol message field. Clients - can use this property to get the value of the field, which will be either a - _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see - below). - - Note that when clients add values to these containers, we perform - type-checking in the case of repeated scalar fields, and we also set any - necessary "has" bits as a side-effect. - - Args: - field: A FieldDescriptor for this field. - cls: The class we're constructing. - """ - proto_field_name = field.name - property_name = _PropertyName(proto_field_name) - - def getter(self): - field_value = self._fields.get(field) - if field_value is None: - # Construct a new object to represent this field. - field_value = field._default_constructor(self) - - # Atomically check if another thread has preempted us and, if not, swap - # in the new object we just created. If someone has preempted us, we - # take that object and discard ours. - # WARNING: We are relying on setdefault() being atomic. This is true - # in CPython but we haven't investigated others. This warning appears - # in several other locations in this file. - field_value = self._fields.setdefault(field, field_value) - return field_value - getter.__module__ = None - getter.__doc__ = 'Getter for %s.' % proto_field_name - - # We define a setter just so we can throw an exception with a more - # helpful error message. - def setter(self, new_value): - raise AttributeError('Assignment not allowed to repeated field ' - '"%s" in protocol message object.' % proto_field_name) - - doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) - - -def _AddPropertiesForNonRepeatedScalarField(field, cls): - """Adds a public property for a nonrepeated, scalar protocol message field. - Clients can use this property to get and directly set the value of the field. - Note that when the client sets the value of a field by using this property, - all necessary "has" bits are set as a side-effect, and we also perform - type-checking. - - Args: - field: A FieldDescriptor for this field. - cls: The class we're constructing. - """ - proto_field_name = field.name - property_name = _PropertyName(proto_field_name) - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) - default_value = field.default_value - - def getter(self): - return self._fields.get(field, default_value) - getter.__module__ = None - getter.__doc__ = 'Getter for %s.' % proto_field_name - def setter(self, new_value): - type_checker.CheckValue(new_value) - self._fields[field] = new_value - # Check _cached_byte_size_dirty inline to improve performance, since scalar - # setters are called frequently. - if not self._cached_byte_size_dirty: - self._Modified() - setter.__module__ = None - setter.__doc__ = 'Setter for %s.' % proto_field_name - - # Add a property to encapsulate the getter/setter. - doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) - - -def _AddPropertiesForNonRepeatedCompositeField(field, cls): - """Adds a public property for a nonrepeated, composite protocol message field. - A composite field is a "group" or "message" field. - - Clients can use this property to get the value of the field, but cannot - assign to the property directly. - - Args: - field: A FieldDescriptor for this field. - cls: The class we're constructing. + Newly created protobuf Message object. """ - # TODO(robinson): Remove duplication with similar method - # for non-repeated scalars. - proto_field_name = field.name - property_name = _PropertyName(proto_field_name) - message_type = field.message_type - - def getter(self): - field_value = self._fields.get(field) - if field_value is None: - # Construct a new object to represent this field. - field_value = message_type._concrete_class() - field_value._SetListener(self._listener_for_children) - - # Atomically check if another thread has preempted us and, if not, swap - # in the new object we just created. If someone has preempted us, we - # take that object and discard ours. - # WARNING: We are relying on setdefault() being atomic. This is true - # in CPython but we haven't investigated others. This warning appears - # in several other locations in this file. - field_value = self._fields.setdefault(field, field_value) - return field_value - getter.__module__ = None - getter.__doc__ = 'Getter for %s.' % proto_field_name - - # We define a setter just so we can throw an exception with a more - # helpful error message. - def setter(self, new_value): - raise AttributeError('Assignment not allowed to composite field ' - '"%s" in protocol message object.' % proto_field_name) - - # Add a property to encapsulate the getter. - doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) - - -def _AddPropertiesForExtensions(descriptor, cls): - """Adds properties for all fields in this protocol message type.""" - extension_dict = descriptor.extensions_by_name - for extension_name, extension_field in extension_dict.iteritems(): - constant_name = extension_name.upper() + "_FIELD_NUMBER" - setattr(cls, constant_name, extension_field.number) - - -def _AddStaticMethods(cls): - # TODO(robinson): This probably needs to be thread-safe(?) - def RegisterExtension(extension_handle): - extension_handle.containing_type = cls.DESCRIPTOR - _AttachFieldHelpers(cls, extension_handle) - - # Try to insert our extension, failing if an extension with the same number - # already exists. - actual_handle = cls._extensions_by_number.setdefault( - extension_handle.number, extension_handle) - if actual_handle is not extension_handle: - raise AssertionError( - 'Extensions "%s" and "%s" both try to extend message type "%s" with ' - 'field number %d.' % - (extension_handle.full_name, actual_handle.full_name, - cls.DESCRIPTOR.full_name, extension_handle.number)) - - cls._extensions_by_name[extension_handle.full_name] = extension_handle - - handle = extension_handle # avoid line wrapping - if _IsMessageSetExtension(handle): - # MessageSet extension. Also register under type name. - cls._extensions_by_name[ - extension_handle.message_type.full_name] = extension_handle - - cls.RegisterExtension = staticmethod(RegisterExtension) - - def FromString(s): - message = cls() - message.MergeFromString(s) - return message - cls.FromString = staticmethod(FromString) - - -def _IsPresent(item): - """Given a (FieldDescriptor, value) tuple from _fields, return true if the - value should be included in the list returned by ListFields().""" - - if item[0].label == _FieldDescriptor.LABEL_REPEATED: - return bool(item[1]) - elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - return item[1]._is_present_in_parent - else: - return True - - -def _AddListFieldsMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - - def ListFields(self): - all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] - all_fields.sort(key = lambda item: item[0].number) - return all_fields - - cls.ListFields = ListFields - - -def _AddHasFieldMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - - singular_fields = {} - for field in message_descriptor.fields: - if field.label != _FieldDescriptor.LABEL_REPEATED: - singular_fields[field.name] = field - - def HasField(self, field_name): - try: - field = singular_fields[field_name] - except KeyError: - raise ValueError( - 'Protocol message has no singular "%s" field.' % field_name) - - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(field) - return value is not None and value._is_present_in_parent - else: - return field in self._fields - cls.HasField = HasField + result_class = MakeClass(descriptor) + new_msg = result_class() + new_msg.ParseFromString(byte_str) + return new_msg -def _AddClearFieldMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def ClearField(self, field_name): - try: - field = message_descriptor.fields_by_name[field_name] - except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) +def MakeClass(descriptor): + """Construct a class object for a protobuf described by descriptor. - if field in self._fields: - # Note: If the field is a sub-message, its listener will still point - # at us. That's fine, because the worst than can happen is that it - # will call _Modified() and invalidate our byte size. Big deal. - del self._fields[field] + Composite descriptors are handled by defining the new class as a member of the + parent class, recursing as deep as necessary. + This is the dynamic equivalent to: - # Always call _Modified() -- even if nothing was changed, this is - # a mutating method, and thus calling it should cause the field to become - # present in the parent message. - self._Modified() - - cls.ClearField = ClearField - - -def _AddClearExtensionMethod(cls): - """Helper for _AddMessageMethods().""" - def ClearExtension(self, extension_handle): - _VerifyExtensionHandle(self, extension_handle) - - # Similar to ClearField(), above. - if extension_handle in self._fields: - del self._fields[extension_handle] - self._Modified() - cls.ClearExtension = ClearExtension - - -def _AddClearMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def Clear(self): - # Clear fields. - self._fields = {} - self._Modified() - cls.Clear = Clear - - -def _AddHasExtensionMethod(cls): - """Helper for _AddMessageMethods().""" - def HasExtension(self, extension_handle): - _VerifyExtensionHandle(self, extension_handle) - if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: - raise KeyError('"%s" is repeated.' % extension_handle.full_name) - - if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - value = self._fields.get(extension_handle) - return value is not None and value._is_present_in_parent - else: - return extension_handle in self._fields - cls.HasExtension = HasExtension - - -def _AddEqualsMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def __eq__(self, other): - if (not isinstance(other, message_mod.Message) or - other.DESCRIPTOR != self.DESCRIPTOR): - return False - - if self is other: - return True - - return self.ListFields() == other.ListFields() - - cls.__eq__ = __eq__ - - -def _AddStrMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def __str__(self): - return text_format.MessageToString(self) - cls.__str__ = __str__ - - -def _AddSetListenerMethod(cls): - """Helper for _AddMessageMethods().""" - def SetListener(self, listener): - if listener is None: - self._listener = message_listener_mod.NullMessageListener() - else: - self._listener = listener - cls._SetListener = SetListener - - -def _BytesForNonRepeatedElement(value, field_number, field_type): - """Returns the number of bytes needed to serialize a non-repeated element. - The returned byte count includes space for tag information and any - other additional space associated with serializing value. + class Parent(message.Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = descriptor + class Child(message.Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = descriptor.nested_types[0] + + Sample usage: + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(proto2_string) + msg_descriptor = descriptor.MakeDescriptor(file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() Args: - value: Value we're serializing. - field_number: Field number of this value. (Since the field number - is stored as part of a varint-encoded tag, this has an impact - on the total bytes required to serialize the value). - field_type: The type of the field. One of the TYPE_* constants - within FieldDescriptor. - """ - try: - fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] - return fn(field_number, value) - except KeyError: - raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) - - -def _AddByteSizeMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - - def ByteSize(self): - if not self._cached_byte_size_dirty: - return self._cached_byte_size - - size = 0 - for field_descriptor, field_value in self.ListFields(): - size += field_descriptor._sizer(field_value) - - self._cached_byte_size = size - self._cached_byte_size_dirty = False - self._listener_for_children.dirty = False - return size - - cls.ByteSize = ByteSize - - -def _AddSerializeToStringMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - - def SerializeToString(self): - # Check if the message has all of its required fields set. - errors = [] - if not self.IsInitialized(): - raise message_mod.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) - return self.SerializePartialToString() - cls.SerializeToString = SerializeToString - - -def _AddSerializePartialToStringMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - - def SerializePartialToString(self): - out = StringIO() - self._InternalSerialize(out.write) - return out.getvalue() - cls.SerializePartialToString = SerializePartialToString - - def InternalSerialize(self, write_bytes): - for field_descriptor, field_value in self.ListFields(): - field_descriptor._encoder(write_bytes, field_value) - cls._InternalSerialize = InternalSerialize - - -def _AddMergeFromStringMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def MergeFromString(self, serialized): - length = len(serialized) - try: - if self._InternalParse(serialized, 0, length) != length: - # The only reason _InternalParse would return early is if it - # encountered an end-group tag. - raise message_mod.DecodeError('Unexpected end-group tag.') - except IndexError: - raise message_mod.DecodeError('Truncated message.') - except struct.error, e: - raise message_mod.DecodeError(e) - return length # Return this for legacy reasons. - cls.MergeFromString = MergeFromString - - local_ReadTag = decoder.ReadTag - local_SkipField = decoder.SkipField - decoders_by_tag = cls._decoders_by_tag - - def InternalParse(self, buffer, pos, end): - self._Modified() - field_dict = self._fields - while pos != end: - (tag_bytes, new_pos) = local_ReadTag(buffer, pos) - field_decoder = decoders_by_tag.get(tag_bytes) - if field_decoder is None: - new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) - if new_pos == -1: - return pos - pos = new_pos - else: - pos = field_decoder(buffer, new_pos, end, self, field_dict) - return pos - cls._InternalParse = InternalParse - - -def _AddIsInitializedMethod(message_descriptor, cls): - """Adds the IsInitialized and FindInitializationError methods to the - protocol message class.""" - - required_fields = [field for field in message_descriptor.fields - if field.label == _FieldDescriptor.LABEL_REQUIRED] - - def IsInitialized(self, errors=None): - """Checks if all required fields of a message are set. - - Args: - errors: A list which, if provided, will be populated with the field - paths of all missing required fields. - - Returns: - True iff the specified message has all required fields set. - """ - - # Performance is critical so we avoid HasField() and ListFields(). - - for field in required_fields: - if (field not in self._fields or - (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and - not self._fields[field]._is_present_in_parent)): - if errors is not None: - errors.extend(self.FindInitializationErrors()) - return False - - for field, value in self._fields.iteritems(): - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if field.label == _FieldDescriptor.LABEL_REPEATED: - for element in value: - if not element.IsInitialized(): - if errors is not None: - errors.extend(self.FindInitializationErrors()) - return False - elif value._is_present_in_parent and not value.IsInitialized(): - if errors is not None: - errors.extend(self.FindInitializationErrors()) - return False - - return True - - cls.IsInitialized = IsInitialized - - def FindInitializationErrors(self): - """Finds required fields which are not initialized. - - Returns: - A list of strings. Each string is a path to an uninitialized field from - the top-level message, e.g. "foo.bar[5].baz". - """ - - errors = [] # simplify things - - for field in required_fields: - if not self.HasField(field.name): - errors.append(field.name) - - for field, value in self.ListFields(): - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if field.is_extension: - name = "(%s)" % field.full_name - else: - name = field.name - - if field.label == _FieldDescriptor.LABEL_REPEATED: - for i in xrange(len(value)): - element = value[i] - prefix = "%s[%d]." % (name, i) - sub_errors = element.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] - else: - prefix = name + "." - sub_errors = value.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] - - return errors - - cls.FindInitializationErrors = FindInitializationErrors - - -def _AddMergeFromMethod(cls): - LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED - CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE - - def MergeFrom(self, msg): - assert msg is not self - self._Modified() - - fields = self._fields - - for field, value in msg._fields.iteritems(): - if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE: - field_value = fields.get(field) - if field_value is None: - # Construct a new object to represent this field. - field_value = field._default_constructor(self) - fields[field] = field_value - field_value.MergeFrom(value) - else: - self._fields[field] = value - cls.MergeFrom = MergeFrom - - -def _AddMessageMethods(message_descriptor, cls): - """Adds implementations of all Message methods to cls.""" - _AddListFieldsMethod(message_descriptor, cls) - _AddHasFieldMethod(message_descriptor, cls) - _AddClearFieldMethod(message_descriptor, cls) - if message_descriptor.is_extendable: - _AddClearExtensionMethod(cls) - _AddHasExtensionMethod(cls) - _AddClearMethod(message_descriptor, cls) - _AddEqualsMethod(message_descriptor, cls) - _AddStrMethod(message_descriptor, cls) - _AddSetListenerMethod(cls) - _AddByteSizeMethod(message_descriptor, cls) - _AddSerializeToStringMethod(message_descriptor, cls) - _AddSerializePartialToStringMethod(message_descriptor, cls) - _AddMergeFromStringMethod(message_descriptor, cls) - _AddIsInitializedMethod(message_descriptor, cls) - _AddMergeFromMethod(cls) - - -def _AddPrivateHelperMethods(cls): - """Adds implementation of private helper methods to cls.""" - - def Modified(self): - """Sets the _cached_byte_size_dirty bit to true, - and propagates this to our listener iff this was a state change. - """ - - # Note: Some callers check _cached_byte_size_dirty before calling - # _Modified() as an extra optimization. So, if this method is ever - # changed such that it does stuff even when _cached_byte_size_dirty is - # already true, the callers need to be updated. - if not self._cached_byte_size_dirty: - self._cached_byte_size_dirty = True - self._listener_for_children.dirty = True - self._is_present_in_parent = True - self._listener.Modified() - - cls._Modified = Modified - cls.SetInParent = Modified - - -class _Listener(object): - - """MessageListener implementation that a parent message registers with its - child message. - - In order to support semantics like: - - foo.bar.baz.qux = 23 - assert foo.HasField('bar') - - ...child objects must have back references to their parents. - This helper class is at the heart of this support. - """ - - def __init__(self, parent_message): - """Args: - parent_message: The message whose _Modified() method we should call when - we receive Modified() messages. - """ - # This listener establishes a back reference from a child (contained) object - # to its parent (containing) object. We make this a weak reference to avoid - # creating cyclic garbage when the client finishes with the 'parent' object - # in the tree. - if isinstance(parent_message, weakref.ProxyType): - self._parent_message_weakref = parent_message - else: - self._parent_message_weakref = weakref.proxy(parent_message) - - # As an optimization, we also indicate directly on the listener whether - # or not the parent message is dirty. This way we can avoid traversing - # up the tree in the common case. - self.dirty = False - - def Modified(self): - if self.dirty: - return - try: - # Propagate the signal to our parents iff this is the first field set. - self._parent_message_weakref._Modified() - except ReferenceError: - # We can get here if a client has kept a reference to a child object, - # and is now setting a field on it, but the child's parent has been - # garbage-collected. This is not an error. - pass - - -# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... -# TODO(robinson): Unify error handling of "unknown extension" crap. -# TODO(robinson): Support iteritems()-style iteration over all -# extensions with the "has" bits turned on? -class _ExtensionDict(object): - - """Dict-like container for supporting an indexable "Extensions" - field on proto instances. - - Note that in all cases we expect extension handles to be - FieldDescriptors. + descriptor: A descriptor.Descriptor object describing the protobuf. + Returns: + The Message class object described by the descriptor. """ + attributes = {} + for name, nested_type in descriptor.nested_types_by_name.items(): + attributes[name] = MakeClass(nested_type) - def __init__(self, extended_message): - """extended_message: Message instance for which we are the Extensions dict. - """ - - self._extended_message = extended_message - - def __getitem__(self, extension_handle): - """Returns the current value of the given extension handle.""" - - _VerifyExtensionHandle(self._extended_message, extension_handle) - - result = self._extended_message._fields.get(extension_handle) - if result is not None: - return result - - if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: - result = extension_handle._default_constructor(self._extended_message) - elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - result = extension_handle.message_type._concrete_class() - try: - result._SetListener(self._extended_message._listener_for_children) - except ReferenceError: - pass - else: - # Singular scalar -- just return the default without inserting into the - # dict. - return extension_handle.default_value - - # Atomically check if another thread has preempted us and, if not, swap - # in the new object we just created. If someone has preempted us, we - # take that object and discard ours. - # WARNING: We are relying on setdefault() being atomic. This is true - # in CPython but we haven't investigated others. This warning appears - # in several other locations in this file. - result = self._extended_message._fields.setdefault( - extension_handle, result) - - return result - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False + attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor - my_fields = self._extended_message.ListFields() - other_fields = other._extended_message.ListFields() - - # Get rid of non-extension fields. - my_fields = [ field for field in my_fields if field.is_extension ] - other_fields = [ field for field in other_fields if field.is_extension ] - - return my_fields == other_fields - - def __ne__(self, other): - return not self == other - - # Note that this is only meaningful for non-repeated, scalar extension - # fields. Note also that we may have to call _Modified() when we do - # successfully set a field this way, to set any necssary "has" bits in the - # ancestors of the extended message. - def __setitem__(self, extension_handle, value): - """If extension_handle specifies a non-repeated, scalar extension - field, sets the value of that field. - """ - - _VerifyExtensionHandle(self._extended_message, extension_handle) - - if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or - extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): - raise TypeError( - 'Cannot assign to extension "%s" because it is a repeated or ' - 'composite type.' % extension_handle.full_name) - - # It's slightly wasteful to lookup the type checker each time, - # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker( - extension_handle.cpp_type, extension_handle.type) - type_checker.CheckValue(value) - self._extended_message._fields[extension_handle] = value - self._extended_message._Modified() - - def _FindExtensionByName(self, name): - """Tries to find a known extension with the specified name. - - Args: - name: Extension full name. - - Returns: - Extension field descriptor. - """ - return self._extended_message._extensions_by_name.get(name, None) + return GeneratedProtocolMessageType(str(descriptor.name), (message.Message,), + attributes) diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py new file mode 100644 index 0000000..7466fec --- /dev/null +++ b/python/google/protobuf/symbol_database.py @@ -0,0 +1,185 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A database of Python protocol buffer generated symbols. + +SymbolDatabase makes it easy to create new instances of a registered type, given +only the type's protocol buffer symbol name. Once all symbols are registered, +they can be accessed using either the MessageFactory interface which +SymbolDatabase exposes, or the DescriptorPool interface of the underlying +pool. + +Example usage: + + db = symbol_database.SymbolDatabase() + + # Register symbols of interest, from one or multiple files. + db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR) + db.RegisterMessage(my_proto_pb2.MyMessage) + db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR) + + # The database can be used as a MessageFactory, to generate types based on + # their name: + types = db.GetMessages(['my_proto.proto']) + my_message_instance = types['MyMessage']() + + # The database's underlying descriptor pool can be queried, so it's not + # necessary to know a type's filename to be able to generate it: + filename = db.pool.FindFileContainingSymbol('MyMessage') + my_message_instance = db.GetMessages([filename])['MyMessage']() + + # This functionality is also provided directly via a convenience method: + my_message_instance = db.GetSymbol('MyMessage')() +""" + + +from google.protobuf import descriptor_pool + + +class SymbolDatabase(object): + """A database of Python generated symbols. + + SymbolDatabase also models message_factory.MessageFactory. + + The symbol database can be used to keep a global registry of all protocol + buffer types used within a program. + """ + + def __init__(self): + """Constructor.""" + + self._symbols = {} + self._symbols_by_file = {} + self.pool = descriptor_pool.DescriptorPool() + + def RegisterMessage(self, message): + """Registers the given message type in the local database. + + Args: + message: a message.Message, to be registered. + + Returns: + The provided message. + """ + + desc = message.DESCRIPTOR + self._symbols[desc.full_name] = message + if desc.file.name not in self._symbols_by_file: + self._symbols_by_file[desc.file.name] = {} + self._symbols_by_file[desc.file.name][desc.full_name] = message + self.pool.AddDescriptor(desc) + return message + + def RegisterEnumDescriptor(self, enum_descriptor): + """Registers the given enum descriptor in the local database. + + Args: + enum_descriptor: a descriptor.EnumDescriptor. + + Returns: + The provided descriptor. + """ + self.pool.AddEnumDescriptor(enum_descriptor) + return enum_descriptor + + def RegisterFileDescriptor(self, file_descriptor): + """Registers the given file descriptor in the local database. + + Args: + file_descriptor: a descriptor.FileDescriptor. + + Returns: + The provided descriptor. + """ + self.pool.AddFileDescriptor(file_descriptor) + + def GetSymbol(self, symbol): + """Tries to find a symbol in the local database. + + Currently, this method only returns message.Message instances, however, if + may be extended in future to support other symbol types. + + Args: + symbol: A str, a protocol buffer symbol. + + Returns: + A Python class corresponding to the symbol. + + Raises: + KeyError: if the symbol could not be found. + """ + + return self._symbols[symbol] + + def GetPrototype(self, descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + + return self.GetSymbol(descriptor.full_name) + + def GetMessages(self, files): + """Gets all the messages from a specified file. + + This will find and resolve dependencies, failing if they are not registered + in the symbol database. + + + Args: + files: The file names to extract messages from. + + Returns: + A dictionary mapping proto names to the message classes. This will include + any dependent messages as well as any messages defined in the same file as + a specified message. + + Raises: + KeyError: if a file could not be found. + """ + + result = {} + for f in files: + result.update(self._symbols_by_file[f]) + return result + +_DEFAULT = SymbolDatabase() + + +def Default(): + """Returns the default SymbolDatabase.""" + return _DEFAULT diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py new file mode 100644 index 0000000..ed0aabf --- /dev/null +++ b/python/google/protobuf/text_encoding.py @@ -0,0 +1,110 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#PY25 compatible for GAE. +# +"""Encoding related utilities.""" + +import re +import sys ##PY25 + +# Lookup table for utf8 +_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)] +_cescape_utf8_to_str[9] = r'\t' # optional escape +_cescape_utf8_to_str[10] = r'\n' # optional escape +_cescape_utf8_to_str[13] = r'\r' # optional escape +_cescape_utf8_to_str[39] = r"\'" # optional escape + +_cescape_utf8_to_str[34] = r'\"' # necessary escape +_cescape_utf8_to_str[92] = r'\\' # necessary escape + +# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) +_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] + + [chr(i) for i in xrange(32, 127)] + + [r'\%03o' % i for i in xrange(127, 256)]) +_cescape_byte_to_str[9] = r'\t' # optional escape +_cescape_byte_to_str[10] = r'\n' # optional escape +_cescape_byte_to_str[13] = r'\r' # optional escape +_cescape_byte_to_str[39] = r"\'" # optional escape + +_cescape_byte_to_str[34] = r'\"' # necessary escape +_cescape_byte_to_str[92] = r'\\' # necessary escape + + +def CEscape(text, as_utf8): + """Escape a bytes string for use in an ascii protocol buffer. + + text.encode('string_escape') does not seem to satisfy our needs as it + encodes unprintable characters using two-digit hex escapes whereas our + C++ unescaping function allows hex escapes to be any length. So, + "\0011".encode('string_escape') ends up being "\\x011", which will be + decoded in C++ as a single-character string with char code 0x11. + + Args: + text: A byte string to be escaped + as_utf8: Specifies if result should be returned in UTF-8 encoding + Returns: + Escaped string + """ + # PY3 hack: make Ord work for str and bytes: + # //platforms/networking/data uses unicode here, hence basestring. + Ord = ord if isinstance(text, basestring) else lambda x: x + if as_utf8: + return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text) + return ''.join(_cescape_byte_to_str[Ord(c)] for c in text) + + +_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') +_cescape_highbit_to_str = ([chr(i) for i in range(0, 127)] + + [r'\%03o' % i for i in range(127, 256)]) + + +def CUnescape(text): + """Unescape a text string with C-style escape sequences to UTF-8 bytes.""" + + def ReplaceHex(m): + # Only replace the match if the number of leading back slashes is odd. i.e. + # the slash itself is not escaped. + if len(m.group(1)) & 1: + return m.group(1) + 'x0' + m.group(2) + return m.group(0) + + # This is required because the 'string_escape' encoding doesn't + # allow single-digit hex escapes (like '\xf'). + result = _CUNESCAPE_HEX.sub(ReplaceHex, text) + + if sys.version_info[0] < 3: ##PY25 +##!PY25 if str is bytes: # PY2 + return result.decode('string_escape') + result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result) + return (result.encode('ascii') # Make it bytes to allow decode. + .decode('unicode_escape') + # Make it bytes again to return the proper type. + .encode('raw_unicode_escape')) diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index cc6ac90..50f76f2 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. + """Contains routines for printing protocol messages in text format.""" __author__ = 'kenton@google.com (Kenton Varda)' @@ -35,46 +39,92 @@ __author__ = 'kenton@google.com (Kenton Varda)' import cStringIO import re -from collections import deque from google.protobuf.internal import type_checkers from google.protobuf import descriptor +from google.protobuf import text_encoding + +__all__ = ['MessageToString', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge'] -__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', - 'PrintFieldValue', 'Merge' ] +_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), + type_checkers.Int32ValueChecker(), + type_checkers.Uint64ValueChecker(), + type_checkers.Int64ValueChecker()) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) +_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) -# Infinity and NaN are not explicitly supported by Python pre-2.6, and -# float('inf') does not work on Windows (pre-2.6). -_INFINITY = 1e10000 # overflows, thus will actually be infinity. -_NAN = _INFINITY * 0 +class Error(Exception): + """Top-level module error for text_format.""" -class ParseError(Exception): + +class ParseError(Error): """Thrown in case of ASCII parsing error.""" -def MessageToString(message): +def MessageToString(message, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, + float_format=None): + """Convert protobuf message to text format. + + Floating point values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using float_format='.15g'. + + Args: + message: The protocol buffers message. + as_utf8: Produce text output in UTF8 format. + as_one_line: Don't introduce newlines between fields. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify floating point number formatting + (per the "Format Specification Mini-Language"); otherwise, str() is used. + + Returns: + A string of the text formatted protocol buffer message. + """ out = cStringIO.StringIO() - PrintMessage(message, out) + PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, + pointy_brackets=pointy_brackets, + use_index_order=use_index_order, + float_format=float_format) result = out.getvalue() out.close() + if as_one_line: + return result.rstrip() return result -def PrintMessage(message, out, indent = 0): - for field, value in message.ListFields(): +def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, + float_format=None): + fields = message.ListFields() + if use_index_order: + fields.sort(key=lambda x: x[0].index) + for field, value in fields: if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: - PrintField(field, element, out, indent) + PrintField(field, element, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) else: - PrintField(field, value, out, indent) + PrintField(field, value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) -def PrintField(field, value, out, indent = 0): +def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, float_format=None): """Print a single field name/value pair. For repeated fields, the value should be a single element.""" - out.write(' ' * indent); + out.write(' ' * indent) if field.is_extension: out.write('[') if (field.containing_type.GetOptions().message_set_wire_format and @@ -96,54 +146,168 @@ def PrintField(field, value, out, indent = 0): # don't include it. out.write(': ') - PrintFieldValue(field, value, out, indent) - out.write('\n') + PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) + if as_one_line: + out.write(' ') + else: + out.write('\n') -def PrintFieldValue(field, value, out, indent = 0): +def PrintFieldValue(field, value, out, indent=0, as_utf8=False, + as_one_line=False, pointy_brackets=False, + float_format=None): """Print a single field value (not including name). For repeated fields, the value should be a single element.""" + if pointy_brackets: + openb = '<' + closeb = '>' + else: + openb = '{' + closeb = '}' + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - out.write(' {\n') - PrintMessage(value, out, indent + 2) - out.write(' ' * indent + '}') + if as_one_line: + out.write(' %s ' % openb) + PrintMessage(value, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) + out.write(closeb) + else: + out.write(' %s\n' % openb) + PrintMessage(value, out, indent + 2, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + float_format=float_format) + out.write(' ' * indent + closeb) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - out.write(field.enum_type.values_by_number[value].name) + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') - out.write(_CEscape(value)) + if isinstance(value, unicode): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We need to escape non-UTF8 chars in TYPE_BYTES field. + out_as_utf8 = False + else: + out_as_utf8 = as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) out.write('\"') elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: if value: - out.write("true") + out.write('true') else: - out.write("false") + out.write('false') + elif field.cpp_type in _FLOAT_TYPES and float_format is not None: + out.write('{1:{0}}'.format(float_format, value)) else: out.write(str(value)) +def _ParseOrMerge(lines, message, allow_multiple_scalars): + """Converts an ASCII representation of a protocol message into a message. + + Args: + lines: Lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". + + Raises: + ParseError: On ASCII parsing problems. + """ + tokenizer = _Tokenizer(lines) + while not tokenizer.AtEnd(): + _MergeField(tokenizer, message, allow_multiple_scalars) + + +def Parse(text, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + text: Message ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + + Raises: + ParseError: On ASCII parsing problems. + """ + if not isinstance(text, str): text = text.decode('utf-8') + return ParseLines(text.split('\n'), message) + + def Merge(text, message): - """Merges an ASCII representation of a protocol message into a message. + """Parses an ASCII representation of a protocol message into a message. + + Like Parse(), but allows repeated values for a non-repeated field, and uses + the last one. Args: text: Message ASCII representation. message: A protocol buffer message to merge into. + Returns: + The same message passed as argument. + Raises: ParseError: On ASCII parsing problems. """ - tokenizer = _Tokenizer(text) - while not tokenizer.AtEnd(): - _MergeField(tokenizer, message) + return MergeLines(text.split('\n'), message) + + +def ParseLines(lines, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + lines: An iterable of lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + + Raises: + ParseError: On ASCII parsing problems. + """ + _ParseOrMerge(lines, message, False) + return message + + +def MergeLines(lines, message): + """Parses an ASCII representation of a protocol message into a message. + + Args: + lines: An iterable of lines of a message's ASCII representation. + message: A protocol buffer message to merge into. + + Returns: + The same message passed as argument. + Raises: + ParseError: On ASCII parsing problems. + """ + _ParseOrMerge(lines, message, True) + return message -def _MergeField(tokenizer, message): + +def _MergeField(tokenizer, message, allow_multiple_scalars): """Merges a single protocol message field into a message. Args: tokenizer: A tokenizer to parse the field name and values. message: A protocol message to record the data. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". Raises: ParseError: In case of ASCII parsing problems. @@ -159,7 +323,9 @@ def _MergeField(tokenizer, message): raise tokenizer.ParseErrorPreviousToken( 'Message type "%s" does not have extensions.' % message_descriptor.full_name) + # pylint: disable=protected-access field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access if not field: raise tokenizer.ParseErrorPreviousToken( 'Extension "%s" not registered.' % name) @@ -208,23 +374,31 @@ def _MergeField(tokenizer, message): sub_message = message.Extensions[field] else: sub_message = getattr(message, field.name) - sub_message.SetInParent() + sub_message.SetInParent() while not tokenizer.TryConsume(end_token): if tokenizer.AtEnd(): raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) - _MergeField(tokenizer, sub_message) + _MergeField(tokenizer, sub_message, allow_multiple_scalars) else: - _MergeScalarField(tokenizer, message, field) + _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') -def _MergeScalarField(tokenizer, message, field): +def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): """Merges a single protocol message scalar field into a message. Args: tokenizer: A tokenizer to parse the field value. message: A protocol message to record the data. field: The descriptor of the field to be merged. + allow_multiple_scalars: Determines if repeated values for a non-repeated + field are permitted, e.g., the string "foo: 1 foo: 2" for a + required/optional field named "foo". Raises: ParseError: In case of ASCII parsing problems. @@ -257,24 +431,7 @@ def _MergeScalarField(tokenizer, message, field): elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: value = tokenizer.ConsumeByteString() elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: - # Enum can be specified by a number (the enum value), or by - # a string literal (the enum name). - enum_descriptor = field.enum_type - if tokenizer.LookingAtInteger(): - number = tokenizer.ConsumeInt32() - enum_value = enum_descriptor.values_by_number.get(number, None) - if enum_value is None: - raise tokenizer.ParseErrorPreviousToken( - 'Enum type "%s" has no value with number %d.' % ( - enum_descriptor.full_name, number)) - else: - identifier = tokenizer.ConsumeIdentifier() - enum_value = enum_descriptor.values_by_name.get(identifier, None) - if enum_value is None: - raise tokenizer.ParseErrorPreviousToken( - 'Enum type "%s" has no value named %s.' % ( - enum_descriptor.full_name, identifier)) - value = enum_value.number + value = tokenizer.ConsumeEnum(field) else: raise RuntimeError('Unknown field type %d' % field.type) @@ -285,9 +442,19 @@ def _MergeScalarField(tokenizer, message, field): getattr(message, field.name).append(value) else: if field.is_extension: - message.Extensions[field] = value + if not allow_multiple_scalars and message.HasExtension(field): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value else: - setattr(message, field.name, value) + if not allow_multiple_scalars and message.HasField(field.name): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) class _Tokenizer(object): @@ -305,26 +472,19 @@ class _Tokenizer(object): '[0-9+-][0-9a-zA-Z_.+-]*|' # a number '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string - _IDENTIFIER = re.compile('\w+') - _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(), - type_checkers.Int32ValueChecker(), - type_checkers.Uint64ValueChecker(), - type_checkers.Int64ValueChecker()] - _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE) - _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE) - - def __init__(self, text_message): - self._text_message = text_message + _IDENTIFIER = re.compile(r'\w+') + def __init__(self, lines): self._position = 0 self._line = -1 self._column = 0 self._token_start = None self.token = '' - self._lines = deque(text_message.split('\n')) + self._lines = iter(lines) self._current_line = '' self._previous_line = 0 self._previous_column = 0 + self._more_lines = True self._SkipWhitespace() self.NextToken() @@ -334,25 +494,27 @@ class _Tokenizer(object): Returns: True iff the end was reached. """ - return not self._lines and not self._current_line + return not self.token def _PopLine(self): - while not self._current_line: - if not self._lines: + while len(self._current_line) <= self._column: + try: + self._current_line = self._lines.next() + except StopIteration: self._current_line = '' + self._more_lines = False return - self._line += 1 - self._column = 0 - self._current_line = self._lines.popleft() + else: + self._line += 1 + self._column = 0 def _SkipWhitespace(self): while True: self._PopLine() - match = re.match(self._WHITESPACE, self._current_line) + match = self._WHITESPACE.match(self._current_line, self._column) if not match: break length = len(match.group(0)) - self._current_line = self._current_line[length:] self._column += length def TryConsume(self, token): @@ -381,17 +543,6 @@ class _Tokenizer(object): if not self.TryConsume(token): raise self._ParseError('Expected "%s".' % token) - def LookingAtInteger(self): - """Checks if the current token is an integer. - - Returns: - True iff the current token is an integer. - """ - if not self.token: - return False - c = self.token[0] - return (c >= '0' and c <= '9') or c == '-' or c == '+' - def ConsumeIdentifier(self): """Consumes protocol message field identifier. @@ -402,7 +553,7 @@ class _Tokenizer(object): ParseError: If an identifier couldn't be consumed. """ result = self.token - if not re.match(self._IDENTIFIER, result): + if not self._IDENTIFIER.match(result): raise self._ParseError('Expected identifier.') self.NextToken() return result @@ -417,9 +568,9 @@ class _Tokenizer(object): ParseError: If a signed 32bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=True, is_long=False) + result = ParseInteger(self.token, is_signed=True, is_long=False) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -433,9 +584,9 @@ class _Tokenizer(object): ParseError: If an unsigned 32bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=False, is_long=False) + result = ParseInteger(self.token, is_signed=False, is_long=False) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -449,9 +600,9 @@ class _Tokenizer(object): ParseError: If a signed 64bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=True, is_long=True) + result = ParseInteger(self.token, is_signed=True, is_long=True) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -465,9 +616,9 @@ class _Tokenizer(object): ParseError: If an unsigned 64bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=False, is_long=True) + result = ParseInteger(self.token, is_signed=False, is_long=True) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -480,21 +631,10 @@ class _Tokenizer(object): Raises: ParseError: If a floating point number couldn't be consumed. """ - text = self.token - if re.match(self._FLOAT_INFINITY, text): - self.NextToken() - if text.startswith('-'): - return -_INFINITY - return _INFINITY - - if re.match(self._FLOAT_NAN, text): - self.NextToken() - return _NAN - try: - result = float(text) + result = ParseFloat(self.token) except ValueError, e: - raise self._FloatParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -507,14 +647,12 @@ class _Tokenizer(object): Raises: ParseError: If a boolean value couldn't be consumed. """ - if self.token == 'true': - self.NextToken() - return True - elif self.token == 'false': - self.NextToken() - return False - else: - raise self._ParseError('Expected "true" or "false".') + try: + result = ParseBool(self.token) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() + return result def ConsumeString(self): """Consumes a string value. @@ -525,7 +663,11 @@ class _Tokenizer(object): Raises: ParseError: If a string value couldn't be consumed. """ - return unicode(self.ConsumeByteString(), 'utf-8') + the_bytes = self.ConsumeByteString() + try: + return unicode(the_bytes, 'utf-8') + except UnicodeDecodeError, e: + raise self._StringParseError(e) def ConsumeByteString(self): """Consumes a byte array value. @@ -536,10 +678,11 @@ class _Tokenizer(object): Raises: ParseError: If a byte array value couldn't be consumed. """ - list = [self._ConsumeSingleByteString()] - while len(self.token) > 0 and self.token[0] in ('\'', '"'): - list.append(self._ConsumeSingleByteString()) - return "".join(list) + the_list = [self._ConsumeSingleByteString()] + while self.token and self.token[0] in ('\'', '"'): + the_list.append(self._ConsumeSingleByteString()) + return ''.encode('latin1').join(the_list) ##PY25 +##!PY25 return b''.join(the_list) def _ConsumeSingleByteString(self): """Consume one token of a string literal. @@ -550,48 +693,24 @@ class _Tokenizer(object): """ text = self.token if len(text) < 1 or text[0] not in ('\'', '"'): - raise self._ParseError('Exptected string.') + raise self._ParseError('Expected string.') if len(text) < 2 or text[-1] != text[0]: raise self._ParseError('String missing ending quote.') try: - result = _CUnescape(text[1:-1]) + result = text_encoding.CUnescape(text[1:-1]) except ValueError, e: raise self._ParseError(str(e)) self.NextToken() return result - def _ParseInteger(self, text, is_signed=False, is_long=False): - """Parses an integer. - - Args: - text: The text to parse. - is_signed: True if a signed integer must be parsed. - is_long: True if a long integer must be parsed. - - Returns: - The integer value. - - Raises: - ValueError: Thrown Iff the text is not a valid integer. - """ - pos = 0 - if text.startswith('-'): - pos += 1 - - base = 10 - if text.startswith('0x', pos) or text.startswith('0X', pos): - base = 16 - elif text.startswith('0', pos): - base = 8 - - # Do the actual parsing. Exception handling is propagated to caller. - result = int(text, base) - - # Check if the integer is sane. Exceptions handled by callers. - checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] - checker.CheckValue(result) + def ConsumeEnum(self, field): + try: + result = ParseEnum(field, self.token) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() return result def ParseErrorPreviousToken(self, message): @@ -611,63 +730,144 @@ class _Tokenizer(object): return ParseError('%d:%d : %s' % ( self._line + 1, self._column + 1, message)) - def _IntegerParseError(self, e): - return self._ParseError('Couldn\'t parse integer: ' + str(e)) - - def _FloatParseError(self, e): - return self._ParseError('Couldn\'t parse number: ' + str(e)) + def _StringParseError(self, e): + return self._ParseError('Couldn\'t parse string: ' + str(e)) def NextToken(self): """Reads the next meaningful token.""" self._previous_line = self._line self._previous_column = self._column - if self.AtEnd(): - self.token = '' - return + self._column += len(self.token) + self._SkipWhitespace() - # Make sure there is data to work on. - self._PopLine() + if not self._more_lines: + self.token = '' + return - match = re.match(self._TOKEN, self._current_line) + match = self._TOKEN.match(self._current_line, self._column) if match: token = match.group(0) - self._current_line = self._current_line[len(token):] self.token = token else: - self.token = self._current_line[0] - self._current_line = self._current_line[1:] - self._SkipWhitespace() + self.token = self._current_line[self._column] + + +def ParseInteger(text, is_signed=False, is_long=False): + """Parses an integer. + + Args: + text: The text to parse. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + # Do the actual parsing. Exception handling is propagated to caller. + try: + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + if is_long: + result = long(text, 0) + else: + result = int(text, 0) + except ValueError: + raise ValueError('Couldn\'t parse integer: %s' % text) + + # Check if the integer is sane. Exceptions handled by callers. + checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + +def ParseFloat(text): + """Parse a floating point number. + + Args: + text: Text to parse. + Returns: + The number parsed. -# text.encode('string_escape') does not seem to satisfy our needs as it -# encodes unprintable characters using two-digit hex escapes whereas our -# C++ unescaping function allows hex escapes to be any length. So, -# "\0011".encode('string_escape') ends up being "\\x011", which will be -# decoded in C++ as a single-character string with char code 0x11. -def _CEscape(text): - def escape(c): - o = ord(c) - if o == 10: return r"\n" # optional escape - if o == 13: return r"\r" # optional escape - if o == 9: return r"\t" # optional escape - if o == 39: return r"\'" # optional escape + Raises: + ValueError: If a floating point number couldn't be parsed. + """ + try: + # Assume Python compatible syntax. + return float(text) + except ValueError: + # Check alternative spellings. + if _FLOAT_INFINITY.match(text): + if text[0] == '-': + return float('-inf') + else: + return float('inf') + elif _FLOAT_NAN.match(text): + return float('nan') + else: + # assume '1.0f' format + try: + return float(text.rstrip('f')) + except ValueError: + raise ValueError('Couldn\'t parse float: %s' % text) + + +def ParseBool(text): + """Parse a boolean value. + + Args: + text: Text to parse. + + Returns: + Boolean values parsed - if o == 34: return r'\"' # necessary escape - if o == 92: return r"\\" # necessary escape + Raises: + ValueError: If text is not a valid boolean. + """ + if text in ('true', 't', '1'): + return True + elif text in ('false', 'f', '0'): + return False + else: + raise ValueError('Expected "true" or "false".') - if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes - return c - return "".join([escape(c) for c in text]) +def ParseEnum(field, value): + """Parse an enum value. -_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])') + The value can be specified by a number (the enum value), or by + a string literal (the enum name). + Args: + field: Enum field descriptor. + value: String value. -def _CUnescape(text): - def ReplaceHex(m): - return chr(int(m.group(0)[2:], 16)) - # This is required because the 'string_escape' encoding doesn't - # allow single-digit hex escapes (like '\xf'). - result = _CUNESCAPE_HEX.sub(ReplaceHex, text) - return result.decode('string_escape') + Returns: + Enum value number. + + Raises: + ValueError: If the enum value could not be parsed. + """ + enum_descriptor = field.enum_type + try: + number = int(value, 0) + except ValueError: + # Identifier. + enum_value = enum_descriptor.values_by_name.get(value, None) + if enum_value is None: + raise ValueError( + 'Enum type "%s" has no value named %s.' % ( + enum_descriptor.full_name, value)) + else: + # Numeric value. + enum_value = enum_descriptor.values_by_number.get(number, None) + if enum_value is None: + raise ValueError( + 'Enum type "%s" has no value with number %d.' % ( + enum_descriptor.full_name, number)) + return enum_value.number diff --git a/python/setup.py b/python/setup.py index 7242dae..9441d0e 100755 --- a/python/setup.py +++ b/python/setup.py @@ -1,25 +1,41 @@ #! /usr/bin/python # # See README for usage instructions. +import sys +import os +import subprocess # We must use setuptools, not distutils, because we need to use the # namespace_packages option for the "google" package. -from ez_setup import use_setuptools -use_setuptools() - -from setuptools import setup +try: + from setuptools import setup, Extension +except ImportError: + try: + from ez_setup import use_setuptools + use_setuptools() + from setuptools import setup, Extension + except ImportError: + sys.stderr.write( + "Could not import setuptools; make sure you have setuptools or " + "ez_setup installed.\n") + raise +from distutils.command.clean import clean as _clean +from distutils.command.build_py import build_py as _build_py from distutils.spawn import find_executable -import sys -import os -import subprocess maintainer_email = "protobuf@googlegroups.com" # Find the Protocol Compiler. -if os.path.exists("../src/protoc"): +if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']): + protoc = os.environ['PROTOC'] +elif os.path.exists("../src/protoc"): protoc = "../src/protoc" elif os.path.exists("../src/protoc.exe"): protoc = "../src/protoc.exe" +elif os.path.exists("../vsprojects/Debug/protoc.exe"): + protoc = "../vsprojects/Debug/protoc.exe" +elif os.path.exists("../vsprojects/Release/protoc.exe"): + protoc = "../vsprojects/Release/protoc.exe" else: protoc = find_executable("protoc") @@ -30,14 +46,14 @@ def generate_proto(source): output = source.replace(".proto", "_pb2.py").replace("../src/", "") - if not os.path.exists(source): - print "Can't find required file: " + source - sys.exit(-1) - if (not os.path.exists(output) or (os.path.exists(source) and os.path.getmtime(source) > os.path.getmtime(output))): - print "Generating %s..." % output + print ("Generating %s..." % output) + + if not os.path.exists(source): + sys.stderr.write("Can't find required file: %s\n" % source) + sys.exit(-1) if protoc == None: sys.stderr.write( @@ -49,74 +65,131 @@ def generate_proto(source): if subprocess.call(protoc_command) != 0: sys.exit(-1) -def MakeTestSuite(): - # This is apparently needed on some systems to make sure that the tests - # work even if a previous version is already installed. - if 'google' in sys.modules: - del sys.modules['google'] - +def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/unittest.proto") + generate_proto("../src/google/protobuf/unittest_custom_options.proto") generate_proto("../src/google/protobuf/unittest_import.proto") + generate_proto("../src/google/protobuf/unittest_import_public.proto") generate_proto("../src/google/protobuf/unittest_mset.proto") generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") + generate_proto("google/protobuf/internal/descriptor_pool_test1.proto") + generate_proto("google/protobuf/internal/descriptor_pool_test2.proto") + generate_proto("google/protobuf/internal/test_bad_identifiers.proto") + generate_proto("google/protobuf/internal/missing_enum_values.proto") generate_proto("google/protobuf/internal/more_extensions.proto") + generate_proto("google/protobuf/internal/more_extensions_dynamic.proto") generate_proto("google/protobuf/internal/more_messages.proto") + generate_proto("google/protobuf/internal/factory_test1.proto") + generate_proto("google/protobuf/internal/factory_test2.proto") + generate_proto("google/protobuf/pyext/python.proto") +def MakeTestSuite(): + # Test C++ implementation import unittest - import google.protobuf.internal.generator_test as generator_test - import google.protobuf.internal.descriptor_test as descriptor_test - import google.protobuf.internal.reflection_test as reflection_test - import google.protobuf.internal.service_reflection_test \ - as service_reflection_test - import google.protobuf.internal.text_format_test as text_format_test - import google.protobuf.internal.wire_format_test as wire_format_test + import google.protobuf.pyext.descriptor_cpp2_test as descriptor_cpp2_test + import google.protobuf.pyext.message_factory_cpp2_test \ + as message_factory_cpp2_test + import google.protobuf.pyext.reflection_cpp2_generated_test \ + as reflection_cpp2_generated_test loader = unittest.defaultTestLoader suite = unittest.TestSuite() - for test in [ generator_test, - descriptor_test, - reflection_test, - service_reflection_test, - text_format_test, - wire_format_test ]: + for test in [ descriptor_cpp2_test, + message_factory_cpp2_test, + reflection_cpp2_generated_test]: suite.addTest(loader.loadTestsFromModule(test)) - return suite -if __name__ == '__main__': - # TODO(kenton): Integrate this into setuptools somehow? - if len(sys.argv) >= 2 and sys.argv[1] == "clean": - # Delete generated _pb2.py files and .pyc files in the code tree. +class clean(_clean): + def run(self): + # Delete generated files in the code tree. for (dirpath, dirnames, filenames) in os.walk("."): for filename in filenames: filepath = os.path.join(dirpath, filename) - if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"): + if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \ + filepath.endswith(".so") or filepath.endswith(".o") or \ + filepath.endswith('google/protobuf/compiler/__init__.py'): os.remove(filepath) - else: + # _clean is an old-style class, so super() doesn't work. + _clean.run(self) + +class build_py(_build_py): + def run(self): # Generate necessary .proto file if it doesn't exist. - # TODO(kenton): Maybe we should hook this into a distutils command? generate_proto("../src/google/protobuf/descriptor.proto") + generate_proto("../src/google/protobuf/compiler/plugin.proto") + GenerateUnittestProtos() + + # Make sure google.protobuf/** are valid packages. + for path in ['', 'internal/', 'compiler/', 'pyext/']: + try: + open('google/protobuf/%s__init__.py' % path, 'a').close() + except EnvironmentError: + pass + # _build_py is an old-style class, so super() doesn't work. + _build_py.run(self) + # TODO(mrovner): Subclass to run 2to3 on some files only. + # Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's "Approach 2" + # section on how to get 2to3 to run on source files during install under + # Python 3. This class seems like a good place to put logic that calls + # python3's distutils.util.run_2to3 on the subset of the files we have in our + # release that are subject to conversion. + # See code reference in previous code review. + +if __name__ == '__main__': + ext_module_list = [] + cpp_impl = '--cpp_implementation' + if cpp_impl in sys.argv: + sys.argv.remove(cpp_impl) + # C++ implementation extension + ext_module_list.append(Extension( + "google.protobuf.pyext._message", + [ "google/protobuf/pyext/descriptor.cc", + "google/protobuf/pyext/message.cc", + "google/protobuf/pyext/extension_dict.cc", + "google/protobuf/pyext/repeated_scalar_container.cc", + "google/protobuf/pyext/repeated_composite_container.cc" ], + define_macros=[('GOOGLE_PROTOBUF_HAS_ONEOF', '1')], + include_dirs = [ ".", "../src"], + libraries = [ "protobuf" ], + library_dirs = [ '../src/.libs' ], + )) setup(name = 'protobuf', - version = '2.3.0', + version = '2.6.0', packages = [ 'google' ], namespace_packages = [ 'google' ], test_suite = 'setup.MakeTestSuite', + google_test_dir = "google/protobuf/internal", # Must list modules explicitly so that we don't install tests. py_modules = [ + 'google.protobuf.internal.api_implementation', 'google.protobuf.internal.containers', + 'google.protobuf.internal.cpp_message', 'google.protobuf.internal.decoder', 'google.protobuf.internal.encoder', + 'google.protobuf.internal.enum_type_wrapper', 'google.protobuf.internal.message_listener', + 'google.protobuf.internal.python_message', 'google.protobuf.internal.type_checkers', 'google.protobuf.internal.wire_format', 'google.protobuf.descriptor', 'google.protobuf.descriptor_pb2', + 'google.protobuf.compiler.plugin_pb2', 'google.protobuf.message', + 'google.protobuf.descriptor_database', + 'google.protobuf.descriptor_pool', + 'google.protobuf.message_factory', 'google.protobuf.reflection', 'google.protobuf.service', 'google.protobuf.service_reflection', - 'google.protobuf.text_format' ], + 'google.protobuf.symbol_database', + 'google.protobuf.text_encoding', + 'google.protobuf.text_format'], + cmdclass = { 'clean': clean, 'build_py': build_py }, + install_requires = ['setuptools'], + setup_requires = ['google-apputils'], + ext_modules = ext_module_list, url = 'http://code.google.com/p/protobuf/', maintainer = maintainer_email, maintainer_email = 'protobuf@googlegroups.com', |