summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authorJeff Davidson <jpd@google.com>2014-09-15 16:29:06 -0700
committerJeff Davidson <jpd@google.com>2015-01-15 14:10:53 -0800
commita3b2a6da25a76f17c73d31def3952feb0fd2296e (patch)
tree586f7d5e9a7e05af45d0e821188097c0faa96219 /python
parentc7c25812eb19d080087b71e08bfe35aff9f21433 (diff)
downloadexternal_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.zip
external_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.tar.gz
external_protobuf-a3b2a6da25a76f17c73d31def3952feb0fd2296e.tar.bz2
Update protobuf library from 2.3 to 2.6.
Copied in all files from the open source protobuf project at commit edc5994525c79cd1919859a370837a6ff7c8e308, removing files which have been renamed (COPYING.txt -> LICENSE, README.txt -> README.md). Removed 2.3 prebuilts, which is an approach that will not work due to incompatibility with the 2.6 runtime. Merged in micro/nano-specific changes in the following files: -Android.mk - updated list of C++/Java sources, bumped versions -java/README.txt - merged in micro/nano instructions, bumped versions -java/pom.xml - merged in micro/nano build rules, set packaging to jar -src/Makefile.am - merged in references to micro/nano generators -src/google/protobuf/compiler/javamicro/javamicro_file.h - imported google/protobuf/compiler/code_generator.h and removed redundant OutputDirectory class. -src/google/protobuf/compiler/javanano/javanano_file.h - same -Replaced instances of vector with std::vector as needed to get libprotobuf-cpp-full to compile. Plan to upstream this fix per discussion with protobuf maintainers. Reran autogen.sh to update ./configure and associated scripts. Change-Id: I949d32fb5126f1c05e2a6ed48f6636a4a9b15a48
Diffstat (limited to 'python')
-rw-r--r--python/README.txt36
-rwxr-xr-xpython/ez_setup.py27
-rwxr-xr-xpython/google/protobuf/descriptor.py299
-rw-r--r--python/google/protobuf/descriptor_database.py137
-rw-r--r--python/google/protobuf/descriptor_pool.py643
-rw-r--r--python/google/protobuf/internal/api_implementation.cc139
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py89
-rw-r--r--python/google/protobuf/internal/api_implementation_default_test.py63
-rwxr-xr-xpython/google/protobuf/internal/containers.py59
-rwxr-xr-xpython/google/protobuf/internal/cpp_message.py663
-rwxr-xr-xpython/google/protobuf/internal/decoder.py226
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py63
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py564
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test1.proto94
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test2.proto70
-rw-r--r--python/google/protobuf/internal/descriptor_python_test.py54
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py423
-rwxr-xr-xpython/google/protobuf/internal/encoder.py118
-rw-r--r--python/google/protobuf/internal/enum_type_wrapper.py89
-rw-r--r--python/google/protobuf/internal/factory_test1.proto57
-rw-r--r--python/google/protobuf/internal/factory_test2.proto92
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py139
-rw-r--r--python/google/protobuf/internal/message_factory_python_test.py54
-rw-r--r--python/google/protobuf/internal/message_factory_test.py131
-rw-r--r--python/google/protobuf/internal/message_python_test.py54
-rwxr-xr-xpython/google/protobuf/internal/message_test.py611
-rw-r--r--python/google/protobuf/internal/missing_enum_values.proto50
-rw-r--r--python/google/protobuf/internal/more_extensions_dynamic.proto49
-rwxr-xr-xpython/google/protobuf/internal/python_message.py1247
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py868
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py6
-rw-r--r--python/google/protobuf/internal/symbol_database_test.py120
-rw-r--r--python/google/protobuf/internal/test_bad_identifiers.proto52
-rwxr-xr-xpython/google/protobuf/internal/test_util.py115
-rwxr-xr-xpython/google/protobuf/internal/text_encoding_test.py68
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py560
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py72
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py231
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py8
-rwxr-xr-xpython/google/protobuf/message.py32
-rw-r--r--python/google/protobuf/message_factory.py155
-rw-r--r--python/google/protobuf/pyext/README6
-rw-r--r--python/google/protobuf/pyext/__init__.py0
-rw-r--r--python/google/protobuf/pyext/cpp_message.py61
-rw-r--r--python/google/protobuf/pyext/descriptor.cc357
-rw-r--r--python/google/protobuf/pyext/descriptor.h96
-rw-r--r--python/google/protobuf/pyext/descriptor_cpp2_test.py58
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc338
-rw-r--r--python/google/protobuf/pyext/extension_dict.h123
-rw-r--r--python/google/protobuf/pyext/message.cc2561
-rw-r--r--python/google/protobuf/pyext/message.h305
-rw-r--r--python/google/protobuf/pyext/message_factory_cpp2_test.py56
-rw-r--r--python/google/protobuf/pyext/proto2_api_test.proto38
-rw-r--r--python/google/protobuf/pyext/python.proto66
-rw-r--r--python/google/protobuf/pyext/python_protobuf.h57
-rwxr-xr-xpython/google/protobuf/pyext/reflection_cpp2_generated_test.py94
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc763
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.h172
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc825
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.h112
-rw-r--r--python/google/protobuf/pyext/scoped_pyobject_ptr.h95
-rwxr-xr-xpython/google/protobuf/reflection.py1069
-rw-r--r--python/google/protobuf/symbol_database.py185
-rw-r--r--python/google/protobuf/text_encoding.py110
-rwxr-xr-xpython/google/protobuf/text_format.py584
-rwxr-xr-xpython/setup.py157
66 files changed, 15256 insertions, 1659 deletions
diff --git a/python/README.txt b/python/README.txt
index 96f1a73..adfa46b 100644
--- a/python/README.txt
+++ b/python/README.txt
@@ -43,9 +43,13 @@ Installation
$ protoc --version
-4) Run the tests:
+4) Build and run the tests:
- $ python setup.py test
+ $ python setup.py build
+ $ python setup.py google_test
+
+ If you want to test c++ implementation, run:
+ $ python setup.py test --cpp_implementation
If some tests fail, this library may not work correctly on your
system. Continue at your own risk.
@@ -61,8 +65,13 @@ Installation
5) Install:
$ python setup.py install
+ or:
+ $ python setup.py install --cpp_implementation
This step may require superuser privileges.
+ NOTE: To use C++ implementation, you need to install C++ protobuf runtime
+ library of the same version and export the environment variable before this
+ step. See the "C++ Implementation" section below for more details.
Usage
=====
@@ -71,3 +80,26 @@ The complete documentation for Protocol Buffers is available via the
web at:
http://code.google.com/apis/protocolbuffers/
+
+C++ Implementation
+==================
+
+The C++ implementation for Python messages is built as a Python extension to
+improve the overall protobuf Python performance.
+
+To use the C++ implementation, you need to:
+1) Install the C++ protobuf runtime library, please see instructions in the
+ parent directory.
+2) Export an environment variable:
+
+ $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
+ $ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
+
+You need to export this variable before running setup.py script to build and
+install the extension. You must also set the variable at runtime, otherwise
+the pure-Python implementation will be used. In a future release, we will
+change the default so that C++ implementation is used whenever it is available.
+It is strongly recommended to run `python setup.py test` after setting the
+variable to "cpp", so the tests will be against C++ implemented Python
+messages.
+
diff --git a/python/ez_setup.py b/python/ez_setup.py
index b7a9849..3aec98e 100755
--- a/python/ez_setup.py
+++ b/python/ez_setup.py
@@ -2,7 +2,7 @@
# This file was obtained from:
# http://peak.telecommunity.com/dist/ez_setup.py
-# on 2009/4/17.
+# on 2011/1/21.
"""Bootstrap setuptools installation
@@ -19,7 +19,7 @@ the appropriate options to ``use_setuptools()``.
This file can also be run as a script to install or upgrade setuptools.
"""
import sys
-DEFAULT_VERSION = "0.6c9"
+DEFAULT_VERSION = "0.6c11"
DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
md5_data = {
@@ -33,6 +33,14 @@ md5_data = {
'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
+ 'setuptools-0.6c10-py2.3.egg': 'ce1e2ab5d3a0256456d9fc13800a7090',
+ 'setuptools-0.6c10-py2.4.egg': '57d6d9d6e9b80772c59a53a8433a5dd4',
+ 'setuptools-0.6c10-py2.5.egg': 'de46ac8b1c97c895572e5e8596aeb8c7',
+ 'setuptools-0.6c10-py2.6.egg': '58ea40aef06da02ce641495523a0b7f5',
+ 'setuptools-0.6c11-py2.3.egg': '2baeac6e13d414a9d28e7ba5b5a596de',
+ 'setuptools-0.6c11-py2.4.egg': 'bd639f9b0eac4c42497034dec2ec0c2b',
+ 'setuptools-0.6c11-py2.5.egg': '64c94f3bf7a72a13ec83e0b24f2749b2',
+ 'setuptools-0.6c11-py2.6.egg': 'bfa92100bd772d5a213eedd356d64086',
'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
@@ -99,6 +107,7 @@ def use_setuptools(
except ImportError:
return do_download()
try:
+ return do_download()
pkg_resources.require("setuptools>="+version); return
except pkg_resources.VersionConflict, e:
if was_imported:
@@ -109,11 +118,11 @@ def use_setuptools(
"\n\n(Currently using %r)"
) % (version, e.args[0])
sys.exit(2)
- else:
- del pkg_resources, sys.modules['pkg_resources'] # reload ok
- return do_download()
except pkg_resources.DistributionNotFound:
- return do_download()
+ pass
+
+ del pkg_resources, sys.modules['pkg_resources'] # reload ok
+ return do_download()
def download_setuptools(
version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
@@ -273,9 +282,3 @@ if __name__=='__main__':
update_md5(sys.argv[2:])
else:
main(sys.argv[1:])
-
-
-
-
-
-
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index aa4ab96..555498d 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -28,15 +28,9 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# TODO(robinson): We probably need to provide deep-copy methods for
-# descriptor types. When a FieldDescriptor is passed into
-# Descriptor.__init__(), we should make a deep copy and then set
-# containing_type on it. Alternatively, we could just get
-# rid of containing_type (iit's not needed for reflection.py, at least).
+# Needs to stay compatible with Python 2.5 due to GAE.
#
-# TODO(robinson): Print method?
-#
-# TODO(robinson): Useful __repr__?
+# Copyright 2007 Google Inc. All Rights Reserved.
"""Descriptors essentially contain exactly the information found in a .proto
file, in types that make this information accessible in Python.
@@ -44,11 +38,28 @@ file, in types that make this information accessible in Python.
__author__ = 'robinson@google.com (Will Robinson)'
+from google.protobuf.internal import api_implementation
+
+
+if api_implementation.Type() == 'cpp':
+ # Used by MakeDescriptor in cpp mode
+ import os
+ import uuid
+
+ if api_implementation.Version() == 2:
+ from google.protobuf.pyext import _message
+ else:
+ from google.protobuf.internal import cpp_message
+
class Error(Exception):
"""Base error for this module."""
+class TypeTransformationError(Error):
+ """Error transforming between python proto type and corresponding C++ type."""
+
+
class DescriptorBase(object):
"""Descriptors base class.
@@ -75,6 +86,18 @@ class DescriptorBase(object):
# Does this descriptor have non-default options?
self.has_options = options is not None
+ def _SetOptions(self, options, options_class_name):
+ """Sets the descriptor's options
+
+ This function is used in generated proto2 files to update descriptor
+ options. It must not be used outside proto2.
+ """
+ self._options = options
+ self._options_class_name = options_class_name
+
+ # Does this descriptor have non-default options?
+ self.has_options = options is not None
+
def GetOptions(self):
"""Retrieves descriptor options.
@@ -204,13 +227,21 @@ class Descriptor(_NestedDescriptorBase):
options: (descriptor_pb2.MessageOptions) Protocol message options or None
to use default message options.
+ oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields
+ in this message.
+ oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|,
+ but indexed by "name" attribute.
+
file: (FileDescriptor) Reference to file descriptor.
"""
+ # NOTE(tmarek): The file argument redefining a builtin is nothing we can
+ # fix right now since we don't know how many clients already rely on the
+ # name of the argument.
def __init__(self, name, full_name, filename, containing_type, fields,
nested_types, enum_types, extensions, options=None,
- is_extendable=True, extension_ranges=None, file=None,
- serialized_start=None, serialized_end=None):
+ is_extendable=True, extension_ranges=None, oneofs=None,
+ file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin
"""Arguments to __init__() are as described in the description
of Descriptor fields above.
@@ -220,7 +251,7 @@ class Descriptor(_NestedDescriptorBase):
super(Descriptor, self).__init__(
options, 'MessageOptions', name, full_name, file,
containing_type, serialized_start=serialized_start,
- serialized_end=serialized_start)
+ serialized_end=serialized_end)
# We have fields in addition to fields_by_name and fields_by_number,
# so that:
@@ -234,6 +265,8 @@ class Descriptor(_NestedDescriptorBase):
self.fields_by_name = dict((f.name, f) for f in fields)
self.nested_types = nested_types
+ for nested_type in nested_types:
+ nested_type.containing_type = self
self.nested_types_by_name = dict((t.name, t) for t in nested_types)
self.enum_types = enum_types
@@ -249,9 +282,28 @@ class Descriptor(_NestedDescriptorBase):
self.extensions_by_name = dict((f.name, f) for f in extensions)
self.is_extendable = is_extendable
self.extension_ranges = extension_ranges
+ self.oneofs = oneofs if oneofs is not None else []
+ self.oneofs_by_name = dict((o.name, o) for o in self.oneofs)
+ for oneof in self.oneofs:
+ oneof.containing_type = self
- self._serialized_start = serialized_start
- self._serialized_end = serialized_end
+ def EnumValueName(self, enum, value):
+ """Returns the string name of an enum value.
+
+ This is just a small helper method to simplify a common operation.
+
+ Args:
+ enum: string name of the Enum.
+ value: int, value of the enum.
+
+ Returns:
+ string name of the enum value.
+
+ Raises:
+ KeyError if either the Enum doesn't exist or the value is not a valid
+ value for the enum.
+ """
+ return self.enum_types_by_name[enum].values_by_number[value].name
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.DescriptorProto.
@@ -278,7 +330,7 @@ class FieldDescriptor(DescriptorBase):
"""Descriptor for a single field in a .proto file.
- A FieldDescriptor instance has the following attriubtes:
+ A FieldDescriptor instance has the following attributes:
name: (str) Name of this field, exactly as it appears in .proto.
full_name: (str) Name of this field, including containing scope. This is
@@ -319,6 +371,9 @@ class FieldDescriptor(DescriptorBase):
options: (descriptor_pb2.FieldOptions) Protocol message field options or
None to use default field options.
+
+ containing_oneof: (OneofDescriptor) If the field is a member of a oneof
+ union, contains its descriptor. Otherwise, None.
"""
# Must be consistent with C++ FieldDescriptor::Type enum in
@@ -361,6 +416,27 @@ class FieldDescriptor(DescriptorBase):
CPPTYPE_MESSAGE = 10
MAX_CPPTYPE = 10
+ _PYTHON_TO_CPP_PROTO_TYPE_MAP = {
+ TYPE_DOUBLE: CPPTYPE_DOUBLE,
+ TYPE_FLOAT: CPPTYPE_FLOAT,
+ TYPE_ENUM: CPPTYPE_ENUM,
+ TYPE_INT64: CPPTYPE_INT64,
+ TYPE_SINT64: CPPTYPE_INT64,
+ TYPE_SFIXED64: CPPTYPE_INT64,
+ TYPE_UINT64: CPPTYPE_UINT64,
+ TYPE_FIXED64: CPPTYPE_UINT64,
+ TYPE_INT32: CPPTYPE_INT32,
+ TYPE_SFIXED32: CPPTYPE_INT32,
+ TYPE_SINT32: CPPTYPE_INT32,
+ TYPE_UINT32: CPPTYPE_UINT32,
+ TYPE_FIXED32: CPPTYPE_UINT32,
+ TYPE_BYTES: CPPTYPE_STRING,
+ TYPE_STRING: CPPTYPE_STRING,
+ TYPE_BOOL: CPPTYPE_BOOL,
+ TYPE_MESSAGE: CPPTYPE_MESSAGE,
+ TYPE_GROUP: CPPTYPE_MESSAGE
+ }
+
# Must be consistent with C++ FieldDescriptor::Label enum in
# descriptor.h.
#
@@ -370,10 +446,16 @@ class FieldDescriptor(DescriptorBase):
LABEL_REPEATED = 3
MAX_LABEL = 3
+ # Must be consistent with C++ constants kMaxNumber, kFirstReservedNumber,
+ # and kLastReservedNumber in descriptor.h
+ MAX_FIELD_NUMBER = (1 << 29) - 1
+ FIRST_RESERVED_FIELD_NUMBER = 19000
+ LAST_RESERVED_FIELD_NUMBER = 19999
+
def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None,
- has_default_value=True):
+ has_default_value=True, containing_oneof=None):
"""The arguments are as described in the description of FieldDescriptor
attributes above.
@@ -396,6 +478,45 @@ class FieldDescriptor(DescriptorBase):
self.enum_type = enum_type
self.is_extension = is_extension
self.extension_scope = extension_scope
+ self.containing_oneof = containing_oneof
+ if api_implementation.Type() == 'cpp':
+ if is_extension:
+ if api_implementation.Version() == 2:
+ # pylint: disable=protected-access
+ self._cdescriptor = (
+ _message.Message._GetExtensionDescriptor(full_name))
+ # pylint: enable=protected-access
+ else:
+ self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name)
+ else:
+ if api_implementation.Version() == 2:
+ # pylint: disable=protected-access
+ self._cdescriptor = _message.Message._GetFieldDescriptor(full_name)
+ # pylint: enable=protected-access
+ else:
+ self._cdescriptor = cpp_message.GetFieldDescriptor(full_name)
+ else:
+ self._cdescriptor = None
+
+ @staticmethod
+ def ProtoTypeToCppProtoType(proto_type):
+ """Converts from a Python proto type to a C++ Proto Type.
+
+ The Python ProtocolBuffer classes specify both the 'Python' datatype and the
+ 'C++' datatype - and they're not the same. This helper method should
+ translate from one to another.
+
+ Args:
+ proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*)
+ Returns:
+ descriptor.FieldDescriptor.CPPTYPE_*, the C++ type.
+ Raises:
+ TypeTransformationError: when the Python proto type isn't known.
+ """
+ try:
+ return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type]
+ except KeyError:
+ raise TypeTransformationError('Unknown proto_type: %s' % proto_type)
class EnumDescriptor(_NestedDescriptorBase):
@@ -434,7 +555,7 @@ class EnumDescriptor(_NestedDescriptorBase):
super(EnumDescriptor, self).__init__(
options, 'EnumOptions', name, full_name, file,
containing_type, serialized_start=serialized_start,
- serialized_end=serialized_start)
+ serialized_end=serialized_end)
self.values = values
for value in self.values:
@@ -442,9 +563,6 @@ class EnumDescriptor(_NestedDescriptorBase):
self.values_by_name = dict((v.name, v) for v in values)
self.values_by_number = dict((v.number, v) for v in values)
- self._serialized_start = serialized_start
- self._serialized_end = serialized_end
-
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.EnumDescriptorProto.
@@ -479,6 +597,29 @@ class EnumValueDescriptor(DescriptorBase):
self.type = type
+class OneofDescriptor(object):
+ """Descriptor for a oneof field.
+
+ name: (str) Name of the oneof field.
+ full_name: (str) Full name of the oneof field, including package name.
+ index: (int) 0-based index giving the order of the oneof field inside
+ its containing type.
+ containing_type: (Descriptor) Descriptor of the protocol message
+ type that contains this field. Set by the Descriptor constructor
+ if we're passed into one.
+ fields: (list of FieldDescriptor) The list of field descriptors this
+ oneof can contain.
+ """
+
+ def __init__(self, name, full_name, index, containing_type, fields):
+ """Arguments are as described in the attribute description above."""
+ self.name = name
+ self.full_name = full_name
+ self.index = index
+ self.containing_type = containing_type
+ self.fields = fields
+
+
class ServiceDescriptor(_NestedDescriptorBase):
"""Descriptor for a service.
@@ -557,20 +698,43 @@ class MethodDescriptor(DescriptorBase):
class FileDescriptor(DescriptorBase):
"""Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
+ Note that enum_types_by_name, extensions_by_name, and dependencies
+ fields are only set by the message_factory module, and not by the
+ generated proto code.
+
name: name of file, relative to root of source tree.
package: name of the package
serialized_pb: (str) Byte string of serialized
descriptor_pb2.FileDescriptorProto.
+ dependencies: List of other FileDescriptors this FileDescriptor depends on.
+ message_types_by_name: Dict of message names of their descriptors.
+ enum_types_by_name: Dict of enum names and their descriptors.
+ extensions_by_name: Dict of extension names and their descriptors.
"""
- def __init__(self, name, package, options=None, serialized_pb=None):
+ def __init__(self, name, package, options=None, serialized_pb=None,
+ dependencies=None):
"""Constructor."""
super(FileDescriptor, self).__init__(options, 'FileOptions')
+ self.message_types_by_name = {}
self.name = name
self.package = package
self.serialized_pb = serialized_pb
+ self.enum_types_by_name = {}
+ self.extensions_by_name = {}
+ self.dependencies = (dependencies or [])
+
+ if (api_implementation.Type() == 'cpp' and
+ self.serialized_pb is not None):
+ if api_implementation.Version() == 2:
+ # pylint: disable=protected-access
+ _message.Message._BuildFile(self.serialized_pb)
+ # pylint: enable=protected-access
+ else:
+ cpp_message.BuildFile(self.serialized_pb)
+
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.FileDescriptorProto.
@@ -588,3 +752,98 @@ def _ParseOptions(message, string):
"""
message.ParseFromString(string)
return message
+
+
+def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True):
+ """Make a protobuf Descriptor given a DescriptorProto protobuf.
+
+ Handles nested descriptors. Note that this is limited to the scope of defining
+ a message inside of another message. Composite fields can currently only be
+ resolved if the message is defined in the same scope as the field.
+
+ Args:
+ desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
+ package: Optional package name for the new message Descriptor (string).
+ build_file_if_cpp: Update the C++ descriptor pool if api matches.
+ Set to False on recursion, so no duplicates are created.
+ Returns:
+ A Descriptor for protobuf messages.
+ """
+ if api_implementation.Type() == 'cpp' and build_file_if_cpp:
+ # The C++ implementation requires all descriptors to be backed by the same
+ # definition in the C++ descriptor pool. To do this, we build a
+ # FileDescriptorProto with the same definition as this descriptor and build
+ # it into the pool.
+ from google.protobuf import descriptor_pb2
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.message_type.add().MergeFrom(desc_proto)
+
+ # Generate a random name for this proto file to prevent conflicts with
+ # any imported ones. We need to specify a file name so BuildFile accepts
+ # our FileDescriptorProto, but it is not important what that file name
+ # is actually set to.
+ proto_name = str(uuid.uuid4())
+
+ if package:
+ file_descriptor_proto.name = os.path.join(package.replace('.', '/'),
+ proto_name + '.proto')
+ file_descriptor_proto.package = package
+ else:
+ file_descriptor_proto.name = proto_name + '.proto'
+
+ if api_implementation.Version() == 2:
+ # pylint: disable=protected-access
+ _message.Message._BuildFile(file_descriptor_proto.SerializeToString())
+ # pylint: enable=protected-access
+ else:
+ cpp_message.BuildFile(file_descriptor_proto.SerializeToString())
+
+ full_message_name = [desc_proto.name]
+ if package: full_message_name.insert(0, package)
+
+ # Create Descriptors for enum types
+ enum_types = {}
+ for enum_proto in desc_proto.enum_type:
+ full_name = '.'.join(full_message_name + [enum_proto.name])
+ enum_desc = EnumDescriptor(
+ enum_proto.name, full_name, None, [
+ EnumValueDescriptor(enum_val.name, ii, enum_val.number)
+ for ii, enum_val in enumerate(enum_proto.value)])
+ enum_types[full_name] = enum_desc
+
+ # Create Descriptors for nested types
+ nested_types = {}
+ for nested_proto in desc_proto.nested_type:
+ full_name = '.'.join(full_message_name + [nested_proto.name])
+ # Nested types are just those defined inside of the message, not all types
+ # used by fields in the message, so no loops are possible here.
+ nested_desc = MakeDescriptor(nested_proto,
+ package='.'.join(full_message_name),
+ build_file_if_cpp=False)
+ nested_types[full_name] = nested_desc
+
+ fields = []
+ for field_proto in desc_proto.field:
+ full_name = '.'.join(full_message_name + [field_proto.name])
+ enum_desc = None
+ nested_desc = None
+ if field_proto.HasField('type_name'):
+ type_name = field_proto.type_name
+ full_type_name = '.'.join(full_message_name +
+ [type_name[type_name.rfind('.')+1:]])
+ if full_type_name in nested_types:
+ nested_desc = nested_types[full_type_name]
+ elif full_type_name in enum_types:
+ enum_desc = enum_types[full_type_name]
+ # Else type_name references a non-local type, which isn't implemented
+ field = FieldDescriptor(
+ field_proto.name, full_name, field_proto.number - 1,
+ field_proto.number, field_proto.type,
+ FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type),
+ field_proto.label, None, nested_desc, enum_desc, None, False, None,
+ has_default_value=False)
+ fields.append(field)
+
+ desc_name = '.'.join(full_message_name)
+ return Descriptor(desc_proto.name, desc_name, None, None, fields,
+ nested_types.values(), enum_types.values(), [])
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
new file mode 100644
index 0000000..9f5a117
--- /dev/null
+++ b/python/google/protobuf/descriptor_database.py
@@ -0,0 +1,137 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Provides a container for DescriptorProtos."""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+
+class Error(Exception):
+ pass
+
+
+class DescriptorDatabaseConflictingDefinitionError(Error):
+ """Raised when a proto is added with the same name & different descriptor."""
+
+
+class DescriptorDatabase(object):
+ """A container accepting FileDescriptorProtos and maps DescriptorProtos."""
+
+ def __init__(self):
+ self._file_desc_protos_by_file = {}
+ self._file_desc_protos_by_symbol = {}
+
+ def Add(self, file_desc_proto):
+ """Adds the FileDescriptorProto and its types to this database.
+
+ Args:
+ file_desc_proto: The FileDescriptorProto to add.
+ Raises:
+ DescriptorDatabaseException: if an attempt is made to add a proto
+ with the same name but different definition than an exisiting
+ proto in the database.
+ """
+ proto_name = file_desc_proto.name
+ if proto_name not in self._file_desc_protos_by_file:
+ self._file_desc_protos_by_file[proto_name] = file_desc_proto
+ elif self._file_desc_protos_by_file[proto_name] != file_desc_proto:
+ raise DescriptorDatabaseConflictingDefinitionError(
+ '%s already added, but with different descriptor.' % proto_name)
+
+ package = file_desc_proto.package
+ for message in file_desc_proto.message_type:
+ self._file_desc_protos_by_symbol.update(
+ (name, file_desc_proto) for name in _ExtractSymbols(message, package))
+ for enum in file_desc_proto.enum_type:
+ self._file_desc_protos_by_symbol[
+ '.'.join((package, enum.name))] = file_desc_proto
+
+ def FindFileByName(self, name):
+ """Finds the file descriptor proto by file name.
+
+ Typically the file name is a relative path ending to a .proto file. The
+ proto with the given name will have to have been added to this database
+ using the Add method or else an error will be raised.
+
+ Args:
+ name: The file name to find.
+
+ Returns:
+ The file descriptor proto matching the name.
+
+ Raises:
+ KeyError if no file by the given name was added.
+ """
+
+ return self._file_desc_protos_by_file[name]
+
+ def FindFileContainingSymbol(self, symbol):
+ """Finds the file descriptor proto containing the specified symbol.
+
+ The symbol should be a fully qualified name including the file descriptor's
+ package and any containing messages. Some examples:
+
+ 'some.package.name.Message'
+ 'some.package.name.Message.NestedEnum'
+
+ The file descriptor proto containing the specified symbol must be added to
+ this database using the Add method or else an error will be raised.
+
+ Args:
+ symbol: The fully qualified symbol name.
+
+ Returns:
+ The file descriptor proto containing the symbol.
+
+ Raises:
+ KeyError if no file contains the specified symbol.
+ """
+
+ return self._file_desc_protos_by_symbol[symbol]
+
+
+def _ExtractSymbols(desc_proto, package):
+ """Pulls out all the symbols from a descriptor proto.
+
+ Args:
+ desc_proto: The proto to extract symbols from.
+ package: The package containing the descriptor type.
+
+ Yields:
+ The fully qualified name found in the descriptor.
+ """
+
+ message_name = '.'.join((package, desc_proto.name))
+ yield message_name
+ for nested_type in desc_proto.nested_type:
+ for symbol in _ExtractSymbols(nested_type, message_name):
+ yield symbol
+ for enum_type in desc_proto.enum_type:
+ yield '.'.join((message_name, enum_type.name))
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
new file mode 100644
index 0000000..372f458
--- /dev/null
+++ b/python/google/protobuf/descriptor_pool.py
@@ -0,0 +1,643 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Provides DescriptorPool to use as a container for proto2 descriptors.
+
+The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
+a collection of protocol buffer descriptors for use when dynamically creating
+message types at runtime.
+
+For most applications protocol buffers should be used via modules generated by
+the protocol buffer compiler tool. This should only be used when the type of
+protocol buffers used in an application or library cannot be predetermined.
+
+Below is a straightforward example on how to use this class:
+
+ pool = DescriptorPool()
+ file_descriptor_protos = [ ... ]
+ for file_descriptor_proto in file_descriptor_protos:
+ pool.Add(file_descriptor_proto)
+ my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
+
+The message descriptor can be used in conjunction with the message_factory
+module in order to create a protocol buffer class that can be encoded and
+decoded.
+
+If you want to get a Python class for the specified proto, use the
+helper functions inside google.protobuf.message_factory
+directly instead of this class.
+"""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+import sys
+
+from google.protobuf import descriptor
+from google.protobuf import descriptor_database
+from google.protobuf import text_encoding
+
+
+def _NormalizeFullyQualifiedName(name):
+ """Remove leading period from fully-qualified type name.
+
+ Due to b/13860351 in descriptor_database.py, types in the root namespace are
+ generated with a leading period. This function removes that prefix.
+
+ Args:
+ name: A str, the fully-qualified symbol name.
+
+ Returns:
+ A str, the normalized fully-qualified symbol name.
+ """
+ return name.lstrip('.')
+
+
+class DescriptorPool(object):
+ """A collection of protobufs dynamically constructed by descriptor protos."""
+
+ def __init__(self, descriptor_db=None):
+ """Initializes a Pool of proto buffs.
+
+ The descriptor_db argument to the constructor is provided to allow
+ specialized file descriptor proto lookup code to be triggered on demand. An
+ example would be an implementation which will read and compile a file
+ specified in a call to FindFileByName() and not require the call to Add()
+ at all. Results from this database will be cached internally here as well.
+
+ Args:
+ descriptor_db: A secondary source of file descriptors.
+ """
+
+ self._internal_db = descriptor_database.DescriptorDatabase()
+ self._descriptor_db = descriptor_db
+ self._descriptors = {}
+ self._enum_descriptors = {}
+ self._file_descriptors = {}
+
+ def Add(self, file_desc_proto):
+ """Adds the FileDescriptorProto and its types to this pool.
+
+ Args:
+ file_desc_proto: The FileDescriptorProto to add.
+ """
+
+ self._internal_db.Add(file_desc_proto)
+
+ def AddDescriptor(self, desc):
+ """Adds a Descriptor to the pool, non-recursively.
+
+ If the Descriptor contains nested messages or enums, the caller must
+ explicitly register them. This method also registers the FileDescriptor
+ associated with the message.
+
+ Args:
+ desc: A Descriptor.
+ """
+ if not isinstance(desc, descriptor.Descriptor):
+ raise TypeError('Expected instance of descriptor.Descriptor.')
+
+ self._descriptors[desc.full_name] = desc
+ self.AddFileDescriptor(desc.file)
+
+ def AddEnumDescriptor(self, enum_desc):
+ """Adds an EnumDescriptor to the pool.
+
+ This method also registers the FileDescriptor associated with the message.
+
+ Args:
+ enum_desc: An EnumDescriptor.
+ """
+
+ if not isinstance(enum_desc, descriptor.EnumDescriptor):
+ raise TypeError('Expected instance of descriptor.EnumDescriptor.')
+
+ self._enum_descriptors[enum_desc.full_name] = enum_desc
+ self.AddFileDescriptor(enum_desc.file)
+
+ def AddFileDescriptor(self, file_desc):
+ """Adds a FileDescriptor to the pool, non-recursively.
+
+ If the FileDescriptor contains messages or enums, the caller must explicitly
+ register them.
+
+ Args:
+ file_desc: A FileDescriptor.
+ """
+
+ if not isinstance(file_desc, descriptor.FileDescriptor):
+ raise TypeError('Expected instance of descriptor.FileDescriptor.')
+ self._file_descriptors[file_desc.name] = file_desc
+
+ def FindFileByName(self, file_name):
+ """Gets a FileDescriptor by file name.
+
+ Args:
+ file_name: The path to the file to get a descriptor for.
+
+ Returns:
+ A FileDescriptor for the named file.
+
+ Raises:
+ KeyError: if the file can not be found in the pool.
+ """
+
+ try:
+ return self._file_descriptors[file_name]
+ except KeyError:
+ pass
+
+ try:
+ file_proto = self._internal_db.FindFileByName(file_name)
+ except KeyError:
+ _, error, _ = sys.exc_info() #PY25 compatible for GAE.
+ if self._descriptor_db:
+ file_proto = self._descriptor_db.FindFileByName(file_name)
+ else:
+ raise error
+ if not file_proto:
+ raise KeyError('Cannot find a file named %s' % file_name)
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
+
+ def FindFileContainingSymbol(self, symbol):
+ """Gets the FileDescriptor for the file containing the specified symbol.
+
+ Args:
+ symbol: The name of the symbol to search for.
+
+ Returns:
+ A FileDescriptor that contains the specified symbol.
+
+ Raises:
+ KeyError: if the file can not be found in the pool.
+ """
+
+ symbol = _NormalizeFullyQualifiedName(symbol)
+ try:
+ return self._descriptors[symbol].file
+ except KeyError:
+ pass
+
+ try:
+ return self._enum_descriptors[symbol].file
+ except KeyError:
+ pass
+
+ try:
+ file_proto = self._internal_db.FindFileContainingSymbol(symbol)
+ except KeyError:
+ _, error, _ = sys.exc_info() #PY25 compatible for GAE.
+ if self._descriptor_db:
+ file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
+ else:
+ raise error
+ if not file_proto:
+ raise KeyError('Cannot find a file containing %s' % symbol)
+ return self._ConvertFileProtoToFileDescriptor(file_proto)
+
+ def FindMessageTypeByName(self, full_name):
+ """Loads the named descriptor from the pool.
+
+ Args:
+ full_name: The full name of the descriptor to load.
+
+ Returns:
+ The descriptor for the named type.
+ """
+
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ if full_name not in self._descriptors:
+ self.FindFileContainingSymbol(full_name)
+ return self._descriptors[full_name]
+
+ def FindEnumTypeByName(self, full_name):
+ """Loads the named enum descriptor from the pool.
+
+ Args:
+ full_name: The full name of the enum descriptor to load.
+
+ Returns:
+ The enum descriptor for the named type.
+ """
+
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ if full_name not in self._enum_descriptors:
+ self.FindFileContainingSymbol(full_name)
+ return self._enum_descriptors[full_name]
+
+ def _ConvertFileProtoToFileDescriptor(self, file_proto):
+ """Creates a FileDescriptor from a proto or returns a cached copy.
+
+ This method also has the side effect of loading all the symbols found in
+ the file into the appropriate dictionaries in the pool.
+
+ Args:
+ file_proto: The proto to convert.
+
+ Returns:
+ A FileDescriptor matching the passed in proto.
+ """
+
+ if file_proto.name not in self._file_descriptors:
+ built_deps = list(self._GetDeps(file_proto.dependency))
+ direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
+
+ file_descriptor = descriptor.FileDescriptor(
+ name=file_proto.name,
+ package=file_proto.package,
+ options=file_proto.options,
+ serialized_pb=file_proto.SerializeToString(),
+ dependencies=direct_deps)
+ scope = {}
+
+ # This loop extracts all the message and enum types from all the
+ # dependencoes of the file_proto. This is necessary to create the
+ # scope of available message types when defining the passed in
+ # file proto.
+ for dependency in built_deps:
+ scope.update(self._ExtractSymbols(
+ dependency.message_types_by_name.values()))
+ scope.update((_PrefixWithDot(enum.full_name), enum)
+ for enum in dependency.enum_types_by_name.values())
+
+ for message_type in file_proto.message_type:
+ message_desc = self._ConvertMessageDescriptor(
+ message_type, file_proto.package, file_descriptor, scope)
+ file_descriptor.message_types_by_name[message_desc.name] = message_desc
+
+ for enum_type in file_proto.enum_type:
+ file_descriptor.enum_types_by_name[enum_type.name] = (
+ self._ConvertEnumDescriptor(enum_type, file_proto.package,
+ file_descriptor, None, scope))
+
+ for index, extension_proto in enumerate(file_proto.extension):
+ extension_desc = self.MakeFieldDescriptor(
+ extension_proto, file_proto.package, index, is_extension=True)
+ extension_desc.containing_type = self._GetTypeFromScope(
+ file_descriptor.package, extension_proto.extendee, scope)
+ self.SetFieldType(extension_proto, extension_desc,
+ file_descriptor.package, scope)
+ file_descriptor.extensions_by_name[extension_desc.name] = extension_desc
+
+ for desc_proto in file_proto.message_type:
+ self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
+
+ if file_proto.package:
+ desc_proto_prefix = _PrefixWithDot(file_proto.package)
+ else:
+ desc_proto_prefix = ''
+
+ for desc_proto in file_proto.message_type:
+ desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope)
+ file_descriptor.message_types_by_name[desc_proto.name] = desc
+ self.Add(file_proto)
+ self._file_descriptors[file_proto.name] = file_descriptor
+
+ return self._file_descriptors[file_proto.name]
+
+ def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
+ scope=None):
+ """Adds the proto to the pool in the specified package.
+
+ Args:
+ desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
+ package: The package the proto should be located in.
+ file_desc: The file containing this message.
+ scope: Dict mapping short and full symbols to message and enum types.
+
+ Returns:
+ The added descriptor.
+ """
+
+ if package:
+ desc_name = '.'.join((package, desc_proto.name))
+ else:
+ desc_name = desc_proto.name
+
+ if file_desc is None:
+ file_name = None
+ else:
+ file_name = file_desc.name
+
+ if scope is None:
+ scope = {}
+
+ nested = [
+ self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope)
+ for nested in desc_proto.nested_type]
+ enums = [
+ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
+ for enum in desc_proto.enum_type]
+ fields = [self.MakeFieldDescriptor(field, desc_name, index)
+ for index, field in enumerate(desc_proto.field)]
+ extensions = [
+ self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True)
+ for index, extension in enumerate(desc_proto.extension)]
+ oneofs = [
+ descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
+ index, None, [])
+ for index, desc in enumerate(desc_proto.oneof_decl)]
+ extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
+ if extension_ranges:
+ is_extendable = True
+ else:
+ is_extendable = False
+ desc = descriptor.Descriptor(
+ name=desc_proto.name,
+ full_name=desc_name,
+ filename=file_name,
+ containing_type=None,
+ fields=fields,
+ oneofs=oneofs,
+ nested_types=nested,
+ enum_types=enums,
+ extensions=extensions,
+ options=desc_proto.options,
+ is_extendable=is_extendable,
+ extension_ranges=extension_ranges,
+ file=file_desc,
+ serialized_start=None,
+ serialized_end=None)
+ for nested in desc.nested_types:
+ nested.containing_type = desc
+ for enum in desc.enum_types:
+ enum.containing_type = desc
+ for field_index, field_desc in enumerate(desc_proto.field):
+ if field_desc.HasField('oneof_index'):
+ oneof_index = field_desc.oneof_index
+ oneofs[oneof_index].fields.append(fields[field_index])
+ fields[field_index].containing_oneof = oneofs[oneof_index]
+
+ scope[_PrefixWithDot(desc_name)] = desc
+ self._descriptors[desc_name] = desc
+ return desc
+
+ def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
+ containing_type=None, scope=None):
+ """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
+
+ Args:
+ enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
+ package: Optional package name for the new message EnumDescriptor.
+ file_desc: The file containing the enum descriptor.
+ containing_type: The type containing this enum.
+ scope: Scope containing available types.
+
+ Returns:
+ The added descriptor
+ """
+
+ if package:
+ enum_name = '.'.join((package, enum_proto.name))
+ else:
+ enum_name = enum_proto.name
+
+ if file_desc is None:
+ file_name = None
+ else:
+ file_name = file_desc.name
+
+ values = [self._MakeEnumValueDescriptor(value, index)
+ for index, value in enumerate(enum_proto.value)]
+ desc = descriptor.EnumDescriptor(name=enum_proto.name,
+ full_name=enum_name,
+ filename=file_name,
+ file=file_desc,
+ values=values,
+ containing_type=containing_type,
+ options=enum_proto.options)
+ scope['.%s' % enum_name] = desc
+ self._enum_descriptors[enum_name] = desc
+ return desc
+
+ def MakeFieldDescriptor(self, field_proto, message_name, index,
+ is_extension=False):
+ """Creates a field descriptor from a FieldDescriptorProto.
+
+ For message and enum type fields, this method will do a look up
+ in the pool for the appropriate descriptor for that type. If it
+ is unavailable, it will fall back to the _source function to
+ create it. If this type is still unavailable, construction will
+ fail.
+
+ Args:
+ field_proto: The proto describing the field.
+ message_name: The name of the containing message.
+ index: Index of the field
+ is_extension: Indication that this field is for an extension.
+
+ Returns:
+ An initialized FieldDescriptor object
+ """
+
+ if message_name:
+ full_name = '.'.join((message_name, field_proto.name))
+ else:
+ full_name = field_proto.name
+
+ return descriptor.FieldDescriptor(
+ name=field_proto.name,
+ full_name=full_name,
+ index=index,
+ number=field_proto.number,
+ type=field_proto.type,
+ cpp_type=None,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ label=field_proto.label,
+ has_default_value=False,
+ default_value=None,
+ is_extension=is_extension,
+ extension_scope=None,
+ options=field_proto.options)
+
+ def SetAllFieldTypes(self, package, desc_proto, scope):
+ """Sets all the descriptor's fields's types.
+
+ This method also sets the containing types on any extensions.
+
+ Args:
+ package: The current package of desc_proto.
+ desc_proto: The message descriptor to update.
+ scope: Enclosing scope of available types.
+ """
+
+ package = _PrefixWithDot(package)
+
+ main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
+
+ if package == '.':
+ nested_package = _PrefixWithDot(desc_proto.name)
+ else:
+ nested_package = '.'.join([package, desc_proto.name])
+
+ for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
+ self.SetFieldType(field_proto, field_desc, nested_package, scope)
+
+ for extension_proto, extension_desc in (
+ zip(desc_proto.extension, main_desc.extensions)):
+ extension_desc.containing_type = self._GetTypeFromScope(
+ nested_package, extension_proto.extendee, scope)
+ self.SetFieldType(extension_proto, extension_desc, nested_package, scope)
+
+ for nested_type in desc_proto.nested_type:
+ self.SetAllFieldTypes(nested_package, nested_type, scope)
+
+ def SetFieldType(self, field_proto, field_desc, package, scope):
+ """Sets the field's type, cpp_type, message_type and enum_type.
+
+ Args:
+ field_proto: Data about the field in proto format.
+ field_desc: The descriptor to modiy.
+ package: The package the field's container is in.
+ scope: Enclosing scope of available types.
+ """
+ if field_proto.type_name:
+ desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
+ else:
+ desc = None
+
+ if not field_proto.HasField('type'):
+ if isinstance(desc, descriptor.Descriptor):
+ field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
+ else:
+ field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
+
+ field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
+ field_proto.type)
+
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
+ or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
+ field_desc.message_type = desc
+
+ if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ field_desc.enum_type = desc
+
+ if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ field_desc.has_default_value = False
+ field_desc.default_value = []
+ elif field_proto.HasField('default_value'):
+ field_desc.has_default_value = True
+ if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
+ field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
+ field_desc.default_value = float(field_proto.default_value)
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
+ field_desc.default_value = field_proto.default_value
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ field_desc.default_value = field_proto.default_value.lower() == 'true'
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ field_desc.default_value = field_desc.enum_type.values_by_name[
+ field_proto.default_value].index
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ field_desc.default_value = text_encoding.CUnescape(
+ field_proto.default_value)
+ else:
+ field_desc.default_value = int(field_proto.default_value)
+ else:
+ field_desc.has_default_value = False
+ field_desc.default_value = None
+
+ field_desc.type = field_proto.type
+
+ def _MakeEnumValueDescriptor(self, value_proto, index):
+ """Creates a enum value descriptor object from a enum value proto.
+
+ Args:
+ value_proto: The proto describing the enum value.
+ index: The index of the enum value.
+
+ Returns:
+ An initialized EnumValueDescriptor object.
+ """
+
+ return descriptor.EnumValueDescriptor(
+ name=value_proto.name,
+ index=index,
+ number=value_proto.number,
+ options=value_proto.options,
+ type=None)
+
+ def _ExtractSymbols(self, descriptors):
+ """Pulls out all the symbols from descriptor protos.
+
+ Args:
+ descriptors: The messages to extract descriptors from.
+ Yields:
+ A two element tuple of the type name and descriptor object.
+ """
+
+ for desc in descriptors:
+ yield (_PrefixWithDot(desc.full_name), desc)
+ for symbol in self._ExtractSymbols(desc.nested_types):
+ yield symbol
+ for enum in desc.enum_types:
+ yield (_PrefixWithDot(enum.full_name), enum)
+
+ def _GetDeps(self, dependencies):
+ """Recursively finds dependencies for file protos.
+
+ Args:
+ dependencies: The names of the files being depended on.
+
+ Yields:
+ Each direct and indirect dependency.
+ """
+
+ for dependency in dependencies:
+ dep_desc = self.FindFileByName(dependency)
+ yield dep_desc
+ for parent_dep in dep_desc.dependencies:
+ yield parent_dep
+
+ def _GetTypeFromScope(self, package, type_name, scope):
+ """Finds a given type name in the current scope.
+
+ Args:
+ package: The package the proto should be located in.
+ type_name: The name of the type to be found in the scope.
+ scope: Dict mapping short and full symbols to message and enum types.
+
+ Returns:
+ The descriptor for the requested type.
+ """
+ if type_name not in scope:
+ components = _PrefixWithDot(package).split('.')
+ while components:
+ possible_match = '.'.join(components + [type_name])
+ if possible_match in scope:
+ type_name = possible_match
+ break
+ else:
+ components.pop(-1)
+ return scope[type_name]
+
+
+def _PrefixWithDot(name):
+ return name if name.startswith('.') else '.%s' % name
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc
new file mode 100644
index 0000000..ad6fd9c
--- /dev/null
+++ b/python/google/protobuf/internal/api_implementation.cc
@@ -0,0 +1,139 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#include <Python.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+// Version constant.
+// This is either 0 for python, 1 for CPP V1, 2 for CPP V2.
+//
+// 0 is default and is equivalent to
+// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+//
+// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to
+// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
+// and
+// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1
+//
+// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to
+// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
+// and
+// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
+#ifdef PYTHON_PROTO2_CPP_IMPL_V1
+#if PY_MAJOR_VERSION >= 3
+#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3."
+#endif
+static int kImplVersion = 1;
+#else
+#ifdef PYTHON_PROTO2_CPP_IMPL_V2
+static int kImplVersion = 2;
+#else
+#ifdef PYTHON_PROTO2_PYTHON_IMPL
+static int kImplVersion = 0;
+#else
+
+// The defaults are set here. Python 3 uses the fast C++ APIv2 by default.
+// Python 2 still uses the Python version by default until some compatibility
+// issues can be worked around.
+#if PY_MAJOR_VERSION >= 3
+static int kImplVersion = 2;
+#else
+static int kImplVersion = 0;
+#endif
+
+#endif // PYTHON_PROTO2_PYTHON_IMPL
+#endif // PYTHON_PROTO2_CPP_IMPL_V2
+#endif // PYTHON_PROTO2_CPP_IMPL_V1
+
+static const char* kImplVersionName = "api_version";
+
+static const char* kModuleName = "_api_implementation";
+static const char kModuleDocstring[] =
+"_api_implementation is a module that exposes compile-time constants that\n"
+"determine the default API implementation to use for Python proto2.\n"
+"\n"
+"It complements api_implementation.py by setting defaults using compile-time\n"
+"constants defined in C, such that one can set defaults at compilation\n"
+"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2).";
+
+#if PY_MAJOR_VERSION >= 3
+static struct PyModuleDef _module = {
+ PyModuleDef_HEAD_INIT,
+ kModuleName,
+ kModuleDocstring,
+ -1,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+#define INITFUNC PyInit__api_implementation
+#define INITFUNC_ERRORVAL NULL
+#else
+#define INITFUNC init_api_implementation
+#define INITFUNC_ERRORVAL
+#endif
+
+extern "C" {
+ PyMODINIT_FUNC INITFUNC() {
+#if PY_MAJOR_VERSION >= 3
+ PyObject *module = PyModule_Create(&_module);
+#else
+ PyObject *module = Py_InitModule3(
+ const_cast<char*>(kModuleName),
+ NULL,
+ const_cast<char*>(kModuleDocstring));
+#endif
+ if (module == NULL) {
+ return INITFUNC_ERRORVAL;
+ }
+
+ // Adds the module variable "api_version".
+ if (PyModule_AddIntConstant(
+ module,
+ const_cast<char*>(kImplVersionName),
+ kImplVersion))
+#if PY_MAJOR_VERSION < 3
+ return;
+#else
+ { Py_DECREF(module); return NULL; }
+
+ return module;
+#endif
+ }
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
new file mode 100755
index 0000000..cbb8574
--- /dev/null
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -0,0 +1,89 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Determine which implementation of the protobuf API is used in this process.
+"""
+
+import os
+import sys
+
+try:
+ # pylint: disable=g-import-not-at-top
+ from google.protobuf.internal import _api_implementation
+ # The compile-time constants in the _api_implementation module can be used to
+ # switch to a certain implementation of the Python API at build time.
+ _api_version = _api_implementation.api_version
+ del _api_implementation
+except ImportError:
+ _api_version = 0
+
+_default_implementation_type = (
+ 'python' if _api_version == 0 else 'cpp')
+_default_version_str = (
+ '1' if _api_version <= 1 else '2')
+
+# This environment variable can be used to switch to a certain implementation
+# of the Python API, overriding the compile-time constants in the
+# _api_implementation module. Right now only 'python' and 'cpp' are valid
+# values. Any other value will be ignored.
+_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION',
+ _default_implementation_type)
+
+if _implementation_type != 'python':
+ _implementation_type = 'cpp'
+
+# This environment variable can be used to switch between the two
+# 'cpp' implementations, overriding the compile-time constants in the
+# _api_implementation module. Right now only 1 and 2 are valid values. Any other
+# value will be ignored.
+_implementation_version_str = os.getenv(
+ 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION',
+ _default_version_str)
+
+if _implementation_version_str not in ('1', '2'):
+ raise ValueError(
+ "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" +
+ _implementation_version_str + "' (supported versions: 1, 2)"
+ )
+
+_implementation_version = int(_implementation_version_str)
+
+
+# Usage of this function is discouraged. Clients shouldn't care which
+# implementation of the API is in use. Note that there is no guarantee
+# that differences between APIs will be maintained.
+# Please don't use this function if possible.
+def Type():
+ return _implementation_type
+
+
+# See comment on 'Type' above.
+def Version():
+ return _implementation_version
diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py
new file mode 100644
index 0000000..b2b4128
--- /dev/null
+++ b/python/google/protobuf/internal/api_implementation_default_test.py
@@ -0,0 +1,63 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test that the api_implementation defaults are what we expect."""
+
+import os
+import sys
+# Clear environment implementation settings before the google3 imports.
+os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None)
+os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None)
+
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+
+
+class ApiImplementationDefaultTest(basetest.TestCase):
+
+ if sys.version_info.major <= 2:
+
+ def testThatPythonIsTheDefault(self):
+ """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
+ self.assertEqual('python', api_implementation.Type())
+
+ else:
+
+ def testThatCppApiV2IsTheDefault(self):
+ """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail."""
+ self.assertEqual('cpp', api_implementation.Type())
+ self.assertEqual(2, api_implementation.Version())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 5cc7d6d..5797e81 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -72,9 +72,20 @@ class BaseContainer(object):
# The concrete classes should define __eq__.
return not self == other
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
def __repr__(self):
return repr(self._values)
+ def sort(self, *args, **kwargs):
+ # Continue to support the old sort_function keyword argument.
+ # This is expected to be a rare occurrence, so use LBYL to avoid
+ # the overhead of actually catching KeyError.
+ if 'sort_function' in kwargs:
+ kwargs['cmp'] = kwargs.pop('sort_function')
+ self._values.sort(*args, **kwargs)
+
class RepeatedScalarFieldContainer(BaseContainer):
@@ -97,15 +108,13 @@ class RepeatedScalarFieldContainer(BaseContainer):
def append(self, value):
"""Appends an item to the list. Similar to list.append()."""
- self._type_checker.CheckValue(value)
- self._values.append(value)
+ self._values.append(self._type_checker.CheckValue(value))
if not self._message_listener.dirty:
self._message_listener.Modified()
def insert(self, key, value):
"""Inserts the item at the specified position. Similar to list.insert()."""
- self._type_checker.CheckValue(value)
- self._values.insert(key, value)
+ self._values.insert(key, self._type_checker.CheckValue(value))
if not self._message_listener.dirty:
self._message_listener.Modified()
@@ -116,8 +125,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
new_values = []
for elem in elem_seq:
- self._type_checker.CheckValue(elem)
- new_values.append(elem)
+ new_values.append(self._type_checker.CheckValue(elem))
self._values.extend(new_values)
self._message_listener.Modified()
@@ -135,9 +143,13 @@ class RepeatedScalarFieldContainer(BaseContainer):
def __setitem__(self, key, value):
"""Sets the item on the specified position."""
- self._type_checker.CheckValue(value)
- self._values[key] = value
- self._message_listener.Modified()
+ if isinstance(key, slice): # PY3
+ if key.step is not None:
+ raise ValueError('Extended slices not supported')
+ self.__setslice__(key.start, key.stop, value)
+ else:
+ self._values[key] = self._type_checker.CheckValue(value)
+ self._message_listener.Modified()
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
@@ -147,8 +159,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
"""Sets the subset of items from between the specified indices."""
new_values = []
for value in values:
- self._type_checker.CheckValue(value)
- new_values.append(value)
+ new_values.append(self._type_checker.CheckValue(value))
self._values[start:stop] = new_values
self._message_listener.Modified()
@@ -198,28 +209,42 @@ class RepeatedCompositeFieldContainer(BaseContainer):
super(RepeatedCompositeFieldContainer, self).__init__(message_listener)
self._message_descriptor = message_descriptor
- def add(self):
- new_element = self._message_descriptor._concrete_class()
+ def add(self, **kwargs):
+ """Adds a new element at the end of the list and returns it. Keyword
+ arguments may be used to initialize the element.
+ """
+ new_element = self._message_descriptor._concrete_class(**kwargs)
new_element._SetListener(self._message_listener)
self._values.append(new_element)
if not self._message_listener.dirty:
self._message_listener.Modified()
return new_element
- def MergeFrom(self, other):
- """Appends the contents of another repeated field of the same type to this
- one, copying each individual message.
+ def extend(self, elem_seq):
+ """Extends by appending the given sequence of elements of the same type
+ as this one, copying each individual message.
"""
message_class = self._message_descriptor._concrete_class
listener = self._message_listener
values = self._values
- for message in other._values:
+ for message in elem_seq:
new_element = message_class()
new_element._SetListener(listener)
new_element.MergeFrom(message)
values.append(new_element)
listener.Modified()
+ def MergeFrom(self, other):
+ """Appends the contents of another repeated field of the same type to this
+ one, copying each individual message.
+ """
+ self.extend(other._values)
+
+ def remove(self, elem):
+ """Removes an item from the list. Similar to list.remove()."""
+ self._values.remove(elem)
+ self._message_listener.Modified()
+
def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices."""
return self._values[start:stop]
diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py
new file mode 100755
index 0000000..8eb38ca
--- /dev/null
+++ b/python/google/protobuf/internal/cpp_message.py
@@ -0,0 +1,663 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains helper functions used to create protocol message classes from
+Descriptor objects at runtime backed by the protocol buffer C++ API.
+"""
+
+__author__ = 'petar@google.com (Petar Petrov)'
+
+import copy_reg
+import operator
+from google.protobuf.internal import _net_proto2___python
+from google.protobuf.internal import enum_type_wrapper
+from google.protobuf import message
+
+
+_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED
+_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL
+_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE
+_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE
+
+
+def GetDescriptorPool():
+ """Creates a new DescriptorPool C++ object."""
+ return _net_proto2___python.NewCDescriptorPool()
+
+
+_pool = GetDescriptorPool()
+
+
+def GetFieldDescriptor(full_field_name):
+ """Searches for a field descriptor given a full field name."""
+ return _pool.FindFieldByName(full_field_name)
+
+
+def BuildFile(content):
+ """Registers a new proto file in the underlying C++ descriptor pool."""
+ _net_proto2___python.BuildFile(content)
+
+
+def GetExtensionDescriptor(full_extension_name):
+ """Searches for extension descriptor given a full field name."""
+ return _pool.FindExtensionByName(full_extension_name)
+
+
+def NewCMessage(full_message_name):
+ """Creates a new C++ protocol message by its name."""
+ return _net_proto2___python.NewCMessage(full_message_name)
+
+
+def ScalarProperty(cdescriptor):
+ """Returns a scalar property for the given descriptor."""
+
+ def Getter(self):
+ return self._cmsg.GetScalar(cdescriptor)
+
+ def Setter(self, value):
+ self._cmsg.SetScalar(cdescriptor, value)
+
+ return property(Getter, Setter)
+
+
+def CompositeProperty(cdescriptor, message_type):
+ """Returns a Python property the given composite field."""
+
+ def Getter(self):
+ sub_message = self._composite_fields.get(cdescriptor.name, None)
+ if sub_message is None:
+ cmessage = self._cmsg.NewSubMessage(cdescriptor)
+ sub_message = message_type._concrete_class(__cmessage=cmessage)
+ self._composite_fields[cdescriptor.name] = sub_message
+ return sub_message
+
+ return property(Getter)
+
+
+class RepeatedScalarContainer(object):
+ """Container for repeated scalar fields."""
+
+ __slots__ = ['_message', '_cfield_descriptor', '_cmsg']
+
+ def __init__(self, msg, cfield_descriptor):
+ self._message = msg
+ self._cmsg = msg._cmsg
+ self._cfield_descriptor = cfield_descriptor
+
+ def append(self, value):
+ self._cmsg.AddRepeatedScalar(
+ self._cfield_descriptor, value)
+
+ def extend(self, sequence):
+ for element in sequence:
+ self.append(element)
+
+ def insert(self, key, value):
+ values = self[slice(None, None, None)]
+ values.insert(key, value)
+ self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
+
+ def remove(self, value):
+ values = self[slice(None, None, None)]
+ values.remove(value)
+ self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
+
+ def __setitem__(self, key, value):
+ values = self[slice(None, None, None)]
+ values[key] = value
+ self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values)
+
+ def __getitem__(self, key):
+ return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key)
+
+ def __delitem__(self, key):
+ self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key)
+
+ def __len__(self):
+ return len(self[slice(None, None, None)])
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not operator.isSequenceType(other):
+ raise TypeError(
+ 'Can only compare repeated scalar fields against sequences.')
+ # We are presumably comparing against some other sequence type.
+ return other == self[slice(None, None, None)]
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
+ def sort(self, *args, **kwargs):
+ # Maintain compatibility with the previous interface.
+ if 'sort_function' in kwargs:
+ kwargs['cmp'] = kwargs.pop('sort_function')
+ self._cmsg.AssignRepeatedScalar(self._cfield_descriptor,
+ sorted(self, *args, **kwargs))
+
+
+def RepeatedScalarProperty(cdescriptor):
+ """Returns a Python property the given repeated scalar field."""
+
+ def Getter(self):
+ container = self._composite_fields.get(cdescriptor.name, None)
+ if container is None:
+ container = RepeatedScalarContainer(self, cdescriptor)
+ self._composite_fields[cdescriptor.name] = container
+ return container
+
+ def Setter(self, new_value):
+ raise AttributeError('Assignment not allowed to repeated field '
+ '"%s" in protocol message object.' % cdescriptor.name)
+
+ doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
+ return property(Getter, Setter, doc=doc)
+
+
+class RepeatedCompositeContainer(object):
+ """Container for repeated composite fields."""
+
+ __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg']
+
+ def __init__(self, msg, cfield_descriptor, subclass):
+ self._message = msg
+ self._cmsg = msg._cmsg
+ self._subclass = subclass
+ self._cfield_descriptor = cfield_descriptor
+
+ def add(self, **kwargs):
+ cmessage = self._cmsg.AddMessage(self._cfield_descriptor)
+ return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs)
+
+ def extend(self, elem_seq):
+ """Extends by appending the given sequence of elements of the same type
+ as this one, copying each individual message.
+ """
+ for message in elem_seq:
+ self.add().MergeFrom(message)
+
+ def remove(self, value):
+ # TODO(protocol-devel): This is inefficient as it needs to generate a
+ # message pointer for each message only to do index(). Move this to a C++
+ # extension function.
+ self.__delitem__(self[slice(None, None, None)].index(value))
+
+ def MergeFrom(self, other):
+ for message in other[:]:
+ self.add().MergeFrom(message)
+
+ def __getitem__(self, key):
+ cmessages = self._cmsg.GetRepeatedMessage(
+ self._cfield_descriptor, key)
+ subclass = self._subclass
+ if not isinstance(cmessages, list):
+ return subclass(__cmessage=cmessages, __owner=self._message)
+
+ return [subclass(__cmessage=m, __owner=self._message) for m in cmessages]
+
+ def __delitem__(self, key):
+ self._cmsg.DeleteRepeatedField(
+ self._cfield_descriptor, key)
+
+ def __len__(self):
+ return self._cmsg.FieldLength(self._cfield_descriptor)
+
+ def __eq__(self, other):
+ """Compares the current instance with another one."""
+ if self is other:
+ return True
+ if not isinstance(other, self.__class__):
+ raise TypeError('Can only compare repeated composite fields against '
+ 'other repeated composite fields.')
+ messages = self[slice(None, None, None)]
+ other_messages = other[slice(None, None, None)]
+ return messages == other_messages
+
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
+ def sort(self, cmp=None, key=None, reverse=False, **kwargs):
+ # Maintain compatibility with the old interface.
+ if cmp is None and 'sort_function' in kwargs:
+ cmp = kwargs.pop('sort_function')
+
+ # The cmp function, if provided, is passed the results of the key function,
+ # so we only need to wrap one of them.
+ if key is None:
+ index_key = self.__getitem__
+ else:
+ index_key = lambda i: key(self[i])
+
+ # Sort the list of current indexes by the underlying object.
+ indexes = range(len(self))
+ indexes.sort(cmp=cmp, key=index_key, reverse=reverse)
+
+ # Apply the transposition.
+ for dest, src in enumerate(indexes):
+ if dest == src:
+ continue
+ self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src)
+ # Don't swap the same value twice.
+ indexes[src] = src
+
+
+def RepeatedCompositeProperty(cdescriptor, message_type):
+ """Returns a Python property for the given repeated composite field."""
+
+ def Getter(self):
+ container = self._composite_fields.get(cdescriptor.name, None)
+ if container is None:
+ container = RepeatedCompositeContainer(
+ self, cdescriptor, message_type._concrete_class)
+ self._composite_fields[cdescriptor.name] = container
+ return container
+
+ def Setter(self, new_value):
+ raise AttributeError('Assignment not allowed to repeated field '
+ '"%s" in protocol message object.' % cdescriptor.name)
+
+ doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name
+ return property(Getter, Setter, doc=doc)
+
+
+class ExtensionDict(object):
+ """Extension dictionary added to each protocol message."""
+
+ def __init__(self, msg):
+ self._message = msg
+ self._cmsg = msg._cmsg
+ self._values = {}
+
+ def __setitem__(self, extension, value):
+ from google.protobuf import descriptor
+ if not isinstance(extension, descriptor.FieldDescriptor):
+ raise KeyError('Bad extension %r.' % (extension,))
+ cdescriptor = extension._cdescriptor
+ if (cdescriptor.label != _LABEL_OPTIONAL or
+ cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
+ raise TypeError('Extension %r is repeated and/or a composite type.' % (
+ extension.full_name,))
+ self._cmsg.SetScalar(cdescriptor, value)
+ self._values[extension] = value
+
+ def __getitem__(self, extension):
+ from google.protobuf import descriptor
+ if not isinstance(extension, descriptor.FieldDescriptor):
+ raise KeyError('Bad extension %r.' % (extension,))
+
+ cdescriptor = extension._cdescriptor
+ if (cdescriptor.label != _LABEL_REPEATED and
+ cdescriptor.cpp_type != _CPPTYPE_MESSAGE):
+ return self._cmsg.GetScalar(cdescriptor)
+
+ ext = self._values.get(extension, None)
+ if ext is not None:
+ return ext
+
+ ext = self._CreateNewHandle(extension)
+ self._values[extension] = ext
+ return ext
+
+ def ClearExtension(self, extension):
+ from google.protobuf import descriptor
+ if not isinstance(extension, descriptor.FieldDescriptor):
+ raise KeyError('Bad extension %r.' % (extension,))
+ self._cmsg.ClearFieldByDescriptor(extension._cdescriptor)
+ if extension in self._values:
+ del self._values[extension]
+
+ def HasExtension(self, extension):
+ from google.protobuf import descriptor
+ if not isinstance(extension, descriptor.FieldDescriptor):
+ raise KeyError('Bad extension %r.' % (extension,))
+ return self._cmsg.HasFieldByDescriptor(extension._cdescriptor)
+
+ def _FindExtensionByName(self, name):
+ """Tries to find a known extension with the specified name.
+
+ Args:
+ name: Extension full name.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._message._extensions_by_name.get(name, None)
+
+ def _CreateNewHandle(self, extension):
+ cdescriptor = extension._cdescriptor
+ if (cdescriptor.label != _LABEL_REPEATED and
+ cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
+ cmessage = self._cmsg.NewSubMessage(cdescriptor)
+ return extension.message_type._concrete_class(__cmessage=cmessage)
+
+ if cdescriptor.label == _LABEL_REPEATED:
+ if cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
+ return RepeatedCompositeContainer(
+ self._message, cdescriptor, extension.message_type._concrete_class)
+ else:
+ return RepeatedScalarContainer(self._message, cdescriptor)
+ # This shouldn't happen!
+ assert False
+ return None
+
+
+def NewMessage(bases, message_descriptor, dictionary):
+ """Creates a new protocol message *class*."""
+ _AddClassAttributesForNestedExtensions(message_descriptor, dictionary)
+ _AddEnumValues(message_descriptor, dictionary)
+ _AddDescriptors(message_descriptor, dictionary)
+ return bases
+
+
+def InitMessage(message_descriptor, cls):
+ """Constructs a new message instance (called before instance's __init__)."""
+ cls._extensions_by_name = {}
+ _AddInitMethod(message_descriptor, cls)
+ _AddMessageMethods(message_descriptor, cls)
+ _AddPropertiesForExtensions(message_descriptor, cls)
+ copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
+
+
+def _AddDescriptors(message_descriptor, dictionary):
+ """Sets up a new protocol message class dictionary.
+
+ Args:
+ message_descriptor: A Descriptor instance describing this message type.
+ dictionary: Class dictionary to which we'll add a '__slots__' entry.
+ """
+ dictionary['__descriptors'] = {}
+ for field in message_descriptor.fields:
+ dictionary['__descriptors'][field.name] = GetFieldDescriptor(
+ field.full_name)
+
+ dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [
+ '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS']
+
+
+def _AddEnumValues(message_descriptor, dictionary):
+ """Sets class-level attributes for all enum fields defined in this message.
+
+ Args:
+ message_descriptor: Descriptor object for this message type.
+ dictionary: Class dictionary that should be populated.
+ """
+ for enum_type in message_descriptor.enum_types:
+ dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type)
+ for enum_value in enum_type.values:
+ dictionary[enum_value.name] = enum_value.number
+
+
+def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary):
+ """Adds class attributes for the nested extensions."""
+ extension_dict = message_descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ assert extension_name not in dictionary
+ dictionary[extension_name] = extension_field
+
+
+def _AddInitMethod(message_descriptor, cls):
+ """Adds an __init__ method to cls."""
+
+ # Create and attach message field properties to the message class.
+ # This can be done just once per message class, since property setters and
+ # getters are passed the message instance.
+ # This makes message instantiation extremely fast, and at the same time it
+ # doesn't require the creation of property objects for each message instance,
+ # which saves a lot of memory.
+ for field in message_descriptor.fields:
+ field_cdescriptor = cls.__descriptors[field.name]
+ if field.label == _LABEL_REPEATED:
+ if field.cpp_type == _CPPTYPE_MESSAGE:
+ value = RepeatedCompositeProperty(field_cdescriptor, field.message_type)
+ else:
+ value = RepeatedScalarProperty(field_cdescriptor)
+ elif field.cpp_type == _CPPTYPE_MESSAGE:
+ value = CompositeProperty(field_cdescriptor, field.message_type)
+ else:
+ value = ScalarProperty(field_cdescriptor)
+ setattr(cls, field.name, value)
+
+ # Attach a constant with the field number.
+ constant_name = field.name.upper() + '_FIELD_NUMBER'
+ setattr(cls, constant_name, field.number)
+
+ def Init(self, **kwargs):
+ """Message constructor."""
+ cmessage = kwargs.pop('__cmessage', None)
+ if cmessage:
+ self._cmsg = cmessage
+ else:
+ self._cmsg = NewCMessage(message_descriptor.full_name)
+
+ # Keep a reference to the owner, as the owner keeps a reference to the
+ # underlying protocol buffer message.
+ owner = kwargs.pop('__owner', None)
+ if owner:
+ self._owner = owner
+
+ if message_descriptor.is_extendable:
+ self.Extensions = ExtensionDict(self)
+ else:
+ # Reference counting in the C++ code is broken and depends on
+ # the Extensions reference to keep this object alive during unit
+ # tests (see b/4856052). Remove this once b/4945904 is fixed.
+ self._HACK_REFCOUNTS = self
+ self._composite_fields = {}
+
+ for field_name, field_value in kwargs.iteritems():
+ field_cdescriptor = self.__descriptors.get(field_name, None)
+ if not field_cdescriptor:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+ if field_cdescriptor.label == _LABEL_REPEATED:
+ if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
+ field_name = getattr(self, field_name)
+ for val in field_value:
+ field_name.add().MergeFrom(val)
+ else:
+ getattr(self, field_name).extend(field_value)
+ elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE:
+ getattr(self, field_name).MergeFrom(field_value)
+ else:
+ setattr(self, field_name, field_value)
+
+ Init.__module__ = None
+ Init.__doc__ = None
+ cls.__init__ = Init
+
+
+def _IsMessageSetExtension(field):
+ """Checks if a field is a message set extension."""
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == _TYPE_MESSAGE and
+ field.message_type == field.extension_scope and
+ field.label == _LABEL_OPTIONAL)
+
+
+def _AddMessageMethods(message_descriptor, cls):
+ """Adds the methods to a protocol message class."""
+ if message_descriptor.is_extendable:
+
+ def ClearExtension(self, extension):
+ self.Extensions.ClearExtension(extension)
+
+ def HasExtension(self, extension):
+ return self.Extensions.HasExtension(extension)
+
+ def HasField(self, field_name):
+ return self._cmsg.HasField(field_name)
+
+ def ClearField(self, field_name):
+ child_cmessage = None
+ if field_name in self._composite_fields:
+ child_field = self._composite_fields[field_name]
+ del self._composite_fields[field_name]
+
+ child_cdescriptor = self.__descriptors[field_name]
+ # TODO(anuraag): Support clearing repeated message fields as well.
+ if (child_cdescriptor.label != _LABEL_REPEATED and
+ child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
+ child_field._owner = None
+ child_cmessage = child_field._cmsg
+
+ if child_cmessage is not None:
+ self._cmsg.ClearField(field_name, child_cmessage)
+ else:
+ self._cmsg.ClearField(field_name)
+
+ def Clear(self):
+ cmessages_to_release = []
+ for field_name, child_field in self._composite_fields.iteritems():
+ child_cdescriptor = self.__descriptors[field_name]
+ # TODO(anuraag): Support clearing repeated message fields as well.
+ if (child_cdescriptor.label != _LABEL_REPEATED and
+ child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE):
+ child_field._owner = None
+ cmessages_to_release.append((child_cdescriptor, child_field._cmsg))
+ self._composite_fields.clear()
+ self._cmsg.Clear(cmessages_to_release)
+
+ def IsInitialized(self, errors=None):
+ if self._cmsg.IsInitialized():
+ return True
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors());
+ return False
+
+ def SerializeToString(self):
+ if not self.IsInitialized():
+ raise message.EncodeError(
+ 'Message %s is missing required fields: %s' % (
+ self._cmsg.full_name, ','.join(self.FindInitializationErrors())))
+ return self._cmsg.SerializeToString()
+
+ def SerializePartialToString(self):
+ return self._cmsg.SerializePartialToString()
+
+ def ParseFromString(self, serialized):
+ self.Clear()
+ self.MergeFromString(serialized)
+
+ def MergeFromString(self, serialized):
+ byte_size = self._cmsg.MergeFromString(serialized)
+ if byte_size < 0:
+ raise message.DecodeError('Unable to merge from string.')
+ return byte_size
+
+ def MergeFrom(self, msg):
+ if not isinstance(msg, cls):
+ raise TypeError(
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s." % (cls.__name__, type(msg).__name__))
+ self._cmsg.MergeFrom(msg._cmsg)
+
+ def CopyFrom(self, msg):
+ self._cmsg.CopyFrom(msg._cmsg)
+
+ def ByteSize(self):
+ return self._cmsg.ByteSize()
+
+ def SetInParent(self):
+ return self._cmsg.SetInParent()
+
+ def ListFields(self):
+ all_fields = []
+ field_list = self._cmsg.ListFields()
+ fields_by_name = cls.DESCRIPTOR.fields_by_name
+ for is_extension, field_name in field_list:
+ if is_extension:
+ extension = cls._extensions_by_name[field_name]
+ all_fields.append((extension, self.Extensions[extension]))
+ else:
+ field_descriptor = fields_by_name[field_name]
+ all_fields.append(
+ (field_descriptor, getattr(self, field_name)))
+ all_fields.sort(key=lambda item: item[0].number)
+ return all_fields
+
+ def FindInitializationErrors(self):
+ return self._cmsg.FindInitializationErrors()
+
+ def __str__(self):
+ return str(self._cmsg)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ if not isinstance(other, self.__class__):
+ return False
+ return self.ListFields() == other.ListFields()
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
+ def __unicode__(self):
+ # Lazy import to prevent circular import when text_format imports this file.
+ from google.protobuf import text_format
+ return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
+
+ # Attach the local methods to the message class.
+ for key, value in locals().copy().iteritems():
+ if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'):
+ setattr(cls, key, value)
+
+ # Static methods:
+
+ def RegisterExtension(extension_handle):
+ extension_handle.containing_type = cls.DESCRIPTOR
+ cls._extensions_by_name[extension_handle.full_name] = extension_handle
+
+ if _IsMessageSetExtension(extension_handle):
+ # MessageSet extension. Also register under type name.
+ cls._extensions_by_name[
+ extension_handle.message_type.full_name] = extension_handle
+ cls.RegisterExtension = staticmethod(RegisterExtension)
+
+ def FromString(string):
+ msg = cls()
+ msg.MergeFromString(string)
+ return msg
+ cls.FromString = staticmethod(FromString)
+
+
+
+def _AddPropertiesForExtensions(message_descriptor, cls):
+ """Adds properties for all fields in this protocol message type."""
+ extension_dict = message_descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ constant_name = extension_name.upper() + '_FIELD_NUMBER'
+ setattr(cls, constant_name, extension_field.number)
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 461a30c..651ee0d 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#PY25 compatible for GAE.
+#
+# Copyright 2009 Google Inc. All Rights Reserved.
+
"""Code for decoding protocol buffer primitives.
This code is very similar to encoder.py -- read the docs for that module first.
@@ -81,17 +85,26 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
+import sys ##PY25
+_PY2 = sys.version_info[0] < 3 ##PY25
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
+# This will overflow and thus become IEEE-754 "infinity". We would use
+# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
+_POS_INF = 1e10000
+_NEG_INF = -_POS_INF
+_NAN = _POS_INF * 0
+
+
# This is not for optimization, but rather to avoid conflicts with local
# variables named "message".
_DecodeError = message.DecodeError
-def _VarintDecoder(mask):
+def _VarintDecoder(mask, result_type):
"""Return an encoder for a basic varint value (does not include tag).
Decoded values will be bitwise-anded with the given mask before being
@@ -102,15 +115,18 @@ def _VarintDecoder(mask):
"""
local_ord = ord
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos])
+ b = local_ord(buffer[pos]) if py2 else buffer[pos]
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
+ result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
@@ -118,15 +134,17 @@ def _VarintDecoder(mask):
return DecodeVarint
-def _SignedVarintDecoder(mask):
+def _SignedVarintDecoder(mask, result_type):
"""Like _VarintDecoder() but decodes signed values."""
local_ord = ord
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos])
+ b = local_ord(buffer[pos]) if py2 else buffer[pos]
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
@@ -135,19 +153,23 @@ def _SignedVarintDecoder(mask):
result |= ~mask
else:
result &= mask
+ result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')
return DecodeVarint
+# We force 32-bit values to int and 64-bit values to long to make
+# alternate implementations where the distinction is more significant
+# (e.g. the C++ implementation) simpler.
-_DecodeVarint = _VarintDecoder((1 << 64) - 1)
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
+_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
# Use these versions for values which must be limited to 32 bits.
-_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
def ReadTag(buffer, pos):
@@ -161,8 +183,10 @@ def ReadTag(buffer, pos):
use that, but not in Python.
"""
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
start = pos
- while ord(buffer[pos]) & 0x80:
+ while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80:
pos += 1
pos += 1
return (buffer[start:pos], pos)
@@ -269,10 +293,161 @@ def _StructPackDecoder(wire_type, format):
return _SimpleDecoder(wire_type, InnerDecode)
+def _FloatDecoder():
+ """Returns a decoder for a float field.
+
+ This code works around a bug in struct.unpack for non-finite 32-bit
+ floating-point values.
+ """
+
+ local_unpack = struct.unpack
+ b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
+
+ def InnerDecode(buffer, pos):
+ # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
+ # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
+ new_pos = pos + 4
+ float_bytes = buffer[pos:new_pos]
+
+ # If this value has all its exponent bits set, then it's non-finite.
+ # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
+ # To avoid that, we parse it specially.
+ if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25
+##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF')
+ and (float_bytes[2:3] >= b('\x80'))): ##PY25
+##!PY25 and (float_bytes[2:3] >= b'\x80')):
+ # If at least one significand bit is set...
+ if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25
+##!PY25 if float_bytes[0:3] != b'\x00\x00\x80':
+ return (_NAN, new_pos)
+ # If sign bit is set...
+ if float_bytes[3:4] == b('\xFF'): ##PY25
+##!PY25 if float_bytes[3:4] == b'\xFF':
+ return (_NEG_INF, new_pos)
+ return (_POS_INF, new_pos)
+
+ # Note that we expect someone up-stack to catch struct.error and convert
+ # it to _DecodeError -- this way we don't have to set up exception-
+ # handling blocks every time we parse one value.
+ result = local_unpack('<f', float_bytes)[0]
+ return (result, new_pos)
+ return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
+
+
+def _DoubleDecoder():
+ """Returns a decoder for a double field.
+
+ This code works around a bug in struct.unpack for not-a-number.
+ """
+
+ local_unpack = struct.unpack
+ b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
+
+ def InnerDecode(buffer, pos):
+ # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
+ # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
+ new_pos = pos + 8
+ double_bytes = buffer[pos:new_pos]
+
+ # If this value has all its exponent bits set and at least one significand
+ # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
+ # as inf or -inf. To avoid that, we treat it specially.
+##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF')
+##!PY25 and (double_bytes[6:7] >= b'\xF0')
+##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
+ if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25
+ and (double_bytes[6:7] >= b('\xF0')) ##PY25
+ and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25
+ return (_NAN, new_pos)
+
+ # Note that we expect someone up-stack to catch struct.error and convert
+ # it to _DecodeError -- this way we don't have to set up exception-
+ # handling blocks every time we parse one value.
+ result = local_unpack('<d', double_bytes)[0]
+ return (result, new_pos)
+ return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
+
+
+def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
+ enum_type = key.enum_type
+ if is_packed:
+ local_DecodeVarint = _DecodeVarint
+ def DecodePackedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ (endpoint, pos) = local_DecodeVarint(buffer, pos)
+ endpoint += pos
+ if endpoint > end:
+ raise _DecodeError('Truncated message.')
+ while pos < endpoint:
+ value_start_pos = pos
+ (element, pos) = _DecodeSignedVarint32(buffer, pos)
+ if element in enum_type.values_by_number:
+ value.append(element)
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_VARINT)
+ message._unknown_fields.append(
+ (tag_bytes, buffer[value_start_pos:pos]))
+ if pos > endpoint:
+ if element in enum_type.values_by_number:
+ del value[-1] # Discard corrupt value.
+ else:
+ del message._unknown_fields[-1]
+ raise _DecodeError('Packed element was truncated.')
+ return pos
+ return DecodePackedField
+ elif is_repeated:
+ tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+ if element in enum_type.values_by_number:
+ value.append(element)
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ message._unknown_fields.append(
+ (tag_bytes, buffer[pos:new_pos]))
+ # Predict that the next tag is another copy of the same repeated
+ # field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+ # Prediction failed. Return.
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ value_start_pos = pos
+ (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ if enum_value in enum_type.values_by_number:
+ field_dict[key] = enum_value
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_VARINT)
+ message._unknown_fields.append(
+ (tag_bytes, buffer[value_start_pos:pos]))
+ return pos
+ return DecodeField
+
+
# --------------------------------------------------------------------
-Int32Decoder = EnumDecoder = _SimpleDecoder(
+Int32Decoder = _SimpleDecoder(
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
Int64Decoder = _SimpleDecoder(
@@ -294,8 +469,8 @@ Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
-FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
-DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
+FloatDecoder = _FloatDecoder()
+DoubleDecoder = _DoubleDecoder()
BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
@@ -307,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
local_DecodeVarint = _DecodeVarint
local_unicode = unicode
+ def _ConvertToUnicode(byte_str):
+ try:
+ return local_unicode(byte_str, 'utf-8')
+ except UnicodeDecodeError, e:
+ # add more information to the error message and re-raise it.
+ e.reason = '%s in field: %s' % (e, key.full_name)
+ raise
+
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@@ -321,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
+ value.append(_ConvertToUnicode(buffer[pos:new_pos]))
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -334,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
+ field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
return new_pos
return DecodeField
@@ -503,6 +686,7 @@ def MessageSetItemDecoder(extensions_by_number):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
+ message_set_item_start = pos
type_id = -1
message_start = -1
message_end = -1
@@ -541,6 +725,11 @@ def MessageSetItemDecoder(extensions_by_number):
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
+ buffer[message_set_item_start:pos]))
return pos
@@ -552,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number):
def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
-
- while ord(buffer[pos]) & 0x80:
+ # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
+ # With this code, ord(b'') raises TypeError. Both are handled in
+ # python_message.py to generate a 'Truncated message' error.
+ while ord(buffer[pos:pos+1]) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -620,7 +811,6 @@ def _FieldSkipper():
]
wiretype_mask = wire_format.TAG_TYPE_MASK
- local_ord = ord
def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.
@@ -633,7 +823,7 @@ def _FieldSkipper():
"""
# The wire type is always in the first byte since varints are little-endian.
- wire_type = local_ord(tag_bytes[0]) & wiretype_mask
+ wire_type = ord(tag_bytes[0:1]) & wiretype_mask
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
return SkipField
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
new file mode 100644
index 0000000..856f472
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -0,0 +1,63 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.descriptor_database."""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+from google.apputils import basetest
+from google.protobuf import descriptor_pb2
+from google.protobuf.internal import factory_test2_pb2
+from google.protobuf import descriptor_database
+
+
+class DescriptorDatabaseTest(basetest.TestCase):
+
+ def testAdd(self):
+ db = descriptor_database.DescriptorDatabase()
+ file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ db.Add(file_desc_proto)
+
+ self.assertEquals(file_desc_proto, db.FindFileByName(
+ 'google/protobuf/internal/factory_test2.proto'))
+ self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message'))
+ self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message'))
+ self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Enum'))
+ self.assertEquals(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum'))
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
new file mode 100644
index 0000000..7c1ce2e
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -0,0 +1,564 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.descriptor_pool."""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+import os
+import unittest
+
+from google.apputils import basetest
+from google.protobuf import unittest_pb2
+from google.protobuf import descriptor_pb2
+from google.protobuf.internal import api_implementation
+from google.protobuf.internal import descriptor_pool_test1_pb2
+from google.protobuf.internal import descriptor_pool_test2_pb2
+from google.protobuf.internal import factory_test1_pb2
+from google.protobuf.internal import factory_test2_pb2
+from google.protobuf import descriptor
+from google.protobuf import descriptor_database
+from google.protobuf import descriptor_pool
+
+
+class DescriptorPoolTest(basetest.TestCase):
+
+ def setUp(self):
+ self.pool = descriptor_pool.DescriptorPool()
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(self.factory_test1_fd)
+ self.pool.Add(self.factory_test2_fd)
+
+ def testFindFileByName(self):
+ name1 = 'google/protobuf/internal/factory_test1.proto'
+ file_desc1 = self.pool.FindFileByName(name1)
+ self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
+ self.assertEquals(name1, file_desc1.name)
+ self.assertEquals('google.protobuf.python.internal', file_desc1.package)
+ self.assertIn('Factory1Message', file_desc1.message_types_by_name)
+
+ name2 = 'google/protobuf/internal/factory_test2.proto'
+ file_desc2 = self.pool.FindFileByName(name2)
+ self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
+ self.assertEquals(name2, file_desc2.name)
+ self.assertEquals('google.protobuf.python.internal', file_desc2.package)
+ self.assertIn('Factory2Message', file_desc2.message_types_by_name)
+
+ def testFindFileByNameFailure(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindFileByName('Does not exist')
+
+ def testFindFileContainingSymbol(self):
+ file_desc1 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory1Message')
+ self.assertIsInstance(file_desc1, descriptor.FileDescriptor)
+ self.assertEquals('google/protobuf/internal/factory_test1.proto',
+ file_desc1.name)
+ self.assertEquals('google.protobuf.python.internal', file_desc1.package)
+ self.assertIn('Factory1Message', file_desc1.message_types_by_name)
+
+ file_desc2 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory2Message')
+ self.assertIsInstance(file_desc2, descriptor.FileDescriptor)
+ self.assertEquals('google/protobuf/internal/factory_test2.proto',
+ file_desc2.name)
+ self.assertEquals('google.protobuf.python.internal', file_desc2.package)
+ self.assertIn('Factory2Message', file_desc2.message_types_by_name)
+
+ def testFindFileContainingSymbolFailure(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindFileContainingSymbol('Does not exist')
+
+ def testFindMessageTypeByName(self):
+ msg1 = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory1Message')
+ self.assertIsInstance(msg1, descriptor.Descriptor)
+ self.assertEquals('Factory1Message', msg1.name)
+ self.assertEquals('google.protobuf.python.internal.Factory1Message',
+ msg1.full_name)
+ self.assertEquals(None, msg1.containing_type)
+
+ nested_msg1 = msg1.nested_types[0]
+ self.assertEquals('NestedFactory1Message', nested_msg1.name)
+ self.assertEquals(msg1, nested_msg1.containing_type)
+
+ nested_enum1 = msg1.enum_types[0]
+ self.assertEquals('NestedFactory1Enum', nested_enum1.name)
+ self.assertEquals(msg1, nested_enum1.containing_type)
+
+ self.assertEquals(nested_msg1, msg1.fields_by_name[
+ 'nested_factory_1_message'].message_type)
+ self.assertEquals(nested_enum1, msg1.fields_by_name[
+ 'nested_factory_1_enum'].enum_type)
+
+ msg2 = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message')
+ self.assertIsInstance(msg2, descriptor.Descriptor)
+ self.assertEquals('Factory2Message', msg2.name)
+ self.assertEquals('google.protobuf.python.internal.Factory2Message',
+ msg2.full_name)
+ self.assertIsNone(msg2.containing_type)
+
+ nested_msg2 = msg2.nested_types[0]
+ self.assertEquals('NestedFactory2Message', nested_msg2.name)
+ self.assertEquals(msg2, nested_msg2.containing_type)
+
+ nested_enum2 = msg2.enum_types[0]
+ self.assertEquals('NestedFactory2Enum', nested_enum2.name)
+ self.assertEquals(msg2, nested_enum2.containing_type)
+
+ self.assertEquals(nested_msg2, msg2.fields_by_name[
+ 'nested_factory_2_message'].message_type)
+ self.assertEquals(nested_enum2, msg2.fields_by_name[
+ 'nested_factory_2_enum'].enum_type)
+
+ self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value)
+ self.assertEquals(
+ 1776, msg2.fields_by_name['int_with_default'].default_value)
+
+ self.assertTrue(
+ msg2.fields_by_name['double_with_default'].has_default_value)
+ self.assertEquals(
+ 9.99, msg2.fields_by_name['double_with_default'].default_value)
+
+ self.assertTrue(
+ msg2.fields_by_name['string_with_default'].has_default_value)
+ self.assertEquals(
+ 'hello world', msg2.fields_by_name['string_with_default'].default_value)
+
+ self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value)
+ self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value)
+
+ self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value)
+ self.assertEquals(
+ 1, msg2.fields_by_name['enum_with_default'].default_value)
+
+ msg3 = self.pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')
+ self.assertEquals(nested_msg2, msg3)
+
+ self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value)
+ self.assertEquals(
+ b'a\xfb\x00c',
+ msg2.fields_by_name['bytes_with_default'].default_value)
+
+ self.assertEqual(1, len(msg2.oneofs))
+ self.assertEqual(1, len(msg2.oneofs_by_name))
+ self.assertEqual(2, len(msg2.oneofs[0].fields))
+ for name in ['oneof_int', 'oneof_string']:
+ self.assertEqual(msg2.oneofs[0],
+ msg2.fields_by_name[name].containing_oneof)
+ self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields)
+
+ def testFindMessageTypeByNameFailure(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindMessageTypeByName('Does not exist')
+
+ def testFindEnumTypeByName(self):
+ enum1 = self.pool.FindEnumTypeByName(
+ 'google.protobuf.python.internal.Factory1Enum')
+ self.assertIsInstance(enum1, descriptor.EnumDescriptor)
+ self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number)
+ self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number)
+
+ nested_enum1 = self.pool.FindEnumTypeByName(
+ 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum')
+ self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor)
+ self.assertEquals(
+ 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number)
+ self.assertEquals(
+ 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number)
+
+ enum2 = self.pool.FindEnumTypeByName(
+ 'google.protobuf.python.internal.Factory2Enum')
+ self.assertIsInstance(enum2, descriptor.EnumDescriptor)
+ self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number)
+ self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number)
+
+ nested_enum2 = self.pool.FindEnumTypeByName(
+ 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')
+ self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor)
+ self.assertEquals(
+ 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number)
+ self.assertEquals(
+ 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number)
+
+ def testFindEnumTypeByNameFailure(self):
+ with self.assertRaises(KeyError):
+ self.pool.FindEnumTypeByName('Does not exist')
+
+ def testUserDefinedDB(self):
+ db = descriptor_database.DescriptorDatabase()
+ self.pool = descriptor_pool.DescriptorPool(db)
+ db.Add(self.factory_test1_fd)
+ db.Add(self.factory_test2_fd)
+ self.testFindMessageTypeByName()
+
+ def testComplexNesting(self):
+ test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
+ test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
+ descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
+ self.pool.Add(test1_desc)
+ self.pool.Add(test2_desc)
+ TEST1_FILE.CheckFile(self, self.pool)
+ TEST2_FILE.CheckFile(self, self.pool)
+
+
+
+class ProtoFile(object):
+
+ def __init__(self, name, package, messages, dependencies=None):
+ self.name = name
+ self.package = package
+ self.messages = messages
+ self.dependencies = dependencies or []
+
+ def CheckFile(self, test, pool):
+ file_desc = pool.FindFileByName(self.name)
+ test.assertEquals(self.name, file_desc.name)
+ test.assertEquals(self.package, file_desc.package)
+ dependencies_names = [f.name for f in file_desc.dependencies]
+ test.assertEqual(self.dependencies, dependencies_names)
+ for name, msg_type in self.messages.items():
+ msg_type.CheckType(test, None, name, file_desc)
+
+
+class EnumType(object):
+
+ def __init__(self, values):
+ self.values = values
+
+ def CheckType(self, test, msg_desc, name, file_desc):
+ enum_desc = msg_desc.enum_types_by_name[name]
+ test.assertEqual(name, enum_desc.name)
+ expected_enum_full_name = '.'.join([msg_desc.full_name, name])
+ test.assertEqual(expected_enum_full_name, enum_desc.full_name)
+ test.assertEqual(msg_desc, enum_desc.containing_type)
+ test.assertEqual(file_desc, enum_desc.file)
+ for index, (value, number) in enumerate(self.values):
+ value_desc = enum_desc.values_by_name[value]
+ test.assertEqual(value, value_desc.name)
+ test.assertEqual(index, value_desc.index)
+ test.assertEqual(number, value_desc.number)
+ test.assertEqual(enum_desc, value_desc.type)
+ test.assertIn(value, msg_desc.enum_values_by_name)
+
+
+class MessageType(object):
+
+ def __init__(self, type_dict, field_list, is_extendable=False,
+ extensions=None):
+ self.type_dict = type_dict
+ self.field_list = field_list
+ self.is_extendable = is_extendable
+ self.extensions = extensions or []
+
+ def CheckType(self, test, containing_type_desc, name, file_desc):
+ if containing_type_desc is None:
+ desc = file_desc.message_types_by_name[name]
+ expected_full_name = '.'.join([file_desc.package, name])
+ else:
+ desc = containing_type_desc.nested_types_by_name[name]
+ expected_full_name = '.'.join([containing_type_desc.full_name, name])
+
+ test.assertEqual(name, desc.name)
+ test.assertEqual(expected_full_name, desc.full_name)
+ test.assertEqual(containing_type_desc, desc.containing_type)
+ test.assertEqual(desc.file, file_desc)
+ test.assertEqual(self.is_extendable, desc.is_extendable)
+ for name, subtype in self.type_dict.items():
+ subtype.CheckType(test, desc, name, file_desc)
+
+ for index, (name, field) in enumerate(self.field_list):
+ field.CheckField(test, desc, name, index)
+
+ for index, (name, field) in enumerate(self.extensions):
+ field.CheckField(test, desc, name, index)
+
+
+class EnumField(object):
+
+ def __init__(self, number, type_name, default_value):
+ self.number = number
+ self.type_name = type_name
+ self.default_value = default_value
+
+ def CheckField(self, test, msg_desc, name, index):
+ field_desc = msg_desc.fields_by_name[name]
+ enum_desc = msg_desc.enum_types_by_name[self.type_name]
+ test.assertEqual(name, field_desc.name)
+ expected_field_full_name = '.'.join([msg_desc.full_name, name])
+ test.assertEqual(expected_field_full_name, field_desc.full_name)
+ test.assertEqual(index, field_desc.index)
+ test.assertEqual(self.number, field_desc.number)
+ test.assertEqual(descriptor.FieldDescriptor.TYPE_ENUM, field_desc.type)
+ test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM,
+ field_desc.cpp_type)
+ test.assertTrue(field_desc.has_default_value)
+ test.assertEqual(enum_desc.values_by_name[self.default_value].index,
+ field_desc.default_value)
+ test.assertEqual(msg_desc, field_desc.containing_type)
+ test.assertEqual(enum_desc, field_desc.enum_type)
+
+
+class MessageField(object):
+
+ def __init__(self, number, type_name):
+ self.number = number
+ self.type_name = type_name
+
+ def CheckField(self, test, msg_desc, name, index):
+ field_desc = msg_desc.fields_by_name[name]
+ field_type_desc = msg_desc.nested_types_by_name[self.type_name]
+ test.assertEqual(name, field_desc.name)
+ expected_field_full_name = '.'.join([msg_desc.full_name, name])
+ test.assertEqual(expected_field_full_name, field_desc.full_name)
+ test.assertEqual(index, field_desc.index)
+ test.assertEqual(self.number, field_desc.number)
+ test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type)
+ test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE,
+ field_desc.cpp_type)
+ test.assertFalse(field_desc.has_default_value)
+ test.assertEqual(msg_desc, field_desc.containing_type)
+ test.assertEqual(field_type_desc, field_desc.message_type)
+
+
+class StringField(object):
+
+ def __init__(self, number, default_value):
+ self.number = number
+ self.default_value = default_value
+
+ def CheckField(self, test, msg_desc, name, index):
+ field_desc = msg_desc.fields_by_name[name]
+ test.assertEqual(name, field_desc.name)
+ expected_field_full_name = '.'.join([msg_desc.full_name, name])
+ test.assertEqual(expected_field_full_name, field_desc.full_name)
+ test.assertEqual(index, field_desc.index)
+ test.assertEqual(self.number, field_desc.number)
+ test.assertEqual(descriptor.FieldDescriptor.TYPE_STRING, field_desc.type)
+ test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_STRING,
+ field_desc.cpp_type)
+ test.assertTrue(field_desc.has_default_value)
+ test.assertEqual(self.default_value, field_desc.default_value)
+
+
+class ExtensionField(object):
+
+ def __init__(self, number, extended_type):
+ self.number = number
+ self.extended_type = extended_type
+
+ def CheckField(self, test, msg_desc, name, index):
+ field_desc = msg_desc.extensions_by_name[name]
+ test.assertEqual(name, field_desc.name)
+ expected_field_full_name = '.'.join([msg_desc.full_name, name])
+ test.assertEqual(expected_field_full_name, field_desc.full_name)
+ test.assertEqual(self.number, field_desc.number)
+ test.assertEqual(index, field_desc.index)
+ test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type)
+ test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE,
+ field_desc.cpp_type)
+ test.assertFalse(field_desc.has_default_value)
+ test.assertTrue(field_desc.is_extension)
+ test.assertEqual(msg_desc, field_desc.extension_scope)
+ test.assertEqual(msg_desc, field_desc.message_type)
+ test.assertEqual(self.extended_type, field_desc.containing_type.name)
+
+
+class AddDescriptorTest(basetest.TestCase):
+
+ def _TestMessage(self, prefix):
+ pool = descriptor_pool.DescriptorPool()
+ pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes',
+ pool.FindMessageTypeByName(
+ prefix + 'protobuf_unittest.TestAllTypes').full_name)
+
+ # AddDescriptor is not recursive.
+ with self.assertRaises(KeyError):
+ pool.FindMessageTypeByName(
+ prefix + 'protobuf_unittest.TestAllTypes.NestedMessage')
+
+ pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes.NestedMessage',
+ pool.FindMessageTypeByName(
+ prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
+
+ # Files are implicitly also indexed when messages are added.
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ pool.FindFileByName(
+ 'google/protobuf/unittest.proto').name)
+
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ pool.FindFileContainingSymbol(
+ prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name)
+
+ def testMessage(self):
+ self._TestMessage('')
+ self._TestMessage('.')
+
+ def _TestEnum(self, prefix):
+ pool = descriptor_pool.DescriptorPool()
+ pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
+ self.assertEquals(
+ 'protobuf_unittest.ForeignEnum',
+ pool.FindEnumTypeByName(
+ prefix + 'protobuf_unittest.ForeignEnum').full_name)
+
+ # AddEnumDescriptor is not recursive.
+ with self.assertRaises(KeyError):
+ pool.FindEnumTypeByName(
+ prefix + 'protobuf_unittest.ForeignEnum.NestedEnum')
+
+ pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes.NestedEnum',
+ pool.FindEnumTypeByName(
+ prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
+
+ # Files are implicitly also indexed when enums are added.
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ pool.FindFileByName(
+ 'google/protobuf/unittest.proto').name)
+
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ pool.FindFileContainingSymbol(
+ prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name)
+
+ def testEnum(self):
+ self._TestEnum('')
+ self._TestEnum('.')
+
+ def testFile(self):
+ pool = descriptor_pool.DescriptorPool()
+ pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR)
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ pool.FindFileByName(
+ 'google/protobuf/unittest.proto').name)
+
+ # AddFileDescriptor is not recursive; messages and enums within files must
+ # be explicitly registered.
+ with self.assertRaises(KeyError):
+ pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes')
+
+
+TEST1_FILE = ProtoFile(
+ 'google/protobuf/internal/descriptor_pool_test1.proto',
+ 'google.protobuf.python.internal',
+ {
+ 'DescriptorPoolTest1': MessageType({
+ 'NestedEnum': EnumType([('ALPHA', 1), ('BETA', 2)]),
+ 'NestedMessage': MessageType({
+ 'NestedEnum': EnumType([('EPSILON', 5), ('ZETA', 6)]),
+ 'DeepNestedMessage': MessageType({
+ 'NestedEnum': EnumType([('ETA', 7), ('THETA', 8)]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'ETA')),
+ ('nested_field', StringField(2, 'theta')),
+ ]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'ZETA')),
+ ('nested_field', StringField(2, 'beta')),
+ ('deep_nested_message', MessageField(3, 'DeepNestedMessage')),
+ ])
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'BETA')),
+ ('nested_message', MessageField(2, 'NestedMessage')),
+ ], is_extendable=True),
+
+ 'DescriptorPoolTest2': MessageType({
+ 'NestedEnum': EnumType([('GAMMA', 3), ('DELTA', 4)]),
+ 'NestedMessage': MessageType({
+ 'NestedEnum': EnumType([('IOTA', 9), ('KAPPA', 10)]),
+ 'DeepNestedMessage': MessageType({
+ 'NestedEnum': EnumType([('LAMBDA', 11), ('MU', 12)]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'MU')),
+ ('nested_field', StringField(2, 'lambda')),
+ ]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'IOTA')),
+ ('nested_field', StringField(2, 'delta')),
+ ('deep_nested_message', MessageField(3, 'DeepNestedMessage')),
+ ])
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'GAMMA')),
+ ('nested_message', MessageField(2, 'NestedMessage')),
+ ]),
+ })
+
+
+TEST2_FILE = ProtoFile(
+ 'google/protobuf/internal/descriptor_pool_test2.proto',
+ 'google.protobuf.python.internal',
+ {
+ 'DescriptorPoolTest3': MessageType({
+ 'NestedEnum': EnumType([('NU', 13), ('XI', 14)]),
+ 'NestedMessage': MessageType({
+ 'NestedEnum': EnumType([('OMICRON', 15), ('PI', 16)]),
+ 'DeepNestedMessage': MessageType({
+ 'NestedEnum': EnumType([('RHO', 17), ('SIGMA', 18)]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'RHO')),
+ ('nested_field', StringField(2, 'sigma')),
+ ]),
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'PI')),
+ ('nested_field', StringField(2, 'nu')),
+ ('deep_nested_message', MessageField(3, 'DeepNestedMessage')),
+ ])
+ }, [
+ ('nested_enum', EnumField(1, 'NestedEnum', 'XI')),
+ ('nested_message', MessageField(2, 'NestedMessage')),
+ ], extensions=[
+ ('descriptor_pool_test',
+ ExtensionField(1001, 'DescriptorPoolTest1')),
+ ]),
+ },
+ dependencies=['google/protobuf/internal/descriptor_pool_test1.proto'])
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto
new file mode 100644
index 0000000..c11dcc0
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_pool_test1.proto
@@ -0,0 +1,94 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package google.protobuf.python.internal;
+
+
+message DescriptorPoolTest1 {
+ extensions 1000 to max;
+
+ enum NestedEnum {
+ ALPHA = 1;
+ BETA = 2;
+ }
+
+ optional NestedEnum nested_enum = 1 [default = BETA];
+
+ message NestedMessage {
+ enum NestedEnum {
+ EPSILON = 5;
+ ZETA = 6;
+ }
+ optional NestedEnum nested_enum = 1 [default = ZETA];
+ optional string nested_field = 2 [default = "beta"];
+ optional DeepNestedMessage deep_nested_message = 3;
+
+ message DeepNestedMessage {
+ enum NestedEnum {
+ ETA = 7;
+ THETA = 8;
+ }
+ optional NestedEnum nested_enum = 1 [default = ETA];
+ optional string nested_field = 2 [default = "theta"];
+ }
+ }
+
+ optional NestedMessage nested_message = 2;
+}
+
+message DescriptorPoolTest2 {
+ enum NestedEnum {
+ GAMMA = 3;
+ DELTA = 4;
+ }
+
+ optional NestedEnum nested_enum = 1 [default = GAMMA];
+
+ message NestedMessage {
+ enum NestedEnum {
+ IOTA = 9;
+ KAPPA = 10;
+ }
+ optional NestedEnum nested_enum = 1 [default = IOTA];
+ optional string nested_field = 2 [default = "delta"];
+ optional DeepNestedMessage deep_nested_message = 3;
+
+ message DeepNestedMessage {
+ enum NestedEnum {
+ LAMBDA = 11;
+ MU = 12;
+ }
+ optional NestedEnum nested_enum = 1 [default = MU];
+ optional string nested_field = 2 [default = "lambda"];
+ }
+ }
+
+ optional NestedMessage nested_message = 2;
+}
diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto
new file mode 100644
index 0000000..d97d39b
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_pool_test2.proto
@@ -0,0 +1,70 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package google.protobuf.python.internal;
+
+import "google/protobuf/internal/descriptor_pool_test1.proto";
+
+
+message DescriptorPoolTest3 {
+
+ extend DescriptorPoolTest1 {
+ optional DescriptorPoolTest3 descriptor_pool_test = 1001;
+ }
+
+ enum NestedEnum {
+ NU = 13;
+ XI = 14;
+ }
+
+ optional NestedEnum nested_enum = 1 [default = XI];
+
+ message NestedMessage {
+ enum NestedEnum {
+ OMICRON = 15;
+ PI = 16;
+ }
+ optional NestedEnum nested_enum = 1 [default = PI];
+ optional string nested_field = 2 [default = "nu"];
+ optional DeepNestedMessage deep_nested_message = 3;
+
+ message DeepNestedMessage {
+ enum NestedEnum {
+ RHO = 17;
+ SIGMA = 18;
+ }
+ optional NestedEnum nested_enum = 1 [default = RHO];
+ optional string nested_field = 2 [default = "sigma"];
+ }
+ }
+
+ optional NestedMessage nested_message = 2;
+}
+
diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py
new file mode 100644
index 0000000..b3a1571
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_python_test.py
@@ -0,0 +1,54 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Unittest for descriptor.py for the pure Python implementation."""
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
+
+# We must set the implementation version above before the google3 imports.
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+# Run all tests from the original module by putting them in our namespace.
+# pylint: disable=wildcard-import
+from google.protobuf.internal.descriptor_test import *
+
+
+class ConfirmPurePythonTest(basetest.TestCase):
+
+ def testImplementationSetting(self):
+ self.assertEqual('python', api_implementation.Type())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 05c2745..d20d945 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -34,7 +34,8 @@
__author__ = 'robinson@google.com (Will Robinson)'
-import unittest
+from google.apputils import basetest
+from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
@@ -47,7 +48,7 @@ name: 'TestEmptyMessage'
"""
-class DescriptorTest(unittest.TestCase):
+class DescriptorTest(basetest.TestCase):
def setUp(self):
self.my_file = descriptor.FileDescriptor(
@@ -101,6 +102,15 @@ class DescriptorTest(unittest.TestCase):
self.my_method
])
+ def testEnumValueName(self):
+ self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4),
+ 'FOREIGN_FOO')
+
+ self.assertEqual(
+ self.my_message.enum_types_by_name[
+ 'ForeignEnum'].values_by_number[4].name,
+ self.my_message.EnumValueName('ForeignEnum', 4))
+
def testEnumFixups(self):
self.assertEqual(self.my_enum, self.my_enum.values[0].type)
@@ -125,6 +135,257 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_service.GetOptions(),
descriptor_pb2.ServiceOptions())
+ def testSimpleCustomOptions(self):
+ file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
+ message_descriptor =\
+ unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR
+ field_descriptor = message_descriptor.fields_by_name["field1"]
+ enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"]
+ enum_value_descriptor =\
+ message_descriptor.enum_values_by_name["ANENUM_VAL2"]
+ service_descriptor =\
+ unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR
+ method_descriptor = service_descriptor.FindMethodByName("Foo")
+
+ file_options = file_descriptor.GetOptions()
+ file_opt1 = unittest_custom_options_pb2.file_opt1
+ self.assertEqual(9876543210, file_options.Extensions[file_opt1])
+ message_options = message_descriptor.GetOptions()
+ message_opt1 = unittest_custom_options_pb2.message_opt1
+ self.assertEqual(-56, message_options.Extensions[message_opt1])
+ field_options = field_descriptor.GetOptions()
+ field_opt1 = unittest_custom_options_pb2.field_opt1
+ self.assertEqual(8765432109, field_options.Extensions[field_opt1])
+ field_opt2 = unittest_custom_options_pb2.field_opt2
+ self.assertEqual(42, field_options.Extensions[field_opt2])
+ enum_options = enum_descriptor.GetOptions()
+ enum_opt1 = unittest_custom_options_pb2.enum_opt1
+ self.assertEqual(-789, enum_options.Extensions[enum_opt1])
+ enum_value_options = enum_value_descriptor.GetOptions()
+ enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1
+ self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1])
+
+ service_options = service_descriptor.GetOptions()
+ service_opt1 = unittest_custom_options_pb2.service_opt1
+ self.assertEqual(-9876543210, service_options.Extensions[service_opt1])
+ method_options = method_descriptor.GetOptions()
+ method_opt1 = unittest_custom_options_pb2.method_opt1
+ self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2,
+ method_options.Extensions[method_opt1])
+
+ def testDifferentCustomOptionTypes(self):
+ kint32min = -2**31
+ kint64min = -2**63
+ kint32max = 2**31 - 1
+ kint64max = 2**63 - 1
+ kuint32max = 2**32 - 1
+ kuint64max = 2**64 - 1
+
+ message_descriptor =\
+ unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR
+ message_options = message_descriptor.GetOptions()
+ self.assertEqual(False, message_options.Extensions[
+ unittest_custom_options_pb2.bool_opt])
+ self.assertEqual(kint32min, message_options.Extensions[
+ unittest_custom_options_pb2.int32_opt])
+ self.assertEqual(kint64min, message_options.Extensions[
+ unittest_custom_options_pb2.int64_opt])
+ self.assertEqual(0, message_options.Extensions[
+ unittest_custom_options_pb2.uint32_opt])
+ self.assertEqual(0, message_options.Extensions[
+ unittest_custom_options_pb2.uint64_opt])
+ self.assertEqual(kint32min, message_options.Extensions[
+ unittest_custom_options_pb2.sint32_opt])
+ self.assertEqual(kint64min, message_options.Extensions[
+ unittest_custom_options_pb2.sint64_opt])
+ self.assertEqual(0, message_options.Extensions[
+ unittest_custom_options_pb2.fixed32_opt])
+ self.assertEqual(0, message_options.Extensions[
+ unittest_custom_options_pb2.fixed64_opt])
+ self.assertEqual(kint32min, message_options.Extensions[
+ unittest_custom_options_pb2.sfixed32_opt])
+ self.assertEqual(kint64min, message_options.Extensions[
+ unittest_custom_options_pb2.sfixed64_opt])
+
+ message_descriptor =\
+ unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR
+ message_options = message_descriptor.GetOptions()
+ self.assertEqual(True, message_options.Extensions[
+ unittest_custom_options_pb2.bool_opt])
+ self.assertEqual(kint32max, message_options.Extensions[
+ unittest_custom_options_pb2.int32_opt])
+ self.assertEqual(kint64max, message_options.Extensions[
+ unittest_custom_options_pb2.int64_opt])
+ self.assertEqual(kuint32max, message_options.Extensions[
+ unittest_custom_options_pb2.uint32_opt])
+ self.assertEqual(kuint64max, message_options.Extensions[
+ unittest_custom_options_pb2.uint64_opt])
+ self.assertEqual(kint32max, message_options.Extensions[
+ unittest_custom_options_pb2.sint32_opt])
+ self.assertEqual(kint64max, message_options.Extensions[
+ unittest_custom_options_pb2.sint64_opt])
+ self.assertEqual(kuint32max, message_options.Extensions[
+ unittest_custom_options_pb2.fixed32_opt])
+ self.assertEqual(kuint64max, message_options.Extensions[
+ unittest_custom_options_pb2.fixed64_opt])
+ self.assertEqual(kint32max, message_options.Extensions[
+ unittest_custom_options_pb2.sfixed32_opt])
+ self.assertEqual(kint64max, message_options.Extensions[
+ unittest_custom_options_pb2.sfixed64_opt])
+
+ message_descriptor =\
+ unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR
+ message_options = message_descriptor.GetOptions()
+ self.assertEqual(-100, message_options.Extensions[
+ unittest_custom_options_pb2.int32_opt])
+ self.assertAlmostEqual(12.3456789, message_options.Extensions[
+ unittest_custom_options_pb2.float_opt], 6)
+ self.assertAlmostEqual(1.234567890123456789, message_options.Extensions[
+ unittest_custom_options_pb2.double_opt])
+ self.assertEqual("Hello, \"World\"", message_options.Extensions[
+ unittest_custom_options_pb2.string_opt])
+ self.assertEqual(b"Hello\0World", message_options.Extensions[
+ unittest_custom_options_pb2.bytes_opt])
+ dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum
+ self.assertEqual(
+ dummy_enum.TEST_OPTION_ENUM_TYPE2,
+ message_options.Extensions[unittest_custom_options_pb2.enum_opt])
+
+ message_descriptor =\
+ unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR
+ message_options = message_descriptor.GetOptions()
+ self.assertAlmostEqual(12, message_options.Extensions[
+ unittest_custom_options_pb2.float_opt], 6)
+ self.assertAlmostEqual(154, message_options.Extensions[
+ unittest_custom_options_pb2.double_opt])
+
+ message_descriptor =\
+ unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR
+ message_options = message_descriptor.GetOptions()
+ self.assertAlmostEqual(-12, message_options.Extensions[
+ unittest_custom_options_pb2.float_opt], 6)
+ self.assertAlmostEqual(-154, message_options.Extensions[
+ unittest_custom_options_pb2.double_opt])
+
+ def testComplexExtensionOptions(self):
+ descriptor =\
+ unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR
+ options = descriptor.GetOptions()
+ self.assertEqual(42, options.Extensions[
+ unittest_custom_options_pb2.complex_opt1].foo)
+ self.assertEqual(324, options.Extensions[
+ unittest_custom_options_pb2.complex_opt1].Extensions[
+ unittest_custom_options_pb2.quux])
+ self.assertEqual(876, options.Extensions[
+ unittest_custom_options_pb2.complex_opt1].Extensions[
+ unittest_custom_options_pb2.corge].qux)
+ self.assertEqual(987, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].baz)
+ self.assertEqual(654, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].Extensions[
+ unittest_custom_options_pb2.grault])
+ self.assertEqual(743, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].bar.foo)
+ self.assertEqual(1999, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].bar.Extensions[
+ unittest_custom_options_pb2.quux])
+ self.assertEqual(2008, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].bar.Extensions[
+ unittest_custom_options_pb2.corge].qux)
+ self.assertEqual(741, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].Extensions[
+ unittest_custom_options_pb2.garply].foo)
+ self.assertEqual(1998, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].Extensions[
+ unittest_custom_options_pb2.garply].Extensions[
+ unittest_custom_options_pb2.quux])
+ self.assertEqual(2121, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].Extensions[
+ unittest_custom_options_pb2.garply].Extensions[
+ unittest_custom_options_pb2.corge].qux)
+ self.assertEqual(1971, options.Extensions[
+ unittest_custom_options_pb2.ComplexOptionType2
+ .ComplexOptionType4.complex_opt4].waldo)
+ self.assertEqual(321, options.Extensions[
+ unittest_custom_options_pb2.complex_opt2].fred.waldo)
+ self.assertEqual(9, options.Extensions[
+ unittest_custom_options_pb2.complex_opt3].qux)
+ self.assertEqual(22, options.Extensions[
+ unittest_custom_options_pb2.complex_opt3].complexoptiontype5.plugh)
+ self.assertEqual(24, options.Extensions[
+ unittest_custom_options_pb2.complexopt6].xyzzy)
+
+ # Check that aggregate options were parsed and saved correctly in
+ # the appropriate descriptors.
+ def testAggregateOptions(self):
+ file_descriptor = unittest_custom_options_pb2.DESCRIPTOR
+ message_descriptor =\
+ unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR
+ field_descriptor = message_descriptor.fields_by_name["fieldname"]
+ enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR
+ enum_value_descriptor = enum_descriptor.values_by_name["VALUE"]
+ service_descriptor =\
+ unittest_custom_options_pb2.AggregateService.DESCRIPTOR
+ method_descriptor = service_descriptor.FindMethodByName("Method")
+
+ # Tests for the different types of data embedded in fileopt
+ file_options = file_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.fileopt]
+ self.assertEqual(100, file_options.i)
+ self.assertEqual("FileAnnotation", file_options.s)
+ self.assertEqual("NestedFileAnnotation", file_options.sub.s)
+ self.assertEqual("FileExtensionAnnotation", file_options.file.Extensions[
+ unittest_custom_options_pb2.fileopt].s)
+ self.assertEqual("EmbeddedMessageSetElement", file_options.mset.Extensions[
+ unittest_custom_options_pb2.AggregateMessageSetElement
+ .message_set_extension].s)
+
+ # Simple tests for all the other types of annotations
+ self.assertEqual(
+ "MessageAnnotation",
+ message_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.msgopt].s)
+ self.assertEqual(
+ "FieldAnnotation",
+ field_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.fieldopt].s)
+ self.assertEqual(
+ "EnumAnnotation",
+ enum_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.enumopt].s)
+ self.assertEqual(
+ "EnumValueAnnotation",
+ enum_value_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.enumvalopt].s)
+ self.assertEqual(
+ "ServiceAnnotation",
+ service_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.serviceopt].s)
+ self.assertEqual(
+ "MethodAnnotation",
+ method_descriptor.GetOptions().Extensions[
+ unittest_custom_options_pb2.methodopt].s)
+
+ def testNestedOptions(self):
+ nested_message =\
+ unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR
+ self.assertEqual(1001, nested_message.GetOptions().Extensions[
+ unittest_custom_options_pb2.message_opt1])
+ nested_field = nested_message.fields_by_name["nested_field"]
+ self.assertEqual(1002, nested_field.GetOptions().Extensions[
+ unittest_custom_options_pb2.field_opt1])
+ outer_message =\
+ unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR
+ nested_enum = outer_message.enum_types_by_name["NestedEnum"]
+ self.assertEqual(1003, nested_enum.GetOptions().Extensions[
+ unittest_custom_options_pb2.enum_opt1])
+ nested_enum_value = outer_message.enum_values_by_name["NESTED_ENUM_VALUE"]
+ self.assertEqual(1004, nested_enum_value.GetOptions().Extensions[
+ unittest_custom_options_pb2.enum_value_opt1])
+ nested_extension = outer_message.extensions_by_name["nested_extension"]
+ self.assertEqual(1005, nested_extension.GetOptions().Extensions[
+ unittest_custom_options_pb2.field_opt2])
+
def testFileDescriptorReferences(self):
self.assertEqual(self.my_enum.file, self.my_file)
self.assertEqual(self.my_message.file, self.my_file)
@@ -134,7 +395,7 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_file.package, 'protobuf_unittest')
-class DescriptorCopyToProtoTest(unittest.TestCase):
+class DescriptorCopyToProtoTest(basetest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
@@ -269,45 +530,49 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
descriptor_pb2.DescriptorProto,
TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
- def testCopyToProto_FileDescriptor(self):
- UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
- name: 'google/protobuf/unittest_import.proto'
- package: 'protobuf_unittest_import'
- message_type: <
- name: 'ImportMessage'
- field: <
- name: 'd'
- number: 1
- label: 1 # Optional
- type: 5 # TYPE_INT32
- >
- >
- """ +
- """enum_type: <
- name: 'ImportEnum'
- value: <
- name: 'IMPORT_FOO'
- number: 7
- >
- value: <
- name: 'IMPORT_BAR'
- number: 8
- >
- value: <
- name: 'IMPORT_BAZ'
- number: 9
- >
- >
- options: <
- java_package: 'com.google.protobuf.test'
- optimize_for: 1 # SPEED
- >
- """)
-
- self._InternalTestCopyToProto(
- unittest_import_pb2.DESCRIPTOR,
- descriptor_pb2.FileDescriptorProto,
- UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
+ # Disable this test so we can make changes to the proto file.
+ # TODO(xiaofeng): Enable this test after cl/55530659 is submitted.
+ #
+ # def testCopyToProto_FileDescriptor(self):
+ # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
+ # name: 'google/protobuf/unittest_import.proto'
+ # package: 'protobuf_unittest_import'
+ # dependency: 'google/protobuf/unittest_import_public.proto'
+ # message_type: <
+ # name: 'ImportMessage'
+ # field: <
+ # name: 'd'
+ # number: 1
+ # label: 1 # Optional
+ # type: 5 # TYPE_INT32
+ # >
+ # >
+ # """ +
+ # """enum_type: <
+ # name: 'ImportEnum'
+ # value: <
+ # name: 'IMPORT_FOO'
+ # number: 7
+ # >
+ # value: <
+ # name: 'IMPORT_BAR'
+ # number: 8
+ # >
+ # value: <
+ # name: 'IMPORT_BAZ'
+ # number: 9
+ # >
+ # >
+ # options: <
+ # java_package: 'com.google.protobuf.test'
+ # optimize_for: 1 # SPEED
+ # >
+ # public_dependency: 0
+ # """)
+ # self._InternalTestCopyToProto(
+ # unittest_import_pb2.DESCRIPTOR,
+ # descriptor_pb2.FileDescriptorProto,
+ # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
def testCopyToProto_ServiceDescriptor(self):
TEST_SERVICE_ASCII = """
@@ -323,12 +588,82 @@ class DescriptorCopyToProtoTest(unittest.TestCase):
output_type: '.protobuf_unittest.BarResponse'
>
"""
-
self._InternalTestCopyToProto(
unittest_pb2.TestService.DESCRIPTOR,
descriptor_pb2.ServiceDescriptorProto,
TEST_SERVICE_ASCII)
+class MakeDescriptorTest(basetest.TestCase):
+
+ def testMakeDescriptorWithNestedFields(self):
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.name = 'Foo2'
+ message_type = file_descriptor_proto.message_type.add()
+ message_type.name = file_descriptor_proto.name
+ nested_type = message_type.nested_type.add()
+ nested_type.name = 'Sub'
+ enum_type = nested_type.enum_type.add()
+ enum_type.name = 'FOO'
+ enum_type_val = enum_type.value.add()
+ enum_type_val.name = 'BAR'
+ enum_type_val.number = 3
+ field = message_type.field.add()
+ field.number = 1
+ field.name = 'uint64_field'
+ field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
+ field.type = descriptor.FieldDescriptor.TYPE_UINT64
+ field = message_type.field.add()
+ field.number = 2
+ field.name = 'nested_message_field'
+ field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
+ field.type = descriptor.FieldDescriptor.TYPE_MESSAGE
+ field.type_name = 'Sub'
+ enum_field = nested_type.field.add()
+ enum_field.number = 2
+ enum_field.name = 'bar_field'
+ enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
+ enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM
+ enum_field.type_name = 'Foo2.Sub.FOO'
+
+ result = descriptor.MakeDescriptor(message_type)
+ self.assertEqual(result.fields[0].cpp_type,
+ descriptor.FieldDescriptor.CPPTYPE_UINT64)
+ self.assertEqual(result.fields[1].cpp_type,
+ descriptor.FieldDescriptor.CPPTYPE_MESSAGE)
+ self.assertEqual(result.fields[1].message_type.containing_type,
+ result)
+ self.assertEqual(result.nested_types[0].fields[0].full_name,
+ 'Foo2.Sub.bar_field')
+ self.assertEqual(result.nested_types[0].fields[0].enum_type,
+ result.nested_types[0].enum_types[0])
+
+ def testMakeDescriptorWithUnsignedIntField(self):
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.name = 'Foo'
+ message_type = file_descriptor_proto.message_type.add()
+ message_type.name = file_descriptor_proto.name
+ enum_type = message_type.enum_type.add()
+ enum_type.name = 'FOO'
+ enum_type_val = enum_type.value.add()
+ enum_type_val.name = 'BAR'
+ enum_type_val.number = 3
+ field = message_type.field.add()
+ field.number = 1
+ field.name = 'uint64_field'
+ field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
+ field.type = descriptor.FieldDescriptor.TYPE_UINT64
+ enum_field = message_type.field.add()
+ enum_field.number = 2
+ enum_field.name = 'bar_field'
+ enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED
+ enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM
+ enum_field.type_name = 'Foo.FOO'
+
+ result = descriptor.MakeDescriptor(message_type)
+ self.assertEqual(result.fields[0].cpp_type,
+ descriptor.FieldDescriptor.CPPTYPE_UINT64)
+
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index aa05d5b..0a7c041 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#PY25 compatible for GAE.
+#
+# Copyright 2009 Google Inc. All Rights Reserved.
+
"""Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type
@@ -67,9 +71,17 @@ sizer rather than when calling them. In particular:
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
+import sys ##PY25
+_PY2 = sys.version_info[0] < 3 ##PY25
from google.protobuf.internal import wire_format
+# This will overflow and thus become IEEE-754 "infinity". We would use
+# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
+_POS_INF = 1e10000
+_NEG_INF = -_POS_INF
+
+
def _VarintSize(value):
"""Compute the size of a varint value."""
if value <= 0x7f: return 1
@@ -334,7 +346,8 @@ def MessageSetItemSizer(field_number):
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
- local_chr = chr
+ local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
+##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeVarint(write, value):
bits = value & 0x7f
value >>= 7
@@ -351,7 +364,8 @@ def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
- local_chr = chr
+ local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25
+##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,))
def EncodeSignedVarint(write, value):
if value < 0:
value += (1 << 64)
@@ -376,7 +390,8 @@ def _VarintBytes(value):
pieces = []
_EncodeVarint(pieces.append, value)
- return "".join(pieces)
+ return "".encode("latin1").join(pieces) ##PY25
+##!PY25 return b"".join(pieces)
def TagBytes(field_number, wire_type):
@@ -502,6 +517,90 @@ def _StructPackEncoder(wire_type, format):
return SpecificEncoder
+def _FloatingPointEncoder(wire_type, format):
+ """Return a constructor for an encoder for float fields.
+
+ This is like StructPackEncoder, but catches errors that may be due to
+ passing non-finite floating-point values to struct.pack, and makes a
+ second attempt to encode those values.
+
+ Args:
+ wire_type: The field's wire type, for encoding tags.
+ format: The format string to pass to struct.pack().
+ """
+
+ b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25
+ value_size = struct.calcsize(format)
+ if value_size == 4:
+ def EncodeNonFiniteOrRaise(write, value):
+ # Remember that the serialized form uses little-endian byte order.
+ if value == _POS_INF:
+ write(b('\x00\x00\x80\x7F')) ##PY25
+##!PY25 write(b'\x00\x00\x80\x7F')
+ elif value == _NEG_INF:
+ write(b('\x00\x00\x80\xFF')) ##PY25
+##!PY25 write(b'\x00\x00\x80\xFF')
+ elif value != value: # NaN
+ write(b('\x00\x00\xC0\x7F')) ##PY25
+##!PY25 write(b'\x00\x00\xC0\x7F')
+ else:
+ raise
+ elif value_size == 8:
+ def EncodeNonFiniteOrRaise(write, value):
+ if value == _POS_INF:
+ write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25
+##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ elif value == _NEG_INF:
+ write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25
+##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ elif value != value: # NaN
+ write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25
+##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ else:
+ raise
+ else:
+ raise ValueError('Can\'t encode floating-point values that are '
+ '%d bytes long (only 4 or 8)' % value_size)
+
+ def SpecificEncoder(field_number, is_repeated, is_packed):
+ local_struct_pack = struct.pack
+ if is_packed:
+ tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ local_EncodeVarint = _EncodeVarint
+ def EncodePackedField(write, value):
+ write(tag_bytes)
+ local_EncodeVarint(write, len(value) * value_size)
+ for element in value:
+ # This try/except block is going to be faster than any code that
+ # we could write to check whether element is finite.
+ try:
+ write(local_struct_pack(format, element))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, element)
+ return EncodePackedField
+ elif is_repeated:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeRepeatedField(write, value):
+ for element in value:
+ write(tag_bytes)
+ try:
+ write(local_struct_pack(format, element))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, element)
+ return EncodeRepeatedField
+ else:
+ tag_bytes = TagBytes(field_number, wire_type)
+ def EncodeField(write, value):
+ write(tag_bytes)
+ try:
+ write(local_struct_pack(format, value))
+ except SystemError:
+ EncodeNonFiniteOrRaise(write, value)
+ return EncodeField
+
+ return SpecificEncoder
+
+
# ====================================================================
# Here we declare an encoder constructor for each field type. These work
# very similarly to sizer constructors, described earlier.
@@ -525,15 +624,17 @@ Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
-FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
-DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
+FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
+DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field."""
- false_byte = chr(0)
- true_byte = chr(1)
+##!PY25 false_byte = b'\x00'
+##!PY25 true_byte = b'\x01'
+ false_byte = '\x00'.encode('latin1') ##PY25
+ true_byte = '\x01'.encode('latin1') ##PY25
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
@@ -669,7 +770,8 @@ def MessageSetItemEncoder(field_number):
}
}
"""
- start_bytes = "".join([
+ start_bytes = "".encode("latin1").join([ ##PY25
+##!PY25 start_bytes = b"".join([
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
TagBytes(2, wire_format.WIRETYPE_VARINT),
_VarintBytes(field_number),
diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py
new file mode 100644
index 0000000..7b28645
--- /dev/null
+++ b/python/google/protobuf/internal/enum_type_wrapper.py
@@ -0,0 +1,89 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""A simple wrapper around enum types to expose utility functions.
+
+Instances are created as properties with the same name as the enum they wrap
+on proto classes. For usage, see:
+ reflection_test.py
+"""
+
+__author__ = 'rabsatt@google.com (Kevin Rabsatt)'
+
+
+class EnumTypeWrapper(object):
+ """A utility for finding the names of enum values."""
+
+ DESCRIPTOR = None
+
+ def __init__(self, enum_type):
+ """Inits EnumTypeWrapper with an EnumDescriptor."""
+ self._enum_type = enum_type
+ self.DESCRIPTOR = enum_type;
+
+ def Name(self, number):
+ """Returns a string containing the name of an enum value."""
+ if number in self._enum_type.values_by_number:
+ return self._enum_type.values_by_number[number].name
+ raise ValueError('Enum %s has no name defined for value %d' % (
+ self._enum_type.name, number))
+
+ def Value(self, name):
+ """Returns the value coresponding to the given enum name."""
+ if name in self._enum_type.values_by_name:
+ return self._enum_type.values_by_name[name].number
+ raise ValueError('Enum %s has no value defined for name %s' % (
+ self._enum_type.name, name))
+
+ def keys(self):
+ """Return a list of the string names in the enum.
+
+ These are returned in the order they were defined in the .proto file.
+ """
+
+ return [value_descriptor.name
+ for value_descriptor in self._enum_type.values]
+
+ def values(self):
+ """Return a list of the integer values in the enum.
+
+ These are returned in the order they were defined in the .proto file.
+ """
+
+ return [value_descriptor.number
+ for value_descriptor in self._enum_type.values]
+
+ def items(self):
+ """Return a list of the (name, value) pairs of the enum.
+
+ These are returned in the order they were defined in the .proto file.
+ """
+ return [(value_descriptor.name, value_descriptor.number)
+ for value_descriptor in self._enum_type.values]
diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto
new file mode 100644
index 0000000..03dcb2c
--- /dev/null
+++ b/python/google/protobuf/internal/factory_test1.proto
@@ -0,0 +1,57 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: matthewtoia@google.com (Matt Toia)
+
+
+package google.protobuf.python.internal;
+
+
+enum Factory1Enum {
+ FACTORY_1_VALUE_0 = 0;
+ FACTORY_1_VALUE_1 = 1;
+}
+
+message Factory1Message {
+ optional Factory1Enum factory_1_enum = 1;
+ enum NestedFactory1Enum {
+ NESTED_FACTORY_1_VALUE_0 = 0;
+ NESTED_FACTORY_1_VALUE_1 = 1;
+ }
+ optional NestedFactory1Enum nested_factory_1_enum = 2;
+ message NestedFactory1Message {
+ optional string value = 1;
+ }
+ optional NestedFactory1Message nested_factory_1_message = 3;
+ optional int32 scalar_value = 4;
+ repeated string list_value = 5;
+
+ extensions 1000 to max;
+}
diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto
new file mode 100644
index 0000000..a8c6812
--- /dev/null
+++ b/python/google/protobuf/internal/factory_test2.proto
@@ -0,0 +1,92 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: matthewtoia@google.com (Matt Toia)
+
+
+package google.protobuf.python.internal;
+
+import "google/protobuf/internal/factory_test1.proto";
+
+
+enum Factory2Enum {
+ FACTORY_2_VALUE_0 = 0;
+ FACTORY_2_VALUE_1 = 1;
+}
+
+message Factory2Message {
+ required int32 mandatory = 1;
+ optional Factory2Enum factory_2_enum = 2;
+ enum NestedFactory2Enum {
+ NESTED_FACTORY_2_VALUE_0 = 0;
+ NESTED_FACTORY_2_VALUE_1 = 1;
+ }
+ optional NestedFactory2Enum nested_factory_2_enum = 3;
+ message NestedFactory2Message {
+ optional string value = 1;
+ }
+ optional NestedFactory2Message nested_factory_2_message = 4;
+ optional Factory1Message factory_1_message = 5;
+ optional Factory1Enum factory_1_enum = 6;
+ optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7;
+ optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8;
+ optional Factory2Message circular_message = 9;
+ optional string scalar_value = 10;
+ repeated string list_value = 11;
+ repeated group Grouped = 12 {
+ optional string part_1 = 13;
+ optional string part_2 = 14;
+ }
+ optional LoopMessage loop = 15;
+ optional int32 int_with_default = 16 [default = 1776];
+ optional double double_with_default = 17 [default = 9.99];
+ optional string string_with_default = 18 [default = "hello world"];
+ optional bool bool_with_default = 19 [default = false];
+ optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1];
+ optional bytes bytes_with_default = 21 [default = "a\373\000c"];
+
+
+ extend Factory1Message {
+ optional string one_more_field = 1001;
+ }
+
+ oneof oneof_field {
+ int32 oneof_int = 22;
+ string oneof_string = 23;
+ }
+}
+
+message LoopMessage {
+ optional Factory2Message loop = 1;
+}
+
+extend Factory1Message {
+ optional string another_field = 1002;
+}
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
index 78360b5..5818060 100755
--- a/python/google/protobuf/internal/generator_test.py
+++ b/python/google/protobuf/internal/generator_test.py
@@ -41,17 +41,21 @@ further ensures that we can use Python protocol message objects as we expect.
__author__ = 'robinson@google.com (Will Robinson)'
-import unittest
+from google.apputils import basetest
+from google.protobuf.internal import test_bad_identifiers_pb2
+from google.protobuf import unittest_custom_options_pb2
from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_import_public_pb2
from google.protobuf import unittest_mset_pb2
-from google.protobuf import unittest_pb2
from google.protobuf import unittest_no_generic_services_pb2
-
+from google.protobuf import unittest_pb2
+from google.protobuf import service
+from google.protobuf import symbol_database
MAX_EXTENSION = 536870912
-class GeneratorTest(unittest.TestCase):
+class GeneratorTest(basetest.TestCase):
def testNestedMessageDescriptor(self):
field_name = 'optional_nested_message'
@@ -99,6 +103,7 @@ class GeneratorTest(unittest.TestCase):
self.assertTrue(isinf(message.neg_inf_float))
self.assertTrue(message.neg_inf_float < 0)
self.assertTrue(isnan(message.nan_float))
+ self.assertEqual("? ? ?? ?? ??? ??/ ??-", message.cpp_trigraph)
def testHasDefaultValues(self):
desc = unittest_pb2.TestAllTypes.DESCRIPTOR
@@ -140,6 +145,13 @@ class GeneratorTest(unittest.TestCase):
proto = unittest_mset_pb2.TestMessageSet()
self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+ def testMessageWithCustomOptions(self):
+ proto = unittest_custom_options_pb2.TestMessageWithCustomOptions()
+ enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions()
+ self.assertTrue(enum_options is not None)
+ # TODO(gps): We really should test for the presense of the enum_opt1
+ # extension and for its value to be set to -789.
+
def testNestedTypes(self):
self.assertEquals(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
@@ -206,15 +218,126 @@ class GeneratorTest(unittest.TestCase):
'google/protobuf/unittest.proto')
self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
+ self.assertEqual(unittest_pb2.DESCRIPTOR.dependencies,
+ [unittest_import_pb2.DESCRIPTOR])
+ self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies,
+ [unittest_import_public_pb2.DESCRIPTOR])
def testNoGenericServices(self):
- # unittest_no_generic_services.proto should contain defs for everything
- # except services.
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
- self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService"))
+ # Make sure unittest_no_generic_services_pb2 has no services subclassing
+ # Proto2 Service class.
+ if hasattr(unittest_no_generic_services_pb2, "TestService"):
+ self.assertFalse(issubclass(unittest_no_generic_services_pb2.TestService,
+ service.Service))
+
+ def testMessageTypesByName(self):
+ file_type = unittest_pb2.DESCRIPTOR
+ self.assertEqual(
+ unittest_pb2._TESTALLTYPES,
+ file_type.message_types_by_name[unittest_pb2._TESTALLTYPES.name])
+
+ # Nested messages shouldn't be included in the message_types_by_name
+ # dictionary (like in the C++ API).
+ self.assertFalse(
+ unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in
+ file_type.message_types_by_name)
+
+ def testEnumTypesByName(self):
+ file_type = unittest_pb2.DESCRIPTOR
+ self.assertEqual(
+ unittest_pb2._FOREIGNENUM,
+ file_type.enum_types_by_name[unittest_pb2._FOREIGNENUM.name])
+
+ def testExtensionsByName(self):
+ file_type = unittest_pb2.DESCRIPTOR
+ self.assertEqual(
+ unittest_pb2.my_extension_string,
+ file_type.extensions_by_name[unittest_pb2.my_extension_string.name])
+
+ def testPublicImports(self):
+ # Test public imports as embedded message.
+ all_type_proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(0, all_type_proto.optional_public_import_message.e)
+
+ # PublicImportMessage is actually defined in unittest_import_public_pb2
+ # module, and is public imported by unittest_import_pb2 module.
+ public_import_proto = unittest_import_pb2.PublicImportMessage()
+ self.assertEqual(0, public_import_proto.e)
+ self.assertTrue(unittest_import_public_pb2.PublicImportMessage is
+ unittest_import_pb2.PublicImportMessage)
+
+ def testBadIdentifiers(self):
+ # We're just testing that the code was imported without problems.
+ message = test_bad_identifiers_pb2.TestBadIdentifiers()
+ self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message],
+ "foo")
+ self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor],
+ "bar")
+ self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection],
+ "baz")
+ self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service],
+ "qux")
+
+ def testOneof(self):
+ desc = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.assertEqual(1, len(desc.oneofs))
+ self.assertEqual('oneof_field', desc.oneofs[0].name)
+ self.assertEqual(0, desc.oneofs[0].index)
+ self.assertIs(desc, desc.oneofs[0].containing_type)
+ self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field'])
+ nested_names = set(['oneof_uint32', 'oneof_nested_message',
+ 'oneof_string', 'oneof_bytes'])
+ self.assertSameElements(
+ nested_names,
+ [field.name for field in desc.oneofs[0].fields])
+ for field_name, field_desc in desc.fields_by_name.iteritems():
+ if field_name in nested_names:
+ self.assertIs(desc.oneofs[0], field_desc.containing_oneof)
+ else:
+ self.assertIsNone(field_desc.containing_oneof)
+
+
+class SymbolDatabaseRegistrationTest(basetest.TestCase):
+ """Checks that messages, enums and files are correctly registered."""
+
+ def testGetSymbol(self):
+ self.assertEquals(
+ unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol(
+ 'protobuf_unittest.TestAllTypes'))
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.NestedMessage,
+ symbol_database.Default().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.NestedMessage'))
+ with self.assertRaises(KeyError):
+ symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage')
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.OptionalGroup,
+ symbol_database.Default().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.OptionalGroup'))
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.RepeatedGroup,
+ symbol_database.Default().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.RepeatedGroup'))
+
+ def testEnums(self):
+ self.assertEquals(
+ 'protobuf_unittest.ForeignEnum',
+ symbol_database.Default().pool.FindEnumTypeByName(
+ 'protobuf_unittest.ForeignEnum').full_name)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes.NestedEnum',
+ symbol_database.Default().pool.FindEnumTypeByName(
+ 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
+
+ def testFindFileByName(self):
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ symbol_database.Default().pool.FindFileByName(
+ 'google/protobuf/unittest.proto').name)
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py
new file mode 100644
index 0000000..6a2053f
--- /dev/null
+++ b/python/google/protobuf/internal/message_factory_python_test.py
@@ -0,0 +1,54 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for ..public.message_factory for the pure Python implementation."""
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
+
+# We must set the implementation version above before the google3 imports.
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+# Run all tests from the original module by putting them in our namespace.
+# pylint: disable=wildcard-import
+from google.protobuf.internal.message_factory_test import *
+
+
+class ConfirmPurePythonTest(basetest.TestCase):
+
+ def testImplementationSetting(self):
+ self.assertEqual('python', api_implementation.Type())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
new file mode 100644
index 0000000..c53d77b
--- /dev/null
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -0,0 +1,131 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.message_factory."""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+from google.apputils import basetest
+from google.protobuf import descriptor_pb2
+from google.protobuf.internal import factory_test1_pb2
+from google.protobuf.internal import factory_test2_pb2
+from google.protobuf import descriptor_database
+from google.protobuf import descriptor_pool
+from google.protobuf import message_factory
+
+
+class MessageFactoryTest(basetest.TestCase):
+
+ def setUp(self):
+ self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test1_pb2.DESCRIPTOR.serialized_pb)
+ self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
+ factory_test2_pb2.DESCRIPTOR.serialized_pb)
+
+ def _ExerciseDynamicClass(self, cls):
+ msg = cls()
+ msg.mandatory = 42
+ msg.nested_factory_2_enum = 0
+ msg.nested_factory_2_message.value = 'nested message value'
+ msg.factory_1_message.factory_1_enum = 1
+ msg.factory_1_message.nested_factory_1_enum = 0
+ msg.factory_1_message.nested_factory_1_message.value = (
+ 'nested message value')
+ msg.factory_1_message.scalar_value = 22
+ msg.factory_1_message.list_value.extend([u'one', u'two', u'three'])
+ msg.factory_1_message.list_value.append(u'four')
+ msg.factory_1_enum = 1
+ msg.nested_factory_1_enum = 0
+ msg.nested_factory_1_message.value = 'nested message value'
+ msg.circular_message.mandatory = 1
+ msg.circular_message.circular_message.mandatory = 2
+ msg.circular_message.scalar_value = 'one deep'
+ msg.scalar_value = 'zero deep'
+ msg.list_value.extend([u'four', u'three', u'two'])
+ msg.list_value.append(u'one')
+ msg.grouped.add()
+ msg.grouped[0].part_1 = 'hello'
+ msg.grouped[0].part_2 = 'world'
+ msg.grouped.add(part_1='testing', part_2='123')
+ msg.loop.loop.mandatory = 2
+ msg.loop.loop.loop.loop.mandatory = 4
+ serialized = msg.SerializeToString()
+ converted = factory_test2_pb2.Factory2Message.FromString(serialized)
+ reserialized = converted.SerializeToString()
+ self.assertEquals(serialized, reserialized)
+ result = cls.FromString(reserialized)
+ self.assertEquals(msg, result)
+
+ def testGetPrototype(self):
+ db = descriptor_database.DescriptorDatabase()
+ pool = descriptor_pool.DescriptorPool(db)
+ db.Add(self.factory_test1_fd)
+ db.Add(self.factory_test2_fd)
+ factory = message_factory.MessageFactory()
+ cls = factory.GetPrototype(pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message'))
+ self.assertIsNot(cls, factory_test2_pb2.Factory2Message)
+ self._ExerciseDynamicClass(cls)
+ cls2 = factory.GetPrototype(pool.FindMessageTypeByName(
+ 'google.protobuf.python.internal.Factory2Message'))
+ self.assertIs(cls, cls2)
+
+ def testGetMessages(self):
+ # performed twice because multiple calls with the same input must be allowed
+ for _ in range(2):
+ messages = message_factory.GetMessages([self.factory_test2_fd,
+ self.factory_test1_fd])
+ self.assertContainsSubset(
+ ['google.protobuf.python.internal.Factory2Message',
+ 'google.protobuf.python.internal.Factory1Message'],
+ messages.keys())
+ self._ExerciseDynamicClass(
+ messages['google.protobuf.python.internal.Factory2Message'])
+ self.assertContainsSubset(
+ ['google.protobuf.python.internal.Factory2Message.one_more_field',
+ 'google.protobuf.python.internal.another_field'],
+ (messages['google.protobuf.python.internal.Factory1Message']
+ ._extensions_by_name.keys()))
+ factory_msg1 = messages['google.protobuf.python.internal.Factory1Message']
+ msg1 = messages['google.protobuf.python.internal.Factory1Message']()
+ ext1 = factory_msg1._extensions_by_name[
+ 'google.protobuf.python.internal.Factory2Message.one_more_field']
+ ext2 = factory_msg1._extensions_by_name[
+ 'google.protobuf.python.internal.another_field']
+ msg1.Extensions[ext1] = 'test1'
+ msg1.Extensions[ext2] = 'test2'
+ self.assertEquals('test1', msg1.Extensions[ext1])
+ self.assertEquals('test2', msg1.Extensions[ext2])
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/message_python_test.py
new file mode 100644
index 0000000..baf1504
--- /dev/null
+++ b/python/google/protobuf/internal/message_python_test.py
@@ -0,0 +1,54 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for ..public.message for the pure Python implementation."""
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
+
+# We must set the implementation version above before the google3 imports.
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+# Run all tests from the original module by putting them in our namespace.
+# pylint: disable=wildcard-import
+from google.protobuf.internal.message_test import *
+
+
+class ConfirmPurePythonTest(basetest.TestCase):
+
+ def testImplementationSetting(self):
+ self.assertEqual('python', api_implementation.Type())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 73a9a3a..f4c4ae0 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -43,47 +43,634 @@ abstract interface.
__author__ = 'gps@google.com (Gregory P. Smith)'
-import unittest
-from google.protobuf import unittest_import_pb2
+import copy
+import math
+import operator
+import pickle
+import sys
+
+from google.apputils import basetest
from google.protobuf import unittest_pb2
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
+from google.protobuf import message
+
+# Python pre-2.6 does not have isinf() or isnan() functions, so we have
+# to provide our own.
+def isnan(val):
+ # NaN is never equal to itself.
+ return val != val
+def isinf(val):
+ # Infinity times zero equals NaN.
+ return not isnan(val) and isnan(val * 0)
+def IsPosInf(val):
+ return isinf(val) and (val > 0)
+def IsNegInf(val):
+ return isinf(val) and (val < 0)
-class MessageTest(unittest.TestCase):
+class MessageTest(basetest.TestCase):
+
+ def testBadUtf8String(self):
+ if api_implementation.Type() != 'python':
+ self.skipTest("Skipping testBadUtf8String, currently only the python "
+ "api implementation raises UnicodeDecodeError when a "
+ "string field contains bad utf-8.")
+ bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
+ with self.assertRaises(UnicodeDecodeError) as context:
+ unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
+ self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
+ str(context.exception))
def testGoldenMessage(self):
- golden_data = test_util.GoldenFile('golden_message').read()
+ golden_data = test_util.GoldenFileData(
+ 'golden_message_oneof_implemented')
golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data)
test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertTrue(golden_message.SerializeToString() == golden_data)
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenExtensions(self):
- golden_data = test_util.GoldenFile('golden_message').read()
+ golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(all_set)
self.assertEquals(all_set, golden_message)
- self.assertTrue(golden_message.SerializeToString() == golden_data)
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenPackedMessage(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(all_set)
self.assertEquals(all_set, golden_message)
- self.assertTrue(all_set.SerializeToString() == golden_data)
+ self.assertEqual(golden_data, all_set.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
+ golden_data = test_util.GoldenFileData('golden_packed_fields_message')
golden_message = unittest_pb2.TestPackedExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedExtensions()
test_util.SetAllPackedExtensions(all_set)
self.assertEquals(all_set, golden_message)
- self.assertTrue(all_set.SerializeToString() == golden_data)
+ self.assertEqual(golden_data, all_set.SerializeToString())
+ golden_copy = copy.deepcopy(golden_message)
+ self.assertEqual(golden_data, golden_copy.SerializeToString())
+
+ def testPickleSupport(self):
+ golden_data = test_util.GoldenFileData('golden_message')
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ pickled_message = pickle.dumps(golden_message)
+
+ unpickled_message = pickle.loads(pickled_message)
+ self.assertEquals(unpickled_message, golden_message)
+
+
+ def testPickleIncompleteProto(self):
+ golden_message = unittest_pb2.TestRequired(a=1)
+ pickled_message = pickle.dumps(golden_message)
+
+ unpickled_message = pickle.loads(pickled_message)
+ self.assertEquals(unpickled_message, golden_message)
+ self.assertEquals(unpickled_message.a, 1)
+ # This is still an incomplete proto - so serializing should fail
+ self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
+
+ def testPositiveInfinity(self):
+ golden_data = (b'\x5D\x00\x00\x80\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
+ b'\xCD\x02\x00\x00\x80\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(IsPosInf(golden_message.optional_float))
+ self.assertTrue(IsPosInf(golden_message.optional_double))
+ self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
+ self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+
+ def testNegativeInfinity(self):
+ golden_data = (b'\x5D\x00\x00\x80\xFF'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
+ b'\xCD\x02\x00\x00\x80\xFF'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(IsNegInf(golden_message.optional_float))
+ self.assertTrue(IsNegInf(golden_message.optional_double))
+ self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
+ self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+
+ def testNotANumber(self):
+ golden_data = (b'\x5D\x00\x00\xC0\x7F'
+ b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
+ b'\xCD\x02\x00\x00\xC0\x7F'
+ b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(isnan(golden_message.optional_float))
+ self.assertTrue(isnan(golden_message.optional_double))
+ self.assertTrue(isnan(golden_message.repeated_float[0]))
+ self.assertTrue(isnan(golden_message.repeated_double[0]))
+
+ # The protocol buffer may serialize to any one of multiple different
+ # representations of a NaN. Rather than verify a specific representation,
+ # verify the serialized string can be converted into a correctly
+ # behaving protocol buffer.
+ serialized = golden_message.SerializeToString()
+ message = unittest_pb2.TestAllTypes()
+ message.ParseFromString(serialized)
+ self.assertTrue(isnan(message.optional_float))
+ self.assertTrue(isnan(message.optional_double))
+ self.assertTrue(isnan(message.repeated_float[0]))
+ self.assertTrue(isnan(message.repeated_double[0]))
+
+ def testPositiveInfinityPacked(self):
+ golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
+ golden_message = unittest_pb2.TestPackedTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(IsPosInf(golden_message.packed_float[0]))
+ self.assertTrue(IsPosInf(golden_message.packed_double[0]))
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+
+ def testNegativeInfinityPacked(self):
+ golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
+ golden_message = unittest_pb2.TestPackedTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(IsNegInf(golden_message.packed_float[0]))
+ self.assertTrue(IsNegInf(golden_message.packed_double[0]))
+ self.assertEqual(golden_data, golden_message.SerializeToString())
+
+ def testNotANumberPacked(self):
+ golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
+ b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
+ golden_message = unittest_pb2.TestPackedTypes()
+ golden_message.ParseFromString(golden_data)
+ self.assertTrue(isnan(golden_message.packed_float[0]))
+ self.assertTrue(isnan(golden_message.packed_double[0]))
+
+ serialized = golden_message.SerializeToString()
+ message = unittest_pb2.TestPackedTypes()
+ message.ParseFromString(serialized)
+ self.assertTrue(isnan(message.packed_float[0]))
+ self.assertTrue(isnan(message.packed_double[0]))
+
+ def testExtremeFloatValues(self):
+ message = unittest_pb2.TestAllTypes()
+
+ # Most positive exponent, no significand bits set.
+ kMostPosExponentNoSigBits = math.pow(2, 127)
+ message.optional_float = kMostPosExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
+
+ # Most positive exponent, one significand bit set.
+ kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
+ message.optional_float = kMostPosExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
+
+ # Repeat last two cases with values of same magnitude, but negative.
+ message.optional_float = -kMostPosExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
+
+ message.optional_float = -kMostPosExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
+
+ # Most negative exponent, no significand bits set.
+ kMostNegExponentNoSigBits = math.pow(2, -127)
+ message.optional_float = kMostNegExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
+
+ # Most negative exponent, one significand bit set.
+ kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
+ message.optional_float = kMostNegExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
+
+ # Repeat last two cases with values of the same magnitude, but negative.
+ message.optional_float = -kMostNegExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
+
+ message.optional_float = -kMostNegExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
+
+ def testExtremeDoubleValues(self):
+ message = unittest_pb2.TestAllTypes()
+
+ # Most positive exponent, no significand bits set.
+ kMostPosExponentNoSigBits = math.pow(2, 1023)
+ message.optional_double = kMostPosExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
+
+ # Most positive exponent, one significand bit set.
+ kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
+ message.optional_double = kMostPosExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
+
+ # Repeat last two cases with values of same magnitude, but negative.
+ message.optional_double = -kMostPosExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
+
+ message.optional_double = -kMostPosExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
+
+ # Most negative exponent, no significand bits set.
+ kMostNegExponentNoSigBits = math.pow(2, -1023)
+ message.optional_double = kMostNegExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
+
+ # Most negative exponent, one significand bit set.
+ kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
+ message.optional_double = kMostNegExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
+
+ # Repeat last two cases with values of the same magnitude, but negative.
+ message.optional_double = -kMostNegExponentNoSigBits
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
+
+ message.optional_double = -kMostNegExponentOneSigBit
+ message.ParseFromString(message.SerializeToString())
+ self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
+
+ def testFloatPrinting(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_float = 2.0
+ self.assertEqual(str(message), 'optional_float: 2.0\n')
+
+ def testHighPrecisionFloatPrinting(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_double = 0.12345678912345678
+ if sys.version_info.major >= 3:
+ self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
+ else:
+ self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
+
+ def testUnknownFieldPrinting(self):
+ populated = unittest_pb2.TestAllTypes()
+ test_util.SetAllNonLazyFields(populated)
+ empty = unittest_pb2.TestEmptyMessage()
+ empty.ParseFromString(populated.SerializeToString())
+ self.assertEqual(str(empty), '')
+
+ def testSortingRepeatedScalarFieldsDefaultComparator(self):
+ """Check some different types with the default comparator."""
+ message = unittest_pb2.TestAllTypes()
+
+ # TODO(mattp): would testing more scalar types strengthen test?
+ message.repeated_int32.append(1)
+ message.repeated_int32.append(3)
+ message.repeated_int32.append(2)
+ message.repeated_int32.sort()
+ self.assertEqual(message.repeated_int32[0], 1)
+ self.assertEqual(message.repeated_int32[1], 2)
+ self.assertEqual(message.repeated_int32[2], 3)
+
+ message.repeated_float.append(1.1)
+ message.repeated_float.append(1.3)
+ message.repeated_float.append(1.2)
+ message.repeated_float.sort()
+ self.assertAlmostEqual(message.repeated_float[0], 1.1)
+ self.assertAlmostEqual(message.repeated_float[1], 1.2)
+ self.assertAlmostEqual(message.repeated_float[2], 1.3)
+
+ message.repeated_string.append('a')
+ message.repeated_string.append('c')
+ message.repeated_string.append('b')
+ message.repeated_string.sort()
+ self.assertEqual(message.repeated_string[0], 'a')
+ self.assertEqual(message.repeated_string[1], 'b')
+ self.assertEqual(message.repeated_string[2], 'c')
+
+ message.repeated_bytes.append(b'a')
+ message.repeated_bytes.append(b'c')
+ message.repeated_bytes.append(b'b')
+ message.repeated_bytes.sort()
+ self.assertEqual(message.repeated_bytes[0], b'a')
+ self.assertEqual(message.repeated_bytes[1], b'b')
+ self.assertEqual(message.repeated_bytes[2], b'c')
+
+ def testSortingRepeatedScalarFieldsCustomComparator(self):
+ """Check some different types with custom comparator."""
+ message = unittest_pb2.TestAllTypes()
+
+ message.repeated_int32.append(-3)
+ message.repeated_int32.append(-2)
+ message.repeated_int32.append(-1)
+ message.repeated_int32.sort(key=abs)
+ self.assertEqual(message.repeated_int32[0], -1)
+ self.assertEqual(message.repeated_int32[1], -2)
+ self.assertEqual(message.repeated_int32[2], -3)
+
+ message.repeated_string.append('aaa')
+ message.repeated_string.append('bb')
+ message.repeated_string.append('c')
+ message.repeated_string.sort(key=len)
+ self.assertEqual(message.repeated_string[0], 'c')
+ self.assertEqual(message.repeated_string[1], 'bb')
+ self.assertEqual(message.repeated_string[2], 'aaa')
+
+ def testSortingRepeatedCompositeFieldsCustomComparator(self):
+ """Check passing a custom comparator to sort a repeated composite field."""
+ message = unittest_pb2.TestAllTypes()
+
+ message.repeated_nested_message.add().bb = 1
+ message.repeated_nested_message.add().bb = 3
+ message.repeated_nested_message.add().bb = 2
+ message.repeated_nested_message.add().bb = 6
+ message.repeated_nested_message.add().bb = 5
+ message.repeated_nested_message.add().bb = 4
+ message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
+ self.assertEqual(message.repeated_nested_message[0].bb, 1)
+ self.assertEqual(message.repeated_nested_message[1].bb, 2)
+ self.assertEqual(message.repeated_nested_message[2].bb, 3)
+ self.assertEqual(message.repeated_nested_message[3].bb, 4)
+ self.assertEqual(message.repeated_nested_message[4].bb, 5)
+ self.assertEqual(message.repeated_nested_message[5].bb, 6)
+
+ def testRepeatedCompositeFieldSortArguments(self):
+ """Check sorting a repeated composite field using list.sort() arguments."""
+ message = unittest_pb2.TestAllTypes()
+
+ get_bb = operator.attrgetter('bb')
+ cmp_bb = lambda a, b: cmp(a.bb, b.bb)
+ message.repeated_nested_message.add().bb = 1
+ message.repeated_nested_message.add().bb = 3
+ message.repeated_nested_message.add().bb = 2
+ message.repeated_nested_message.add().bb = 6
+ message.repeated_nested_message.add().bb = 5
+ message.repeated_nested_message.add().bb = 4
+ message.repeated_nested_message.sort(key=get_bb)
+ self.assertEqual([k.bb for k in message.repeated_nested_message],
+ [1, 2, 3, 4, 5, 6])
+ message.repeated_nested_message.sort(key=get_bb, reverse=True)
+ self.assertEqual([k.bb for k in message.repeated_nested_message],
+ [6, 5, 4, 3, 2, 1])
+ if sys.version_info.major >= 3: return # No cmp sorting in PY3.
+ message.repeated_nested_message.sort(sort_function=cmp_bb)
+ self.assertEqual([k.bb for k in message.repeated_nested_message],
+ [1, 2, 3, 4, 5, 6])
+ message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
+ self.assertEqual([k.bb for k in message.repeated_nested_message],
+ [6, 5, 4, 3, 2, 1])
+
+ def testRepeatedScalarFieldSortArguments(self):
+ """Check sorting a scalar field using list.sort() arguments."""
+ message = unittest_pb2.TestAllTypes()
+
+ message.repeated_int32.append(-3)
+ message.repeated_int32.append(-2)
+ message.repeated_int32.append(-1)
+ message.repeated_int32.sort(key=abs)
+ self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
+ message.repeated_int32.sort(key=abs, reverse=True)
+ self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
+ if sys.version_info.major < 3: # No cmp sorting in PY3.
+ abs_cmp = lambda a, b: cmp(abs(a), abs(b))
+ message.repeated_int32.sort(sort_function=abs_cmp)
+ self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
+ message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
+ self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
+
+ message.repeated_string.append('aaa')
+ message.repeated_string.append('bb')
+ message.repeated_string.append('c')
+ message.repeated_string.sort(key=len)
+ self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
+ message.repeated_string.sort(key=len, reverse=True)
+ self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
+ if sys.version_info.major < 3: # No cmp sorting in PY3.
+ len_cmp = lambda a, b: cmp(len(a), len(b))
+ message.repeated_string.sort(sort_function=len_cmp)
+ self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
+ message.repeated_string.sort(cmp=len_cmp, reverse=True)
+ self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
+
+ def testRepeatedFieldsComparable(self):
+ m1 = unittest_pb2.TestAllTypes()
+ m2 = unittest_pb2.TestAllTypes()
+ m1.repeated_int32.append(0)
+ m1.repeated_int32.append(1)
+ m1.repeated_int32.append(2)
+ m2.repeated_int32.append(0)
+ m2.repeated_int32.append(1)
+ m2.repeated_int32.append(2)
+ m1.repeated_nested_message.add().bb = 1
+ m1.repeated_nested_message.add().bb = 2
+ m1.repeated_nested_message.add().bb = 3
+ m2.repeated_nested_message.add().bb = 1
+ m2.repeated_nested_message.add().bb = 2
+ m2.repeated_nested_message.add().bb = 3
+
+ if sys.version_info.major >= 3: return # No cmp() in PY3.
+
+ # These comparisons should not raise errors.
+ _ = m1 < m2
+ _ = m1.repeated_nested_message < m2.repeated_nested_message
+
+ # Make sure cmp always works. If it wasn't defined, these would be
+ # id() comparisons and would all fail.
+ self.assertEqual(cmp(m1, m2), 0)
+ self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
+ self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
+ self.assertEqual(cmp(m1.repeated_nested_message,
+ m2.repeated_nested_message), 0)
+ with self.assertRaises(TypeError):
+ # Can't compare repeated composite containers to lists.
+ cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
+
+ # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
+
+ def testParsingMerge(self):
+ """Check the merge behavior when a required or optional field appears
+ multiple times in the input."""
+ messages = [
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes(),
+ unittest_pb2.TestAllTypes() ]
+ messages[0].optional_int32 = 1
+ messages[1].optional_int64 = 2
+ messages[2].optional_int32 = 3
+ messages[2].optional_string = 'hello'
+
+ merged_message = unittest_pb2.TestAllTypes()
+ merged_message.optional_int32 = 3
+ merged_message.optional_int64 = 2
+ merged_message.optional_string = 'hello'
+
+ generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
+ generator.field1.extend(messages)
+ generator.field2.extend(messages)
+ generator.field3.extend(messages)
+ generator.ext1.extend(messages)
+ generator.ext2.extend(messages)
+ generator.group1.add().field1.MergeFrom(messages[0])
+ generator.group1.add().field1.MergeFrom(messages[1])
+ generator.group1.add().field1.MergeFrom(messages[2])
+ generator.group2.add().field1.MergeFrom(messages[0])
+ generator.group2.add().field1.MergeFrom(messages[1])
+ generator.group2.add().field1.MergeFrom(messages[2])
+
+ data = generator.SerializeToString()
+ parsing_merge = unittest_pb2.TestParsingMerge()
+ parsing_merge.ParseFromString(data)
+
+ # Required and optional fields should be merged.
+ self.assertEqual(parsing_merge.required_all_types, merged_message)
+ self.assertEqual(parsing_merge.optional_all_types, merged_message)
+ self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
+ merged_message)
+ self.assertEqual(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.optional_ext],
+ merged_message)
+
+ # Repeated fields should not be merged.
+ self.assertEqual(len(parsing_merge.repeated_all_types), 3)
+ self.assertEqual(len(parsing_merge.repeatedgroup), 3)
+ self.assertEqual(len(parsing_merge.Extensions[
+ unittest_pb2.TestParsingMerge.repeated_ext]), 3)
+
+ def ensureNestedMessageExists(self, msg, attribute):
+ """Make sure that a nested message object exists.
+
+ As soon as a nested message attribute is accessed, it will be present in the
+ _fields dict, without being marked as actually being set.
+ """
+ getattr(msg, attribute)
+ self.assertFalse(msg.HasField(attribute))
+
+ def testOneofGetCaseNonexistingField(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
+
+ def testOneofSemantics(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ m.oneof_uint32 = 11
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+
+ m.oneof_string = u'foo'
+ self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertTrue(m.HasField('oneof_string'))
+
+ m.oneof_nested_message.bb = 11
+ self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_string'))
+ self.assertTrue(m.HasField('oneof_nested_message'))
+
+ m.oneof_bytes = b'bb'
+ self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
+ self.assertFalse(m.HasField('oneof_nested_message'))
+ self.assertTrue(m.HasField('oneof_bytes'))
+
+ def testOneofCompositeFieldReadAccess(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+
+ self.ensureNestedMessageExists(m, 'oneof_nested_message')
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+ self.assertEqual(11, m.oneof_uint32)
+
+ def testOneofHasField(self):
+ m = unittest_pb2.TestAllTypes()
+ self.assertFalse(m.HasField('oneof_field'))
+ m.oneof_uint32 = 11
+ self.assertTrue(m.HasField('oneof_field'))
+ m.oneof_bytes = b'bb'
+ self.assertTrue(m.HasField('oneof_field'))
+ m.ClearField('oneof_bytes')
+ self.assertFalse(m.HasField('oneof_field'))
+
+ def testOneofClearField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ m.ClearField('oneof_field')
+ self.assertFalse(m.HasField('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ def testOneofClearSetField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ m.ClearField('oneof_uint32')
+ self.assertFalse(m.HasField('oneof_field'))
+ self.assertFalse(m.HasField('oneof_uint32'))
+ self.assertIs(None, m.WhichOneof('oneof_field'))
+
+ def testOneofClearUnsetField(self):
+ m = unittest_pb2.TestAllTypes()
+ m.oneof_uint32 = 11
+ self.ensureNestedMessageExists(m, 'oneof_nested_message')
+ m.ClearField('oneof_nested_message')
+ self.assertEqual(11, m.oneof_uint32)
+ self.assertTrue(m.HasField('oneof_field'))
+ self.assertTrue(m.HasField('oneof_uint32'))
+ self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
+
+
+
+ def testSortEmptyRepeatedCompositeContainer(self):
+ """Exercise a scenario that has led to segfaults in the past.
+ """
+ m = unittest_pb2.TestAllTypes()
+ m.repeated_nested_message.sort()
+
+ def testHasFieldOnRepeatedField(self):
+ """Using HasField on a repeated field should raise an exception.
+ """
+ m = unittest_pb2.TestAllTypes()
+ with self.assertRaises(ValueError) as _:
+ m.HasField('repeated_int32')
+
+
+class ValidTypeNamesTest(basetest.TestCase):
+
+ def assertImportFromName(self, msg, base_name):
+ # Parse <type 'module.class_name'> to extra 'some.name' as a string.
+ tp_name = str(type(msg)).split("'")[1]
+ valid_names = ('Repeated%sContainer' % base_name,
+ 'Repeated%sFieldContainer' % base_name)
+ self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
+ '%r does end with any of %r' % (tp_name, valid_names))
+
+ parts = tp_name.split('.')
+ class_name = parts[-1]
+ module_name = '.'.join(parts[:-1])
+ __import__(module_name, fromlist=[class_name])
+
+ def testTypeNamesCanBeImported(self):
+ # If import doesn't work, pickling won't work either.
+ pb = unittest_pb2.TestAllTypes()
+ self.assertImportFromName(pb.repeated_int32, 'Scalar')
+ self.assertImportFromName(pb.repeated_nested_message, 'Composite')
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto
new file mode 100644
index 0000000..c9ae58b
--- /dev/null
+++ b/python/google/protobuf/internal/missing_enum_values.proto
@@ -0,0 +1,50 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package google.protobuf.python.internal;
+
+message TestEnumValues {
+ enum NestedEnum {
+ ZERO = 0;
+ ONE = 1;
+ }
+ optional NestedEnum optional_nested_enum = 1;
+ repeated NestedEnum repeated_nested_enum = 2;
+ repeated NestedEnum packed_nested_enum = 3 [packed = true];
+}
+
+message TestMissingEnumValues {
+ enum NestedEnum {
+ TWO = 2;
+ }
+ optional NestedEnum optional_nested_enum = 1;
+ repeated NestedEnum repeated_nested_enum = 2;
+ repeated NestedEnum packed_nested_enum = 3 [packed = true];
+}
diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto
new file mode 100644
index 0000000..df98ac4
--- /dev/null
+++ b/python/google/protobuf/internal/more_extensions_dynamic.proto
@@ -0,0 +1,49 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: jasonh@google.com (Jason Hsueh)
+//
+// This file is used to test a corner case in the CPP implementation where the
+// generated C++ type is available for the extendee, but the extension is
+// defined in a file whose C++ type is not in the binary.
+
+
+import "google/protobuf/internal/more_extensions.proto";
+
+package google.protobuf.internal;
+
+message DynamicMessageType {
+ optional int32 a = 1;
+}
+
+extend ExtendedMessage {
+ optional int32 dynamic_int32_extension = 100;
+ optional DynamicMessageType dynamic_message_extension = 101;
+}
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
new file mode 100755
index 0000000..9ee352d
--- /dev/null
+++ b/python/google/protobuf/internal/python_message.py
@@ -0,0 +1,1247 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# Keep it Python2.5 compatible for GAE.
+#
+# Copyright 2007 Google Inc. All Rights Reserved.
+#
+# This code is meant to work on Python 2.4 and above only.
+#
+# TODO(robinson): Helpers for verbose, common checks like seeing if a
+# descriptor's cpp_type is CPPTYPE_MESSAGE.
+
+"""Contains a metaclass and helper functions used to create
+protocol message classes from Descriptor objects at runtime.
+
+Recall that a metaclass is the "type" of a class.
+(A class is to a metaclass what an instance is to a class.)
+
+In this case, we use the GeneratedProtocolMessageType metaclass
+to inject all the useful functionality into the classes
+output by the protocol compiler at compile-time.
+
+The upshot of all this is that the real implementation
+details for ALL pure-Python protocol buffers are *here in
+this file*.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import sys
+if sys.version_info[0] < 3:
+ try:
+ from cStringIO import StringIO as BytesIO
+ except ImportError:
+ from StringIO import StringIO as BytesIO
+ import copy_reg as copyreg
+else:
+ from io import BytesIO
+ import copyreg
+import struct
+import weakref
+
+# We use "as" to avoid name collisions with variables.
+from google.protobuf.internal import containers
+from google.protobuf.internal import decoder
+from google.protobuf.internal import encoder
+from google.protobuf.internal import enum_type_wrapper
+from google.protobuf.internal import message_listener as message_listener_mod
+from google.protobuf.internal import type_checkers
+from google.protobuf.internal import wire_format
+from google.protobuf import descriptor as descriptor_mod
+from google.protobuf import message as message_mod
+from google.protobuf import text_format
+
+_FieldDescriptor = descriptor_mod.FieldDescriptor
+
+
+def NewMessage(bases, descriptor, dictionary):
+ _AddClassAttributesForNestedExtensions(descriptor, dictionary)
+ _AddSlots(descriptor, dictionary)
+ return bases
+
+
+def InitMessage(descriptor, cls):
+ cls._decoders_by_tag = {}
+ cls._extensions_by_name = {}
+ cls._extensions_by_number = {}
+ if (descriptor.has_options and
+ descriptor.GetOptions().message_set_wire_format):
+ cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+ decoder.MessageSetItemDecoder(cls._extensions_by_number))
+
+ # Attach stuff to each FieldDescriptor for quick lookup later on.
+ for field in descriptor.fields:
+ _AttachFieldHelpers(cls, field)
+
+ _AddEnumValues(descriptor, cls)
+ _AddInitMethod(descriptor, cls)
+ _AddPropertiesForFields(descriptor, cls)
+ _AddPropertiesForExtensions(descriptor, cls)
+ _AddStaticMethods(cls)
+ _AddMessageMethods(descriptor, cls)
+ _AddPrivateHelperMethods(descriptor, cls)
+ copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
+
+
+# Stateless helpers for GeneratedProtocolMessageType below.
+# Outside clients should not access these directly.
+#
+# I opted not to make any of these methods on the metaclass, to make it more
+# clear that I'm not really using any state there and to keep clients from
+# thinking that they have direct access to these construction helpers.
+
+
+def _PropertyName(proto_field_name):
+ """Returns the name of the public property attribute which
+ clients can use to get and (in some cases) set the value
+ of a protocol message field.
+
+ Args:
+ proto_field_name: The protocol message field name, exactly
+ as it appears (or would appear) in a .proto file.
+ """
+ # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
+ # nnorwitz makes my day by writing:
+ # """
+ # FYI. See the keyword module in the stdlib. This could be as simple as:
+ #
+ # if keyword.iskeyword(proto_field_name):
+ # return proto_field_name + "_"
+ # return proto_field_name
+ # """
+ # Kenton says: The above is a BAD IDEA. People rely on being able to use
+ # getattr() and setattr() to reflectively manipulate field values. If we
+ # rename the properties, then every such user has to also make sure to apply
+ # the same transformation. Note that currently if you name a field "yield",
+ # you can still access it just fine using getattr/setattr -- it's not even
+ # that cumbersome to do so.
+ # TODO(kenton): Remove this method entirely if/when everyone agrees with my
+ # position.
+ return proto_field_name
+
+
+def _VerifyExtensionHandle(message, extension_handle):
+ """Verify that the given extension handle is valid."""
+
+ if not isinstance(extension_handle, _FieldDescriptor):
+ raise KeyError('HasExtension() expects an extension handle, got: %s' %
+ extension_handle)
+
+ if not extension_handle.is_extension:
+ raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
+
+ if not extension_handle.containing_type:
+ raise KeyError('"%s" is missing a containing_type.'
+ % extension_handle.full_name)
+
+ if extension_handle.containing_type is not message.DESCRIPTOR:
+ raise KeyError('Extension "%s" extends message type "%s", but this '
+ 'message is of type "%s".' %
+ (extension_handle.full_name,
+ extension_handle.containing_type.full_name,
+ message.DESCRIPTOR.full_name))
+
+
+def _AddSlots(message_descriptor, dictionary):
+ """Adds a __slots__ entry to dictionary, containing the names of all valid
+ attributes for this message type.
+
+ Args:
+ message_descriptor: A Descriptor instance describing this message type.
+ dictionary: Class dictionary to which we'll add a '__slots__' entry.
+ """
+ dictionary['__slots__'] = ['_cached_byte_size',
+ '_cached_byte_size_dirty',
+ '_fields',
+ '_unknown_fields',
+ '_is_present_in_parent',
+ '_listener',
+ '_listener_for_children',
+ '__weakref__',
+ '_oneofs']
+
+
+def _IsMessageSetExtension(field):
+ return (field.is_extension and
+ field.containing_type.has_options and
+ field.containing_type.GetOptions().message_set_wire_format and
+ field.type == _FieldDescriptor.TYPE_MESSAGE and
+ field.message_type == field.extension_scope and
+ field.label == _FieldDescriptor.LABEL_OPTIONAL)
+
+
+def _AttachFieldHelpers(cls, field_descriptor):
+ is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
+ is_packed = (field_descriptor.has_options and
+ field_descriptor.GetOptions().packed)
+
+ if _IsMessageSetExtension(field_descriptor):
+ field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
+ sizer = encoder.MessageSetItemSizer(field_descriptor.number)
+ else:
+ field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed)
+ sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed)
+
+ field_descriptor._encoder = field_encoder
+ field_descriptor._sizer = sizer
+ field_descriptor._default_constructor = _DefaultValueConstructorForField(
+ field_descriptor)
+
+ def AddDecoder(wiretype, is_packed):
+ tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+ cls._decoders_by_tag[tag_bytes] = (
+ type_checkers.TYPE_TO_DECODER[field_descriptor.type](
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor))
+
+ AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
+ False)
+
+ if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
+ # To support wire compatibility of adding packed = true, add a decoder for
+ # packed values regardless of the field's options.
+ AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
+
+
+def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
+ extension_dict = descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ assert extension_name not in dictionary
+ dictionary[extension_name] = extension_field
+
+
+def _AddEnumValues(descriptor, cls):
+ """Sets class-level attributes for all enum fields defined in this message.
+
+ Also exporting a class-level object that can name enum values.
+
+ Args:
+ descriptor: Descriptor object for this message type.
+ cls: Class we're constructing for this message type.
+ """
+ for enum_type in descriptor.enum_types:
+ setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
+ for enum_value in enum_type.values:
+ setattr(cls, enum_value.name, enum_value.number)
+
+
+def _DefaultValueConstructorForField(field):
+ """Returns a function which returns a default value for a field.
+
+ Args:
+ field: FieldDescriptor object for this field.
+
+ The returned function has one argument:
+ message: Message instance containing this field, or a weakref proxy
+ of same.
+
+ That function in turn returns a default value for this field. The default
+ value may refer back to |message| via a weak reference.
+ """
+
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if field.has_default_value and field.default_value != []:
+ raise ValueError('Repeated field default value not empty list: %s' % (
+ field.default_value))
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ # We can't look at _concrete_class yet since it might not have
+ # been set. (Depends on order in which we initialize the classes).
+ message_type = field.message_type
+ def MakeRepeatedMessageDefault(message):
+ return containers.RepeatedCompositeFieldContainer(
+ message._listener_for_children, field.message_type)
+ return MakeRepeatedMessageDefault
+ else:
+ type_checker = type_checkers.GetTypeChecker(field)
+ def MakeRepeatedScalarDefault(message):
+ return containers.RepeatedScalarFieldContainer(
+ message._listener_for_children, type_checker)
+ return MakeRepeatedScalarDefault
+
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ # _concrete_class may not yet be initialized.
+ message_type = field.message_type
+ def MakeSubMessageDefault(message):
+ result = message_type._concrete_class()
+ result._SetListener(message._listener_for_children)
+ return result
+ return MakeSubMessageDefault
+
+ def MakeScalarDefault(message):
+ # TODO(protobuf-team): This may be broken since there may not be
+ # default_value. Combine with has_default_value somehow.
+ return field.default_value
+ return MakeScalarDefault
+
+
+def _AddInitMethod(message_descriptor, cls):
+ """Adds an __init__ method to cls."""
+ fields = message_descriptor.fields
+ def init(self, **kwargs):
+ self._cached_byte_size = 0
+ self._cached_byte_size_dirty = len(kwargs) > 0
+ self._fields = {}
+ # Contains a mapping from oneof field descriptors to the descriptor
+ # of the currently set field in that oneof field.
+ self._oneofs = {}
+
+ # _unknown_fields is () when empty for efficiency, and will be turned into
+ # a list if fields are added.
+ self._unknown_fields = ()
+ self._is_present_in_parent = False
+ self._listener = message_listener_mod.NullMessageListener()
+ self._listener_for_children = _Listener(self)
+ for field_name, field_value in kwargs.iteritems():
+ field = _GetFieldByName(message_descriptor, field_name)
+ if field is None:
+ raise TypeError("%s() got an unexpected keyword argument '%s'" %
+ (message_descriptor.name, field_name))
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ copy = field._default_constructor(self)
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
+ for val in field_value:
+ copy.add().MergeFrom(val)
+ else: # Scalar
+ copy.extend(field_value)
+ self._fields[field] = copy
+ elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ copy = field._default_constructor(self)
+ copy.MergeFrom(field_value)
+ self._fields[field] = copy
+ else:
+ setattr(self, field_name, field_value)
+
+ init.__module__ = None
+ init.__doc__ = None
+ cls.__init__ = init
+
+
+def _GetFieldByName(message_descriptor, field_name):
+ """Returns a field descriptor by field name.
+
+ Args:
+ message_descriptor: A Descriptor describing all fields in message.
+ field_name: The name of the field to retrieve.
+ Returns:
+ The field descriptor associated with the field name.
+ """
+ try:
+ return message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+
+def _AddPropertiesForFields(descriptor, cls):
+ """Adds properties for all fields in this protocol message type."""
+ for field in descriptor.fields:
+ _AddPropertiesForField(field, cls)
+
+ if descriptor.is_extendable:
+ # _ExtensionDict is just an adaptor with no state so we allocate a new one
+ # every time it is accessed.
+ cls.Extensions = property(lambda self: _ExtensionDict(self))
+
+
+def _AddPropertiesForField(field, cls):
+ """Adds a public property for a protocol message field.
+ Clients can use this property to get and (in the case
+ of non-repeated scalar fields) directly set the value
+ of a protocol message field.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ # Catch it if we add other types that we should
+ # handle specially here.
+ assert _FieldDescriptor.MAX_CPPTYPE == 10
+
+ constant_name = field.name.upper() + "_FIELD_NUMBER"
+ setattr(cls, constant_name, field.number)
+
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ _AddPropertiesForRepeatedField(field, cls)
+ elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ _AddPropertiesForNonRepeatedCompositeField(field, cls)
+ else:
+ _AddPropertiesForNonRepeatedScalarField(field, cls)
+
+
+def _AddPropertiesForRepeatedField(field, cls):
+ """Adds a public property for a "repeated" protocol message field. Clients
+ can use this property to get the value of the field, which will be either a
+ _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
+ below).
+
+ Note that when clients add values to these containers, we perform
+ type-checking in the case of repeated scalar fields, and we also set any
+ necessary "has" bits as a side-effect.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ proto_field_name = field.name
+ property_name = _PropertyName(proto_field_name)
+
+ def getter(self):
+ field_value = self._fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = field._default_constructor(self)
+
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ field_value = self._fields.setdefault(field, field_value)
+ return field_value
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ # We define a setter just so we can throw an exception with a more
+ # helpful error message.
+ def setter(self, new_value):
+ raise AttributeError('Assignment not allowed to repeated field '
+ '"%s" in protocol message object.' % proto_field_name)
+
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForNonRepeatedScalarField(field, cls):
+ """Adds a public property for a nonrepeated, scalar protocol message field.
+ Clients can use this property to get and directly set the value of the field.
+ Note that when the client sets the value of a field by using this property,
+ all necessary "has" bits are set as a side-effect, and we also perform
+ type-checking.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ proto_field_name = field.name
+ property_name = _PropertyName(proto_field_name)
+ type_checker = type_checkers.GetTypeChecker(field)
+ default_value = field.default_value
+ valid_values = set()
+
+ def getter(self):
+ # TODO(protobuf-team): This may be broken since there may not be
+ # default_value. Combine with has_default_value somehow.
+ return self._fields.get(field, default_value)
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+ def field_setter(self, new_value):
+ # pylint: disable=protected-access
+ self._fields[field] = type_checker.CheckValue(new_value)
+ # Check _cached_byte_size_dirty inline to improve performance, since scalar
+ # setters are called frequently.
+ if not self._cached_byte_size_dirty:
+ self._Modified()
+
+ if field.containing_oneof is not None:
+ def setter(self, new_value):
+ field_setter(self, new_value)
+ self._UpdateOneofState(field)
+ else:
+ setter = field_setter
+
+ setter.__module__ = None
+ setter.__doc__ = 'Setter for %s.' % proto_field_name
+
+ # Add a property to encapsulate the getter/setter.
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForNonRepeatedCompositeField(field, cls):
+ """Adds a public property for a nonrepeated, composite protocol message field.
+ A composite field is a "group" or "message" field.
+
+ Clients can use this property to get the value of the field, but cannot
+ assign to the property directly.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ # TODO(robinson): Remove duplication with similar method
+ # for non-repeated scalars.
+ proto_field_name = field.name
+ property_name = _PropertyName(proto_field_name)
+
+ # TODO(komarek): Can anyone explain to me why we cache the message_type this
+ # way, instead of referring to field.message_type inside of getter(self)?
+ # What if someone sets message_type later on (which makes for simpler
+ # dyanmic proto descriptor and class creation code).
+ message_type = field.message_type
+
+ def getter(self):
+ field_value = self._fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = message_type._concrete_class() # use field.message_type?
+ field_value._SetListener(
+ _OneofListener(self, field)
+ if field.containing_oneof is not None
+ else self._listener_for_children)
+
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ field_value = self._fields.setdefault(field, field_value)
+ return field_value
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ # We define a setter just so we can throw an exception with a more
+ # helpful error message.
+ def setter(self, new_value):
+ raise AttributeError('Assignment not allowed to composite field '
+ '"%s" in protocol message object.' % proto_field_name)
+
+ # Add a property to encapsulate the getter.
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForExtensions(descriptor, cls):
+ """Adds properties for all fields in this protocol message type."""
+ extension_dict = descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ constant_name = extension_name.upper() + "_FIELD_NUMBER"
+ setattr(cls, constant_name, extension_field.number)
+
+
+def _AddStaticMethods(cls):
+ # TODO(robinson): This probably needs to be thread-safe(?)
+ def RegisterExtension(extension_handle):
+ extension_handle.containing_type = cls.DESCRIPTOR
+ _AttachFieldHelpers(cls, extension_handle)
+
+ # Try to insert our extension, failing if an extension with the same number
+ # already exists.
+ actual_handle = cls._extensions_by_number.setdefault(
+ extension_handle.number, extension_handle)
+ if actual_handle is not extension_handle:
+ raise AssertionError(
+ 'Extensions "%s" and "%s" both try to extend message type "%s" with '
+ 'field number %d.' %
+ (extension_handle.full_name, actual_handle.full_name,
+ cls.DESCRIPTOR.full_name, extension_handle.number))
+
+ cls._extensions_by_name[extension_handle.full_name] = extension_handle
+
+ handle = extension_handle # avoid line wrapping
+ if _IsMessageSetExtension(handle):
+ # MessageSet extension. Also register under type name.
+ cls._extensions_by_name[
+ extension_handle.message_type.full_name] = extension_handle
+
+ cls.RegisterExtension = staticmethod(RegisterExtension)
+
+ def FromString(s):
+ message = cls()
+ message.MergeFromString(s)
+ return message
+ cls.FromString = staticmethod(FromString)
+
+
+def _IsPresent(item):
+ """Given a (FieldDescriptor, value) tuple from _fields, return true if the
+ value should be included in the list returned by ListFields()."""
+
+ if item[0].label == _FieldDescriptor.LABEL_REPEATED:
+ return bool(item[1])
+ elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ return item[1]._is_present_in_parent
+ else:
+ return True
+
+
+def _AddListFieldsMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def ListFields(self):
+ all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
+ all_fields.sort(key = lambda item: item[0].number)
+ return all_fields
+
+ cls.ListFields = ListFields
+
+
+def _AddHasFieldMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ singular_fields = {}
+ for field in message_descriptor.fields:
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ singular_fields[field.name] = field
+ # Fields inside oneofs are never repeated (enforced by the compiler).
+ for field in message_descriptor.oneofs:
+ singular_fields[field.name] = field
+
+ def HasField(self, field_name):
+ try:
+ field = singular_fields[field_name]
+ except KeyError:
+ raise ValueError(
+ 'Protocol message has no singular "%s" field.' % field_name)
+
+ if isinstance(field, descriptor_mod.OneofDescriptor):
+ try:
+ return HasField(self, self._oneofs[field].name)
+ except KeyError:
+ return False
+ else:
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(field)
+ return value is not None and value._is_present_in_parent
+ else:
+ return field in self._fields
+
+ cls.HasField = HasField
+
+
+def _AddClearFieldMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def ClearField(self, field_name):
+ try:
+ field = message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ try:
+ field = message_descriptor.oneofs_by_name[field_name]
+ if field in self._oneofs:
+ field = self._oneofs[field]
+ else:
+ return
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+ if field in self._fields:
+ # Note: If the field is a sub-message, its listener will still point
+ # at us. That's fine, because the worst than can happen is that it
+ # will call _Modified() and invalidate our byte size. Big deal.
+ del self._fields[field]
+
+ if self._oneofs.get(field.containing_oneof, None) is field:
+ del self._oneofs[field.containing_oneof]
+
+ # Always call _Modified() -- even if nothing was changed, this is
+ # a mutating method, and thus calling it should cause the field to become
+ # present in the parent message.
+ self._Modified()
+
+ cls.ClearField = ClearField
+
+
+def _AddClearExtensionMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def ClearExtension(self, extension_handle):
+ _VerifyExtensionHandle(self, extension_handle)
+
+ # Similar to ClearField(), above.
+ if extension_handle in self._fields:
+ del self._fields[extension_handle]
+ self._Modified()
+ cls.ClearExtension = ClearExtension
+
+
+def _AddClearMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def Clear(self):
+ # Clear fields.
+ self._fields = {}
+ self._unknown_fields = ()
+ self._Modified()
+ cls.Clear = Clear
+
+
+def _AddHasExtensionMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def HasExtension(self, extension_handle):
+ _VerifyExtensionHandle(self, extension_handle)
+ if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
+ raise KeyError('"%s" is repeated.' % extension_handle.full_name)
+
+ if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ value = self._fields.get(extension_handle)
+ return value is not None and value._is_present_in_parent
+ else:
+ return extension_handle in self._fields
+ cls.HasExtension = HasExtension
+
+
+def _AddEqualsMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __eq__(self, other):
+ if (not isinstance(other, message_mod.Message) or
+ other.DESCRIPTOR != self.DESCRIPTOR):
+ return False
+
+ if self is other:
+ return True
+
+ if not self.ListFields() == other.ListFields():
+ return False
+
+ # Sort unknown fields because their order shouldn't affect equality test.
+ unknown_fields = list(self._unknown_fields)
+ unknown_fields.sort()
+ other_unknown_fields = list(other._unknown_fields)
+ other_unknown_fields.sort()
+
+ return unknown_fields == other_unknown_fields
+
+ cls.__eq__ = __eq__
+
+
+def _AddStrMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __str__(self):
+ return text_format.MessageToString(self)
+ cls.__str__ = __str__
+
+
+def _AddUnicodeMethod(unused_message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def __unicode__(self):
+ return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
+ cls.__unicode__ = __unicode__
+
+
+def _AddSetListenerMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def SetListener(self, listener):
+ if listener is None:
+ self._listener = message_listener_mod.NullMessageListener()
+ else:
+ self._listener = listener
+ cls._SetListener = SetListener
+
+
+def _BytesForNonRepeatedElement(value, field_number, field_type):
+ """Returns the number of bytes needed to serialize a non-repeated element.
+ The returned byte count includes space for tag information and any
+ other additional space associated with serializing value.
+
+ Args:
+ value: Value we're serializing.
+ field_number: Field number of this value. (Since the field number
+ is stored as part of a varint-encoded tag, this has an impact
+ on the total bytes required to serialize the value).
+ field_type: The type of the field. One of the TYPE_* constants
+ within FieldDescriptor.
+ """
+ try:
+ fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
+ return fn(field_number, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
+
+
+def _AddByteSizeMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def ByteSize(self):
+ if not self._cached_byte_size_dirty:
+ return self._cached_byte_size
+
+ size = 0
+ for field_descriptor, field_value in self.ListFields():
+ size += field_descriptor._sizer(field_value)
+
+ for tag_bytes, value_bytes in self._unknown_fields:
+ size += len(tag_bytes) + len(value_bytes)
+
+ self._cached_byte_size = size
+ self._cached_byte_size_dirty = False
+ self._listener_for_children.dirty = False
+ return size
+
+ cls.ByteSize = ByteSize
+
+
+def _AddSerializeToStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def SerializeToString(self):
+ # Check if the message has all of its required fields set.
+ errors = []
+ if not self.IsInitialized():
+ raise message_mod.EncodeError(
+ 'Message %s is missing required fields: %s' % (
+ self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
+ return self.SerializePartialToString()
+ cls.SerializeToString = SerializeToString
+
+
+def _AddSerializePartialToStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def SerializePartialToString(self):
+ out = BytesIO()
+ self._InternalSerialize(out.write)
+ return out.getvalue()
+ cls.SerializePartialToString = SerializePartialToString
+
+ def InternalSerialize(self, write_bytes):
+ for field_descriptor, field_value in self.ListFields():
+ field_descriptor._encoder(write_bytes, field_value)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ write_bytes(tag_bytes)
+ write_bytes(value_bytes)
+ cls._InternalSerialize = InternalSerialize
+
+
+def _AddMergeFromStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def MergeFromString(self, serialized):
+ length = len(serialized)
+ try:
+ if self._InternalParse(serialized, 0, length) != length:
+ # The only reason _InternalParse would return early is if it
+ # encountered an end-group tag.
+ raise message_mod.DecodeError('Unexpected end-group tag.')
+ except (IndexError, TypeError):
+ # Now ord(buf[p:p+1]) == ord('') gets TypeError.
+ raise message_mod.DecodeError('Truncated message.')
+ except struct.error, e:
+ raise message_mod.DecodeError(e)
+ return length # Return this for legacy reasons.
+ cls.MergeFromString = MergeFromString
+
+ local_ReadTag = decoder.ReadTag
+ local_SkipField = decoder.SkipField
+ decoders_by_tag = cls._decoders_by_tag
+
+ def InternalParse(self, buffer, pos, end):
+ self._Modified()
+ field_dict = self._fields
+ unknown_field_list = self._unknown_fields
+ while pos != end:
+ (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
+ field_decoder = decoders_by_tag.get(tag_bytes)
+ if field_decoder is None:
+ value_start_pos = new_pos
+ new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ if not unknown_field_list:
+ unknown_field_list = self._unknown_fields = []
+ unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
+ pos = new_pos
+ else:
+ pos = field_decoder(buffer, new_pos, end, self, field_dict)
+ return pos
+ cls._InternalParse = InternalParse
+
+
+def _AddIsInitializedMethod(message_descriptor, cls):
+ """Adds the IsInitialized and FindInitializationError methods to the
+ protocol message class."""
+
+ required_fields = [field for field in message_descriptor.fields
+ if field.label == _FieldDescriptor.LABEL_REQUIRED]
+
+ def IsInitialized(self, errors=None):
+ """Checks if all required fields of a message are set.
+
+ Args:
+ errors: A list which, if provided, will be populated with the field
+ paths of all missing required fields.
+
+ Returns:
+ True iff the specified message has all required fields set.
+ """
+
+ # Performance is critical so we avoid HasField() and ListFields().
+
+ for field in required_fields:
+ if (field not in self._fields or
+ (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
+ not self._fields[field]._is_present_in_parent)):
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
+
+ for field, value in list(self._fields.items()): # dict can change size!
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for element in value:
+ if not element.IsInitialized():
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
+ elif value._is_present_in_parent and not value.IsInitialized():
+ if errors is not None:
+ errors.extend(self.FindInitializationErrors())
+ return False
+
+ return True
+
+ cls.IsInitialized = IsInitialized
+
+ def FindInitializationErrors(self):
+ """Finds required fields which are not initialized.
+
+ Returns:
+ A list of strings. Each string is a path to an uninitialized field from
+ the top-level message, e.g. "foo.bar[5].baz".
+ """
+
+ errors = [] # simplify things
+
+ for field in required_fields:
+ if not self.HasField(field.name):
+ errors.append(field.name)
+
+ for field, value in self.ListFields():
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.is_extension:
+ name = "(%s)" % field.full_name
+ else:
+ name = field.name
+
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for i in xrange(len(value)):
+ element = value[i]
+ prefix = "%s[%d]." % (name, i)
+ sub_errors = element.FindInitializationErrors()
+ errors += [ prefix + error for error in sub_errors ]
+ else:
+ prefix = name + "."
+ sub_errors = value.FindInitializationErrors()
+ errors += [ prefix + error for error in sub_errors ]
+
+ return errors
+
+ cls.FindInitializationErrors = FindInitializationErrors
+
+
+def _AddMergeFromMethod(cls):
+ LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
+ CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
+
+ def MergeFrom(self, msg):
+ if not isinstance(msg, cls):
+ raise TypeError(
+ "Parameter to MergeFrom() must be instance of same class: "
+ "expected %s got %s." % (cls.__name__, type(msg).__name__))
+
+ assert msg is not self
+ self._Modified()
+
+ fields = self._fields
+
+ for field, value in msg._fields.iteritems():
+ if field.label == LABEL_REPEATED:
+ field_value = fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = field._default_constructor(self)
+ fields[field] = field_value
+ field_value.MergeFrom(value)
+ elif field.cpp_type == CPPTYPE_MESSAGE:
+ if value._is_present_in_parent:
+ field_value = fields.get(field)
+ if field_value is None:
+ # Construct a new object to represent this field.
+ field_value = field._default_constructor(self)
+ fields[field] = field_value
+ field_value.MergeFrom(value)
+ else:
+ self._fields[field] = value
+
+ if msg._unknown_fields:
+ if not self._unknown_fields:
+ self._unknown_fields = []
+ self._unknown_fields.extend(msg._unknown_fields)
+
+ cls.MergeFrom = MergeFrom
+
+
+def _AddWhichOneofMethod(message_descriptor, cls):
+ def WhichOneof(self, oneof_name):
+ """Returns the name of the currently set field inside a oneof, or None."""
+ try:
+ field = message_descriptor.oneofs_by_name[oneof_name]
+ except KeyError:
+ raise ValueError(
+ 'Protocol message has no oneof "%s" field.' % oneof_name)
+
+ nested_field = self._oneofs.get(field, None)
+ if nested_field is not None and self.HasField(nested_field.name):
+ return nested_field.name
+ else:
+ return None
+
+ cls.WhichOneof = WhichOneof
+
+
+def _AddMessageMethods(message_descriptor, cls):
+ """Adds implementations of all Message methods to cls."""
+ _AddListFieldsMethod(message_descriptor, cls)
+ _AddHasFieldMethod(message_descriptor, cls)
+ _AddClearFieldMethod(message_descriptor, cls)
+ if message_descriptor.is_extendable:
+ _AddClearExtensionMethod(cls)
+ _AddHasExtensionMethod(cls)
+ _AddClearMethod(message_descriptor, cls)
+ _AddEqualsMethod(message_descriptor, cls)
+ _AddStrMethod(message_descriptor, cls)
+ _AddUnicodeMethod(message_descriptor, cls)
+ _AddSetListenerMethod(cls)
+ _AddByteSizeMethod(message_descriptor, cls)
+ _AddSerializeToStringMethod(message_descriptor, cls)
+ _AddSerializePartialToStringMethod(message_descriptor, cls)
+ _AddMergeFromStringMethod(message_descriptor, cls)
+ _AddIsInitializedMethod(message_descriptor, cls)
+ _AddMergeFromMethod(cls)
+ _AddWhichOneofMethod(message_descriptor, cls)
+
+def _AddPrivateHelperMethods(message_descriptor, cls):
+ """Adds implementation of private helper methods to cls."""
+
+ def Modified(self):
+ """Sets the _cached_byte_size_dirty bit to true,
+ and propagates this to our listener iff this was a state change.
+ """
+
+ # Note: Some callers check _cached_byte_size_dirty before calling
+ # _Modified() as an extra optimization. So, if this method is ever
+ # changed such that it does stuff even when _cached_byte_size_dirty is
+ # already true, the callers need to be updated.
+ if not self._cached_byte_size_dirty:
+ self._cached_byte_size_dirty = True
+ self._listener_for_children.dirty = True
+ self._is_present_in_parent = True
+ self._listener.Modified()
+
+ def _UpdateOneofState(self, field):
+ """Sets field as the active field in its containing oneof.
+
+ Will also delete currently active field in the oneof, if it is different
+ from the argument. Does not mark the message as modified.
+ """
+ other_field = self._oneofs.setdefault(field.containing_oneof, field)
+ if other_field is not field:
+ del self._fields[other_field]
+ self._oneofs[field.containing_oneof] = field
+
+ cls._Modified = Modified
+ cls.SetInParent = Modified
+ cls._UpdateOneofState = _UpdateOneofState
+
+
+class _Listener(object):
+
+ """MessageListener implementation that a parent message registers with its
+ child message.
+
+ In order to support semantics like:
+
+ foo.bar.baz.qux = 23
+ assert foo.HasField('bar')
+
+ ...child objects must have back references to their parents.
+ This helper class is at the heart of this support.
+ """
+
+ def __init__(self, parent_message):
+ """Args:
+ parent_message: The message whose _Modified() method we should call when
+ we receive Modified() messages.
+ """
+ # This listener establishes a back reference from a child (contained) object
+ # to its parent (containing) object. We make this a weak reference to avoid
+ # creating cyclic garbage when the client finishes with the 'parent' object
+ # in the tree.
+ if isinstance(parent_message, weakref.ProxyType):
+ self._parent_message_weakref = parent_message
+ else:
+ self._parent_message_weakref = weakref.proxy(parent_message)
+
+ # As an optimization, we also indicate directly on the listener whether
+ # or not the parent message is dirty. This way we can avoid traversing
+ # up the tree in the common case.
+ self.dirty = False
+
+ def Modified(self):
+ if self.dirty:
+ return
+ try:
+ # Propagate the signal to our parents iff this is the first field set.
+ self._parent_message_weakref._Modified()
+ except ReferenceError:
+ # We can get here if a client has kept a reference to a child object,
+ # and is now setting a field on it, but the child's parent has been
+ # garbage-collected. This is not an error.
+ pass
+
+
+class _OneofListener(_Listener):
+ """Special listener implementation for setting composite oneof fields."""
+
+ def __init__(self, parent_message, field):
+ """Args:
+ parent_message: The message whose _Modified() method we should call when
+ we receive Modified() messages.
+ field: The descriptor of the field being set in the parent message.
+ """
+ super(_OneofListener, self).__init__(parent_message)
+ self._field = field
+
+ def Modified(self):
+ """Also updates the state of the containing oneof in the parent message."""
+ try:
+ self._parent_message_weakref._UpdateOneofState(self._field)
+ super(_OneofListener, self).Modified()
+ except ReferenceError:
+ pass
+
+
+# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
+# TODO(robinson): Unify error handling of "unknown extension" crap.
+# TODO(robinson): Support iteritems()-style iteration over all
+# extensions with the "has" bits turned on?
+class _ExtensionDict(object):
+
+ """Dict-like container for supporting an indexable "Extensions"
+ field on proto instances.
+
+ Note that in all cases we expect extension handles to be
+ FieldDescriptors.
+ """
+
+ def __init__(self, extended_message):
+ """extended_message: Message instance for which we are the Extensions dict.
+ """
+
+ self._extended_message = extended_message
+
+ def __getitem__(self, extension_handle):
+ """Returns the current value of the given extension handle."""
+
+ _VerifyExtensionHandle(self._extended_message, extension_handle)
+
+ result = self._extended_message._fields.get(extension_handle)
+ if result is not None:
+ return result
+
+ if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
+ result = extension_handle._default_constructor(self._extended_message)
+ elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ result = extension_handle.message_type._concrete_class()
+ try:
+ result._SetListener(self._extended_message._listener_for_children)
+ except ReferenceError:
+ pass
+ else:
+ # Singular scalar -- just return the default without inserting into the
+ # dict.
+ return extension_handle.default_value
+
+ # Atomically check if another thread has preempted us and, if not, swap
+ # in the new object we just created. If someone has preempted us, we
+ # take that object and discard ours.
+ # WARNING: We are relying on setdefault() being atomic. This is true
+ # in CPython but we haven't investigated others. This warning appears
+ # in several other locations in this file.
+ result = self._extended_message._fields.setdefault(
+ extension_handle, result)
+
+ return result
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+
+ my_fields = self._extended_message.ListFields()
+ other_fields = other._extended_message.ListFields()
+
+ # Get rid of non-extension fields.
+ my_fields = [ field for field in my_fields if field.is_extension ]
+ other_fields = [ field for field in other_fields if field.is_extension ]
+
+ return my_fields == other_fields
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
+ # Note that this is only meaningful for non-repeated, scalar extension
+ # fields. Note also that we may have to call _Modified() when we do
+ # successfully set a field this way, to set any necssary "has" bits in the
+ # ancestors of the extended message.
+ def __setitem__(self, extension_handle, value):
+ """If extension_handle specifies a non-repeated, scalar extension
+ field, sets the value of that field.
+ """
+
+ _VerifyExtensionHandle(self._extended_message, extension_handle)
+
+ if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
+ extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
+ raise TypeError(
+ 'Cannot assign to extension "%s" because it is a repeated or '
+ 'composite type.' % extension_handle.full_name)
+
+ # It's slightly wasteful to lookup the type checker each time,
+ # but we expect this to be a vanishingly uncommon case anyway.
+ type_checker = type_checkers.GetTypeChecker(
+ extension_handle)
+ # pylint: disable=protected-access
+ self._extended_message._fields[extension_handle] = (
+ type_checker.CheckValue(value))
+ self._extended_message._Modified()
+
+ def _FindExtensionByName(self, name):
+ """Tries to find a known extension with the specified name.
+
+ Args:
+ name: Extension full name.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extended_message._extensions_by_name.get(name, None)
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 2c9fa30..b3c414c 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -37,12 +37,12 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)'
+import copy
+import gc
import operator
import struct
-import unittest
-# TODO(robinson): When we split this test in two, only some of these imports
-# will be necessary in each test.
+from google.apputils import basetest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -50,6 +50,8 @@ from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
+from google.protobuf import text_format
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import wire_format
@@ -102,12 +104,12 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(unittest.TestCase):
+class ReflectionTest(basetest.TestCase):
- def assertIs(self, values, others):
+ def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
for i in range(len(values)):
- self.assertTrue(values[i] is others[i])
+ self.assertEqual(values[i], others[i])
def testScalarConstructor(self):
# Constructor with only scalar types should succeed.
@@ -200,6 +202,41 @@ class ReflectionTest(unittest.TestCase):
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
+ def testConstructorTypeError(self):
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
+
+ def testConstructorInvalidatesCachedByteSize(self):
+ message = unittest_pb2.TestAllTypes(optional_int32 = 12)
+ self.assertEquals(2, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(
+ optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertEquals(3, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
+ self.assertEquals(3, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(
+ repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
+ self.assertEquals(3, message.ByteSize())
+
def testSimpleHasBits(self):
# Test a scalar.
proto = unittest_pb2.TestAllTypes()
@@ -284,12 +321,6 @@ class ReflectionTest(unittest.TestCase):
# ...and ensure that the scalar field has returned to its default.
self.assertEqual(0, getattr(composite_field, scalar_field_name))
- # Finally, ensure that modifications to the old composite field object
- # don't have any effect on the parent.
- #
- # (NOTE that when we clear the composite field in the parent, we actually
- # don't recursively clear down the tree. Instead, we just disconnect the
- # cleared composite from the tree.)
self.assertTrue(old_composite_field is not composite_field)
setattr(old_composite_field, scalar_field_name, new_val)
self.assertTrue(not composite_field.HasField(scalar_field_name))
@@ -319,6 +350,64 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(not proto.HasField('optional_nested_message'))
self.assertEqual(0, proto.optional_nested_message.bb)
+ def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ proto.ClearField('optional_nested_message')
+ del proto
+ del nested
+ # Force a garbage collect so that the underlying CMessages are freed along
+ # with the Messages they point to. This is to make sure we're not deleting
+ # default message instances.
+ gc.collect()
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+
+ def testDisconnectingNestedMessageAfterSettingField(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ nested.bb = 5
+ self.assertTrue(proto.HasField('optional_nested_message'))
+ proto.ClearField('optional_nested_message') # Should disconnect from parent
+ self.assertEqual(5, nested.bb)
+ self.assertEqual(0, proto.optional_nested_message.bb)
+ self.assertTrue(nested is not proto.optional_nested_message)
+ nested.bb = 23
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+
+ def testDisconnectingNestedMessageBeforeGettingField(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ proto.ClearField('optional_nested_message')
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+
+ def testDisconnectingNestedMessageAfterMerge(self):
+ # This test exercises the code path that does not use ReleaseMessage().
+ # The underlying fear is that if we use ReleaseMessage() incorrectly,
+ # we will have memory leaks. It's hard to check that that doesn't happen,
+ # but at least we can exercise that code path to make sure it works.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.optional_nested_message.bb = 5
+ proto1.MergeFrom(proto2)
+ self.assertTrue(proto1.HasField('optional_nested_message'))
+ proto1.ClearField('optional_nested_message')
+ self.assertTrue(not proto1.HasField('optional_nested_message'))
+
+ def testDisconnectingLazyNestedMessage(self):
+ # This test exercises releasing a nested message that is lazy. This test
+ # only exercises real code in the C++ implementation as Python does not
+ # support lazy parsing, but the current C++ implementation results in
+ # memory corruption and a crash.
+ if api_implementation.Type() != 'python':
+ return
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_lazy_message.bb = 5
+ proto.ClearField('optional_lazy_message')
+ del proto
+ gc.collect()
+
def testHasBitsWhenModifyingRepeatedFields(self):
# Test nesting when we add an element to a repeated field in a submessage.
proto = unittest_pb2.TestNestedMessageHasBits()
@@ -446,7 +535,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(0.0, proto.optional_double)
self.assertEqual(False, proto.optional_bool)
self.assertEqual('', proto.optional_string)
- self.assertEqual('', proto.optional_bytes)
+ self.assertEqual(b'', proto.optional_bytes)
self.assertEqual(41, proto.default_int32)
self.assertEqual(42, proto.default_int64)
@@ -462,7 +551,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(52e3, proto.default_double)
self.assertEqual(True, proto.default_bool)
self.assertEqual('hello', proto.default_string)
- self.assertEqual('world', proto.default_bytes)
+ self.assertEqual(b'world', proto.default_bytes)
self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_BAR,
@@ -479,6 +568,17 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
+ def testClearRemovesChildren(self):
+ # Make sure there aren't any implementation bugs that are only partially
+ # clearing the message (which can happen in the more complex C++
+ # implementation which has parallel message lists).
+ proto = unittest_pb2.TestRequiredForeign()
+ for i in range(10):
+ proto.repeated_message.add()
+ proto2 = unittest_pb2.TestRequiredForeign()
+ proto.CopyFrom(proto2)
+ self.assertRaises(IndexError, lambda: proto.repeated_message[5])
+
def testDisallowedAssignments(self):
# It's illegal to assign values directly to repeated fields
# or to nonrepeated composite fields. Ensure that this fails.
@@ -500,7 +600,6 @@ class ReflectionTest(unittest.TestCase):
# proto.nonexistent_field = 23 should fail as well.
self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
- # TODO(robinson): Add type-safety check for enums.
def testSingleScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
@@ -508,11 +607,37 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+ def testIntegerTypes(self):
+ def TestGetAndDeserialize(field_name, value, expected_type):
+ proto = unittest_pb2.TestAllTypes()
+ setattr(proto, field_name, value)
+ self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.ParseFromString(proto.SerializeToString())
+ self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
+
+ TestGetAndDeserialize('optional_int32', 1, int)
+ TestGetAndDeserialize('optional_int32', 1 << 30, int)
+ TestGetAndDeserialize('optional_uint32', 1 << 30, int)
+ if struct.calcsize('L') == 4:
+ # Python only has signed ints, so 32-bit python can't fit an uint32
+ # in an int.
+ TestGetAndDeserialize('optional_uint32', 1 << 31, long)
+ else:
+ # 64-bit python can fit uint32 inside an int
+ TestGetAndDeserialize('optional_uint32', 1 << 31, int)
+ TestGetAndDeserialize('optional_int64', 1 << 30, long)
+ TestGetAndDeserialize('optional_int64', 1 << 60, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 30, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 60, long)
+
def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
setattr(pb, field_name, expected_min)
+ self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
+ self.assertEqual(expected_max, getattr(pb, field_name))
self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
@@ -520,7 +645,10 @@ class ReflectionTest(unittest.TestCase):
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
- TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
+
+ pb = unittest_pb2.TestAllTypes()
+ pb.optional_nested_enum = 1
+ self.assertEqual(1, pb.optional_nested_enum)
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
@@ -534,11 +662,19 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
+ # Repeated enums tests.
+ #proto.repeated_nested_enum.append(0)
+
def testSingleScalarGettersAndSetters(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(0, proto.optional_int32)
proto.optional_int32 = 1
self.assertEqual(1, proto.optional_int32)
+
+ proto.optional_uint64 = 0xffffffffffff
+ self.assertEqual(0xffffffffffff, proto.optional_uint64)
+ proto.optional_uint64 = 0xffffffffffffffff
+ self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
# TODO(robinson): Test all other scalar field types.
def testSingleScalarClearField(self):
@@ -561,6 +697,77 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
+ def testEnum_Name(self):
+ self.assertEqual('FOREIGN_FOO',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
+ self.assertEqual('FOREIGN_BAR',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
+ self.assertEqual('FOREIGN_BAZ',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
+ self.assertRaises(ValueError,
+ unittest_pb2.ForeignEnum.Name, 11312)
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual('FOO',
+ proto.NestedEnum.Name(proto.FOO))
+ self.assertEqual('FOO',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
+ self.assertEqual('BAR',
+ proto.NestedEnum.Name(proto.BAR))
+ self.assertEqual('BAR',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
+ self.assertEqual('BAZ',
+ proto.NestedEnum.Name(proto.BAZ))
+ self.assertEqual('BAZ',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
+ self.assertRaises(ValueError,
+ proto.NestedEnum.Name, 11312)
+ self.assertRaises(ValueError,
+ unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
+
+ def testEnum_Value(self):
+ self.assertEqual(unittest_pb2.FOREIGN_FOO,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
+ self.assertEqual(unittest_pb2.FOREIGN_BAR,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
+ self.assertRaises(ValueError,
+ unittest_pb2.ForeignEnum.Value, 'FO')
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(proto.FOO,
+ proto.NestedEnum.Value('FOO'))
+ self.assertEqual(proto.FOO,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
+ self.assertEqual(proto.BAR,
+ proto.NestedEnum.Value('BAR'))
+ self.assertEqual(proto.BAR,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
+ self.assertEqual(proto.BAZ,
+ proto.NestedEnum.Value('BAZ'))
+ self.assertEqual(proto.BAZ,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
+ self.assertRaises(ValueError,
+ proto.NestedEnum.Value, 'Foo')
+ self.assertRaises(ValueError,
+ unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
+
+ def testEnum_KeysAndValues(self):
+ self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
+ unittest_pb2.ForeignEnum.keys())
+ self.assertEqual([4, 5, 6],
+ unittest_pb2.ForeignEnum.values())
+ self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
+ ('FOREIGN_BAZ', 6)],
+ unittest_pb2.ForeignEnum.items())
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
+ self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
+ self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
+ proto.NestedEnum.items())
+
def testRepeatedScalars(self):
proto = unittest_pb2.TestAllTypes()
@@ -619,11 +826,38 @@ class ReflectionTest(unittest.TestCase):
del proto.repeated_int32[2:]
self.assertEqual([5, 35], proto.repeated_int32)
+ # Test extending.
+ proto.repeated_int32.extend([3, 13])
+ self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
+
# Test clearing.
proto.ClearField('repeated_int32')
self.assertTrue(not proto.repeated_int32)
self.assertEqual(0, len(proto.repeated_int32))
+ proto.repeated_int32.append(1)
+ self.assertEqual(1, proto.repeated_int32[-1])
+ # Test assignment to a negative index.
+ proto.repeated_int32[-1] = 2
+ self.assertEqual(2, proto.repeated_int32[-1])
+
+ # Test deletion at negative indices.
+ proto.repeated_int32[:] = [0, 1, 2, 3]
+ del proto.repeated_int32[-1]
+ self.assertEqual([0, 1, 2], proto.repeated_int32)
+
+ del proto.repeated_int32[-2]
+ self.assertEqual([0, 2], proto.repeated_int32)
+
+ self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
+ self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
+
+ del proto.repeated_int32[-2:-1]
+ self.assertEqual([2], proto.repeated_int32)
+
+ del proto.repeated_int32[100:10000]
+ self.assertEqual([2], proto.repeated_int32)
+
def testRepeatedScalarsRemove(self):
proto = unittest_pb2.TestAllTypes()
@@ -661,7 +895,7 @@ class ReflectionTest(unittest.TestCase):
m1 = proto.repeated_nested_message.add()
self.assertTrue(proto.repeated_nested_message)
self.assertEqual(2, len(proto.repeated_nested_message))
- self.assertIs([m0, m1], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1], proto.repeated_nested_message)
self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
# Test out-of-bounds indices.
@@ -680,32 +914,86 @@ class ReflectionTest(unittest.TestCase):
m2 = proto.repeated_nested_message.add()
m3 = proto.repeated_nested_message.add()
m4 = proto.repeated_nested_message.add()
- self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
- self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
+ self.assertListsEqual(
+ [m1, m2, m3], proto.repeated_nested_message[1:4])
+ self.assertListsEqual(
+ [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
+ self.assertListsEqual(
+ [m0, m1], proto.repeated_nested_message[:2])
+ self.assertListsEqual(
+ [m2, m3, m4], proto.repeated_nested_message[2:])
+ self.assertEqual(
+ m0, proto.repeated_nested_message[0])
+ self.assertListsEqual(
+ [m0], proto.repeated_nested_message[:1])
# Test that we can use the field as an iterator.
result = []
for i in proto.repeated_nested_message:
result.append(i)
- self.assertIs([m0, m1, m2, m3, m4], result)
+ self.assertListsEqual([m0, m1, m2, m3, m4], result)
# Test single deletion.
del proto.repeated_nested_message[2]
- self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
# Test slice deletion.
del proto.repeated_nested_message[2:]
- self.assertIs([m0, m1], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1], proto.repeated_nested_message)
+
+ # Test extending.
+ n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
+ n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
+ proto.repeated_nested_message.extend([n1,n2])
+ self.assertEqual(4, len(proto.repeated_nested_message))
+ self.assertEqual(n1, proto.repeated_nested_message[2])
+ self.assertEqual(n2, proto.repeated_nested_message[3])
# Test clearing.
proto.ClearField('repeated_nested_message')
self.assertTrue(not proto.repeated_nested_message)
self.assertEqual(0, len(proto.repeated_nested_message))
+ # Test constructing an element while adding it.
+ proto.repeated_nested_message.add(bb=23)
+ self.assertEqual(1, len(proto.repeated_nested_message))
+ self.assertEqual(23, proto.repeated_nested_message[0].bb)
+
+ def testRepeatedCompositeRemove(self):
+ proto = unittest_pb2.TestAllTypes()
+
+ self.assertEqual(0, len(proto.repeated_nested_message))
+ m0 = proto.repeated_nested_message.add()
+ # Need to set some differentiating variable so m0 != m1 != m2:
+ m0.bb = len(proto.repeated_nested_message)
+ m1 = proto.repeated_nested_message.add()
+ m1.bb = len(proto.repeated_nested_message)
+ self.assertTrue(m0 != m1)
+ m2 = proto.repeated_nested_message.add()
+ m2.bb = len(proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
+
+ self.assertEqual(3, len(proto.repeated_nested_message))
+ proto.repeated_nested_message.remove(m0)
+ self.assertEqual(2, len(proto.repeated_nested_message))
+ self.assertEqual(m1, proto.repeated_nested_message[0])
+ self.assertEqual(m2, proto.repeated_nested_message[1])
+
+ # Removing m0 again or removing None should raise error
+ self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
+ self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
+ self.assertEqual(2, len(proto.repeated_nested_message))
+
+ proto.repeated_nested_message.remove(m2)
+ self.assertEqual(1, len(proto.repeated_nested_message))
+ self.assertEqual(m1, proto.repeated_nested_message[0])
+
def testHandWrittenReflection(self):
- # TODO(robinson): We probably need a better way to specify
- # protocol types by hand. But then again, this isn't something
- # we expect many people to do. Hmm.
+ # Hand written extensions are only supported by the pure-Python
+ # implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
FieldDescriptor = descriptor.FieldDescriptor
foo_field_descriptor = FieldDescriptor(
name='foo_field', full_name='MyProto.foo_field',
@@ -730,6 +1018,68 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
+ def testDescriptorProtoSupport(self):
+ # Hand written descriptors/reflection are only supported by the pure-Python
+ # implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
+ def AddDescriptorField(proto, field_name, field_type):
+ AddDescriptorField.field_index += 1
+ new_field = proto.field.add()
+ new_field.name = field_name
+ new_field.type = field_type
+ new_field.number = AddDescriptorField.field_index
+ new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+
+ AddDescriptorField.field_index = 0
+
+ desc_proto = descriptor_pb2.DescriptorProto()
+ desc_proto.name = 'Car'
+ fdp = descriptor_pb2.FieldDescriptorProto
+ AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
+ AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
+ AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
+ AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
+ # Add a repeated field
+ AddDescriptorField.field_index += 1
+ new_field = desc_proto.field.add()
+ new_field.name = 'owners'
+ new_field.type = fdp.TYPE_STRING
+ new_field.number = AddDescriptorField.field_index
+ new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
+
+ desc = descriptor.MakeDescriptor(desc_proto)
+ self.assertTrue(desc.fields_by_name.has_key('name'))
+ self.assertTrue(desc.fields_by_name.has_key('year'))
+ self.assertTrue(desc.fields_by_name.has_key('automatic'))
+ self.assertTrue(desc.fields_by_name.has_key('price'))
+ self.assertTrue(desc.fields_by_name.has_key('owners'))
+
+ class CarMessage(message.Message):
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ DESCRIPTOR = desc
+
+ prius = CarMessage()
+ prius.name = 'prius'
+ prius.year = 2010
+ prius.automatic = True
+ prius.price = 25134.75
+ prius.owners.extend(['bob', 'susan'])
+
+ serialized_prius = prius.SerializeToString()
+ new_prius = reflection.ParseMessage(desc, serialized_prius)
+ self.assertTrue(new_prius is not prius)
+ self.assertEqual(prius, new_prius)
+
+ # these are unnecessary assuming message equality works as advertised but
+ # explicitly check to be safe since we're mucking about in metaclass foo
+ self.assertEqual(prius.name, new_prius.name)
+ self.assertEqual(prius.year, new_prius.year)
+ self.assertEqual(prius.automatic, new_prius.automatic)
+ self.assertEqual(prius.price, new_prius.price)
+ self.assertEqual(prius.owners, new_prius.owners)
+
def testTopLevelExtensionsForOptionalScalar(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.optional_int32_extension
@@ -819,6 +1169,14 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(required is not extendee_proto.Extensions[extension])
self.assertTrue(not extendee_proto.HasExtension(extension))
+ def testRegisteredExtensions(self):
+ self.assertTrue('protobuf_unittest.optional_int32_extension' in
+ unittest_pb2.TestAllExtensions._extensions_by_name)
+ self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ # Make sure extensions haven't been registered into types that shouldn't
+ # have any.
+ self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
# extension in B should change a.HasField('b') to True
@@ -868,7 +1226,7 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(not toplevel.HasField('submessage'))
foreign = toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension].add()
- self.assertTrue(foreign is toplevel.submessage.Extensions[
+ self.assertEqual(foreign, toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension][0])
self.assertTrue(toplevel.HasField('submessage'))
@@ -971,6 +1329,12 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(123, proto2.repeated_nested_message[1].bb)
self.assertEqual(321, proto2.repeated_nested_message[2].bb)
+ proto3 = unittest_pb2.TestAllTypes()
+ proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
+ self.assertEqual(999, proto3.repeated_nested_message[0].bb)
+ self.assertEqual(123, proto3.repeated_nested_message[1].bb)
+ self.assertEqual(321, proto3.repeated_nested_message[2].bb)
+
def testMergeFromAllFields(self):
# With all fields set.
proto1 = unittest_pb2.TestAllTypes()
@@ -1035,6 +1399,19 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(222, ext2[1].bb)
self.assertEqual(333, ext2[2].bb)
+ def testMergeFromBug(self):
+ message1 = unittest_pb2.TestAllTypes()
+ message2 = unittest_pb2.TestAllTypes()
+
+ # Cause optional_nested_message to be instantiated within message1, even
+ # though it is not considered to be "present".
+ message1.optional_nested_message
+ self.assertFalse(message1.HasField('optional_nested_message'))
+
+ # Merge into message2. This should not instantiate the field is message2.
+ message2.MergeFrom(message1)
+ self.assertFalse(message2.HasField('optional_nested_message'))
+
def testCopyFromSingularField(self):
# Test copy with just a singular field.
proto1 = unittest_pb2.TestAllTypes()
@@ -1087,9 +1464,36 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(2, proto1.optional_int32)
self.assertEqual('important-text', proto1.optional_string)
+ def testCopyFromBadType(self):
+ # The python implementation doesn't raise an exception in this
+ # case. In theory it should.
+ if api_implementation.Type() == 'python':
+ return
+ proto1 = unittest_pb2.TestAllTypes()
+ proto2 = unittest_pb2.TestAllExtensions()
+ self.assertRaises(TypeError, proto1.CopyFrom, proto2)
+
+ def testDeepCopy(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optional_int32 = 1
+ proto2 = copy.deepcopy(proto1)
+ self.assertEqual(1, proto2.optional_int32)
+
+ proto1.repeated_int32.append(2)
+ proto1.repeated_int32.append(3)
+ container = copy.deepcopy(proto1.repeated_int32)
+ self.assertEqual([2, 3], container)
+
+ # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
+
def testClear(self):
proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto)
+ # C++ implementation does not support lazy fields right now so leave it
+ # out for now.
+ if api_implementation.Type() == 'python':
+ test_util.SetAllFields(proto)
+ else:
+ test_util.SetAllNonLazyFields(proto)
# Clear the message.
proto.Clear()
self.assertEquals(proto.ByteSize(), 0)
@@ -1105,6 +1509,45 @@ class ReflectionTest(unittest.TestCase):
empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto)
+ def testDisconnectingBeforeClear(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ proto.Clear()
+ self.assertTrue(nested is not proto.optional_nested_message)
+ nested.bb = 23
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ nested.bb = 5
+ foreign = proto.optional_foreign_message
+ foreign.c = 6
+
+ proto.Clear()
+ self.assertTrue(nested is not proto.optional_nested_message)
+ self.assertTrue(foreign is not proto.optional_foreign_message)
+ self.assertEqual(5, nested.bb)
+ self.assertEqual(6, foreign.c)
+ nested.bb = 15
+ foreign.c = 16
+ self.assertFalse(proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+ self.assertFalse(proto.HasField('optional_foreign_message'))
+ self.assertEqual(0, proto.optional_foreign_message.c)
+
+ def testOneOf(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.oneof_uint32 = 10
+ proto.oneof_nested_message.bb = 11
+ self.assertEqual(11, proto.oneof_nested_message.bb)
+ self.assertFalse(proto.HasField('oneof_uint32'))
+ nested = proto.oneof_nested_message
+ proto.oneof_string = 'abc'
+ self.assertEqual('abc', proto.oneof_string)
+ self.assertEqual(11, nested.bb)
+ self.assertFalse(proto.HasField('oneof_nested_message'))
+
def assertInitialized(self, proto):
self.assertTrue(proto.IsInitialized())
# Neither method should raise an exception.
@@ -1175,6 +1618,40 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Errors are only available from the most recent C++ implementation.')
+ def testFileDescriptorErrors(self):
+ file_name = 'test_file_descriptor_errors.proto'
+ package_name = 'test_file_descriptor_errors.proto'
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.name = file_name
+ file_descriptor_proto.package = package_name
+ m1 = file_descriptor_proto.message_type.add()
+ m1.name = 'msg1'
+ # Compiles the proto into the C++ descriptor pool
+ descriptor.FileDescriptor(
+ file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ # Add a FileDescriptorProto that has duplicate symbols
+ another_file_name = 'another_test_file_descriptor_errors.proto'
+ file_descriptor_proto.name = another_file_name
+ m2 = file_descriptor_proto.message_type.add()
+ m2.name = 'msg2'
+ with self.assertRaises(TypeError) as cm:
+ descriptor.FileDescriptor(
+ another_file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
+ getattr(cm.expected, '__name__', cm.expected))
+ self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
+ # Error message will say something about this definition being a
+ # duplicate, though we don't check the message exactly to avoid a
+ # dependency on the C++ logging code.
+ self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
+
def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes()
@@ -1192,16 +1669,15 @@ class ReflectionTest(unittest.TestCase):
proto.optional_string = str('Testing')
self.assertEqual(proto.optional_string, unicode('Testing'))
- # Values of type 'str' are also accepted as long as they can be encoded in
- # UTF-8.
- self.assertEqual(type(proto.optional_string), str)
-
# Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
self.assertRaises(ValueError,
- setattr, proto, 'optional_string', str('a\x80a'))
- # Assign a 'str' object which contains a UTF-8 encoded string.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', 'Тест')
+ setattr, proto, 'optional_string', b'a\x80a')
+ if str is bytes: # PY2
+ # Assign a 'str' object which contains a UTF-8 encoded string.
+ self.assertRaises(ValueError,
+ setattr, proto, 'optional_string', 'Тест')
+ else:
+ proto.optional_string = 'Тест'
# No exception thrown.
proto.optional_string = 'abc'
@@ -1224,7 +1700,8 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(proto.ByteSize(), len(serialized))
raw = unittest_mset_pb2.RawMessageSet()
- raw.MergeFromString(serialized)
+ bytes_read = raw.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_read)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
@@ -1232,18 +1709,37 @@ class ReflectionTest(unittest.TestCase):
# Check that the type_id is the same as the tag ID in the .proto file.
self.assertEqual(raw.item[0].type_id, 1547769)
- # Check the actually bytes on the wire.
+ # Check the actual bytes on the wire.
self.assertTrue(
raw.item[0].message.endswith(test_utf8_bytes))
- message2.MergeFromString(raw.item[0].message)
+ bytes_read = message2.MergeFromString(raw.item[0].message)
+ self.assertEqual(len(raw.item[0].message), bytes_read)
self.assertEqual(type(message2.str), unicode)
self.assertEqual(message2.str, test_utf8)
- # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
- bytes = raw.item[0].message.replace(
- test_utf8_bytes, len(test_utf8_bytes) * '\xff')
- self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
+ # The pure Python API throws an exception on MergeFromString(),
+ # if any of the string fields of the message can't be UTF-8 decoded.
+ # The C++ implementation of the API has no way to check that on
+ # MergeFromString and thus has no way to throw the exception.
+ #
+ # The pure Python API always returns objects of type 'unicode' (UTF-8
+ # encoded), or 'bytes' (in 7 bit ASCII).
+ badbytes = raw.item[0].message.replace(
+ test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
+
+ unicode_decode_failed = False
+ try:
+ message2.MergeFromString(badbytes)
+ except UnicodeDecodeError:
+ unicode_decode_failed = True
+ string_field = message2.str
+ self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
+
+ def testBytesInTextFormat(self):
+ proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
+ self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
+ unicode(proto))
def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
@@ -1257,16 +1753,19 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.MergeFromString('')
+ bytes_read = proto.optional_nested_message.MergeFromString(b'')
+ self.assertEqual(0, bytes_read)
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.ParseFromString('')
+ proto.optional_nested_message.ParseFromString(b'')
self.assertTrue(proto.HasField('optional_nested_message'))
serialized = proto.SerializeToString()
proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertTrue(proto2.HasField('optional_nested_message'))
def testSetInParent(self):
@@ -1280,12 +1779,15 @@ class ReflectionTest(unittest.TestCase):
# into separate TestCase classes.
-class TestAllTypesEqualityTest(unittest.TestCase):
+class TestAllTypesEqualityTest(basetest.TestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
self.second_proto = unittest_pb2.TestAllTypes()
+ def testNotHashable(self):
+ self.assertRaises(TypeError, hash, self.first_proto)
+
def testSelfEquality(self):
self.assertEqual(self.first_proto, self.first_proto)
@@ -1293,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(unittest.TestCase):
+class FullProtosEqualityTest(basetest.TestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1303,6 +1805,9 @@ class FullProtosEqualityTest(unittest.TestCase):
test_util.SetAllFields(self.first_proto)
test_util.SetAllFields(self.second_proto)
+ def testNotHashable(self):
+ self.assertRaises(TypeError, hash, self.first_proto)
+
def testNoneNotEqual(self):
self.assertNotEqual(self.first_proto, None)
self.assertNotEqual(None, self.second_proto)
@@ -1371,15 +1876,12 @@ class FullProtosEqualityTest(unittest.TestCase):
self.first_proto.ClearField('optional_nested_message')
self.second_proto.optional_nested_message.ClearField('bb')
self.assertNotEqual(self.first_proto, self.second_proto)
- # TODO(robinson): Replace next two lines with method
- # to set the "has" bit without changing the value,
- # if/when such a method exists.
self.first_proto.optional_nested_message.bb = 0
self.first_proto.optional_nested_message.ClearField('bb')
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(unittest.TestCase):
+class ExtensionEqualityTest(basetest.TestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1412,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(unittest.TestCase):
+class MutualRecursionEqualityTest(basetest.TestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1424,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(unittest.TestCase):
+class ByteSizeTest(basetest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -1438,6 +1940,14 @@ class ByteSizeTest(unittest.TestCase):
def testEmptyMessage(self):
self.assertEqual(0, self.proto.ByteSize())
+ def testSizedOnKwargs(self):
+ # Use a separate message to ensure testing right after creation.
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(0, proto.ByteSize())
+ proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
+ # One byte for the tag, one to encode varint 1.
+ self.assertEqual(2, proto_kwargs.ByteSize())
+
def testVarints(self):
def Test(i, expected_varint_size):
self.proto.Clear()
@@ -1629,10 +2139,13 @@ class ByteSizeTest(unittest.TestCase):
self.assertEqual(3, self.proto.ByteSize())
self.proto.ClearField('optional_foreign_message')
self.assertEqual(0, self.proto.ByteSize())
- child = self.proto.optional_foreign_message
- self.proto.ClearField('optional_foreign_message')
- child.c = 128
- self.assertEqual(0, self.proto.ByteSize())
+
+ if api_implementation.Type() == 'python':
+ # This is only possible in pure-Python implementation of the API.
+ child = self.proto.optional_foreign_message
+ self.proto.ClearField('optional_foreign_message')
+ child.c = 128
+ self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.optional_message_extension
@@ -1698,7 +2211,6 @@ class ByteSizeTest(unittest.TestCase):
self.assertEqual(19, self.packed_extended_proto.ByteSize())
-# TODO(robinson): We need cross-language serialization consistency tests.
# Issues to be sure to cover include:
# * Handling of unrecognized tags ("uninterpreted_bytes").
# * Handling of MessageSets.
@@ -1710,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(unittest.TestCase):
+class SerializationTest(basetest.TestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
second_proto = unittest_pb2.TestAllTypes()
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllFields(self):
@@ -1726,7 +2240,9 @@ class SerializationTest(unittest.TestCase):
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllExtensions(self):
@@ -1734,7 +2250,19 @@ class SerializationTest(unittest.TestCase):
second_proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(first_proto)
serialized = first_proto.SerializeToString()
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeWithOptionalGroup(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ second_proto = unittest_pb2.TestAllTypes()
+ first_proto.optionalgroup.a = 242
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeNegativeValues(self):
@@ -1753,6 +2281,10 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
def testParseTruncated(self):
+ # This test is only applicable for the Python implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
@@ -1822,7 +2354,9 @@ class SerializationTest(unittest.TestCase):
second_proto.optional_int32 = 100
second_proto.optional_nested_message.bb = 999
- second_proto.MergeFromString(serialized)
+ bytes_parsed = second_proto.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_parsed)
+
# Ensure that we append to repeated fields.
self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
# Ensure that we overwrite nonrepeatd scalars.
@@ -1847,20 +2381,28 @@ class SerializationTest(unittest.TestCase):
raw = unittest_mset_pb2.RawMessageSet()
self.assertEqual(False,
raw.DESCRIPTOR.GetOptions().message_set_wire_format)
- raw.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ raw.MergeFromString(serialized))
self.assertEqual(2, len(raw.item))
message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.MergeFromString(raw.item[0].message)
+ self.assertEqual(
+ len(raw.item[0].message),
+ message1.MergeFromString(raw.item[0].message))
self.assertEqual(123, message1.i)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
- message2.MergeFromString(raw.item[1].message)
+ self.assertEqual(
+ len(raw.item[1].message),
+ message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
# Deserialize using the MessageSet wire format.
proto2 = unittest_mset_pb2.TestMessageSet()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
@@ -1900,7 +2442,9 @@ class SerializationTest(unittest.TestCase):
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto.MergeFromString(serialized))
# Check that the message parsed well.
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
@@ -1918,7 +2462,9 @@ class SerializationTest(unittest.TestCase):
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
# Now test with a int64 field set.
proto = unittest_pb2.TestAllTypes()
@@ -1928,13 +2474,15 @@ class SerializationTest(unittest.TestCase):
# unknown.
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
def _CheckRaises(self, exc_class, callable_obj, exception):
"""This method checks if the excpetion type and message are as expected."""
try:
callable_obj()
- except exc_class, ex:
+ except exc_class as ex:
# Check if the exception message is the right one.
self.assertEqual(exception, str(ex))
return
@@ -1946,15 +2494,22 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: a,b,c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: '
+ 'a,b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
+ proto2 = unittest_pb2.TestRequired()
+ self.assertFalse(proto2.HasField('a'))
+ # proto2 ParseFromString does not check that required fields are set.
+ proto2.ParseFromString(partial)
+ self.assertFalse(proto2.HasField('a'))
+
proto.a = 1
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: b,c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1962,7 +2517,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1972,11 +2527,15 @@ class SerializationTest(unittest.TestCase):
partial = proto.SerializePartialToString()
proto2 = unittest_pb2.TestRequired()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
- proto2.ParseFromString(partial)
+ self.assertEqual(
+ len(partial),
+ proto2.MergeFromString(partial))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
@@ -1991,7 +2550,8 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: '
+ 'Message protobuf_unittest.TestRequiredForeign '
+ 'is missing required fields: '
'optional_message.b,optional_message.c')
proto.optional_message.b = 2
@@ -2003,7 +2563,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: '
+ 'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
'repeated_message[0].b,repeated_message[0].c,'
'repeated_message[1].a,repeated_message[1].c')
@@ -2043,7 +2603,9 @@ class SerializationTest(unittest.TestCase):
second_proto.packed_double.extend([1.0, 2.0])
second_proto.packed_sint32.append(4)
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual([3, 1, 2], second_proto.packed_int32)
self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
self.assertEqual([4], second_proto.packed_sint32)
@@ -2076,7 +2638,10 @@ class SerializationTest(unittest.TestCase):
unpacked = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(unpacked)
packed = unittest_pb2.TestPackedTypes()
- packed.MergeFromString(unpacked.SerializeToString())
+ serialized = unpacked.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ packed.MergeFromString(serialized))
expected = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(expected)
self.assertEqual(expected, packed)
@@ -2085,7 +2650,10 @@ class SerializationTest(unittest.TestCase):
packed = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(packed)
unpacked = unittest_pb2.TestUnpackedTypes()
- unpacked.MergeFromString(packed.SerializeToString())
+ serialized = packed.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ unpacked.MergeFromString(serialized))
expected = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(expected)
self.assertEqual(expected, unpacked)
@@ -2137,7 +2705,7 @@ class SerializationTest(unittest.TestCase):
optional_int32=1,
optional_string='foo',
optional_bool=True,
- optional_bytes='bar',
+ optional_bytes=b'bar',
optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
@@ -2155,7 +2723,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1, proto.optional_int32)
self.assertEqual('foo', proto.optional_string)
self.assertEqual(True, proto.optional_bool)
- self.assertEqual('bar', proto.optional_bytes)
+ self.assertEqual(b'bar', proto.optional_bytes)
self.assertEqual(1, proto.optional_nested_message.bb)
self.assertEqual(1, proto.optional_foreign_message.c)
self.assertEqual(unittest_pb2.TestAllTypes.FOO,
@@ -2205,7 +2773,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(unittest.TestCase):
+class OptionsTest(basetest.TestCase):
def testMessageOptions(self):
proto = unittest_mset_pb2.TestMessageSet()
@@ -2232,5 +2800,135 @@ class OptionsTest(unittest.TestCase):
+class ClassAPITest(basetest.TestCase):
+
+ def testMakeClassWithNestedDescriptor(self):
+ leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
+ containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
+ containing_type=None, fields=[],
+ nested_types=[leaf_desc], enum_types=[],
+ extensions=[])
+ sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
+ '', containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
+ containing_type=None, fields=[],
+ nested_types=[child_desc, sibling_desc],
+ enum_types=[], extensions=[])
+ message_class = reflection.MakeClass(parent_desc)
+ self.assertIn('child', message_class.__dict__)
+ self.assertIn('sibling', message_class.__dict__)
+ self.assertIn('leaf', message_class.child.__dict__)
+
+ def _GetSerializedFileDescriptor(self, name):
+ """Get a serialized representation of a test FileDescriptorProto.
+
+ Args:
+ name: All calls to this must use a unique message name, to avoid
+ collisions in the cpp descriptor pool.
+ Returns:
+ A string containing the serialized form of a test FileDescriptorProto.
+ """
+ file_descriptor_str = (
+ 'message_type {'
+ ' name: "' + name + '"'
+ ' field {'
+ ' name: "flat"'
+ ' number: 1'
+ ' label: LABEL_REPEATED'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' field {'
+ ' name: "bar"'
+ ' number: 2'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Bar"'
+ ' }'
+ ' nested_type {'
+ ' name: "Bar"'
+ ' field {'
+ ' name: "baz"'
+ ' number: 3'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Baz"'
+ ' }'
+ ' nested_type {'
+ ' name: "Baz"'
+ ' enum_type {'
+ ' name: "deep_enum"'
+ ' value {'
+ ' name: "VALUE_A"'
+ ' number: 0'
+ ' }'
+ ' }'
+ ' field {'
+ ' name: "deep"'
+ ' number: 4'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' }'
+ ' }'
+ '}')
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ text_format.Merge(file_descriptor_str, file_descriptor)
+ return file_descriptor.SerializeToString()
+
+ def testParsingFlatClassWithExplicitClassDeclaration(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+
+ class MessageClass(message.Message):
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ DESCRIPTOR = msg_descriptor
+ msg = MessageClass()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingFlatClass(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingNestedClass(self):
+ """Test that the generated class can parse a nested message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'bar {'
+ ' baz {'
+ ' deep: 4'
+ ' }'
+ '}')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.bar.baz.deep, 4)
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index e04f825..ef0981d 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -34,13 +34,13 @@
__author__ = 'petar@google.com (Petar Petrov)'
-import unittest
+from google.apputils import basetest
from google.protobuf import unittest_pb2
from google.protobuf import service_reflection
from google.protobuf import service
-class FooUnitTest(unittest.TestCase):
+class FooUnitTest(basetest.TestCase):
def testService(self):
class MockRpcChannel(service.RpcChannel):
@@ -133,4 +133,4 @@ class FooUnitTest(unittest.TestCase):
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py
new file mode 100644
index 0000000..80bc8d6
--- /dev/null
+++ b/python/google/protobuf/internal/symbol_database_test.py
@@ -0,0 +1,120 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.symbol_database."""
+
+from google.apputils import basetest
+from google.protobuf import unittest_pb2
+from google.protobuf import symbol_database
+
+
+class SymbolDatabaseTest(basetest.TestCase):
+
+ def _Database(self):
+ db = symbol_database.SymbolDatabase()
+ # Register representative types from unittest_pb2.
+ db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR)
+ db.RegisterMessage(unittest_pb2.TestAllTypes)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup)
+ db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup)
+ db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR)
+ db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR)
+ return db
+
+ def testGetPrototype(self):
+ instance = self._Database().GetPrototype(
+ unittest_pb2.TestAllTypes.DESCRIPTOR)
+ self.assertTrue(instance is unittest_pb2.TestAllTypes)
+
+ def testGetMessages(self):
+ messages = self._Database().GetMessages(
+ ['google/protobuf/unittest.proto'])
+ self.assertTrue(
+ unittest_pb2.TestAllTypes is
+ messages['protobuf_unittest.TestAllTypes'])
+
+ def testGetSymbol(self):
+ self.assertEquals(
+ unittest_pb2.TestAllTypes, self._Database().GetSymbol(
+ 'protobuf_unittest.TestAllTypes'))
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.NestedMessage'))
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.OptionalGroup'))
+ self.assertEquals(
+ unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol(
+ 'protobuf_unittest.TestAllTypes.RepeatedGroup'))
+
+ def testEnums(self):
+ # Check registration of types in the pool.
+ self.assertEquals(
+ 'protobuf_unittest.ForeignEnum',
+ self._Database().pool.FindEnumTypeByName(
+ 'protobuf_unittest.ForeignEnum').full_name)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes.NestedEnum',
+ self._Database().pool.FindEnumTypeByName(
+ 'protobuf_unittest.TestAllTypes.NestedEnum').full_name)
+
+ def testFindMessageTypeByName(self):
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes',
+ self._Database().pool.FindMessageTypeByName(
+ 'protobuf_unittest.TestAllTypes').full_name)
+ self.assertEquals(
+ 'protobuf_unittest.TestAllTypes.NestedMessage',
+ self._Database().pool.FindMessageTypeByName(
+ 'protobuf_unittest.TestAllTypes.NestedMessage').full_name)
+
+ def testFindFindContainingSymbol(self):
+ # Lookup based on either enum or message.
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ self._Database().pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes.NestedEnum').name)
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ self._Database().pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes').name)
+
+ def testFindFileByName(self):
+ self.assertEquals(
+ 'google/protobuf/unittest.proto',
+ self._Database().pool.FindFileByName(
+ 'google/protobuf/unittest.proto').name)
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto
new file mode 100644
index 0000000..6a82299
--- /dev/null
+++ b/python/google/protobuf/internal/test_bad_identifiers.proto
@@ -0,0 +1,52 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+
+
+package protobuf_unittest;
+
+option py_generic_services = true;
+
+message TestBadIdentifiers {
+ extensions 100 to max;
+}
+
+// Make sure these reasonable extension names don't conflict with internal
+// variables.
+extend TestBadIdentifiers {
+ optional string message = 100 [default="foo"];
+ optional string descriptor = 101 [default="bar"];
+ optional string reflection = 102 [default="baz"];
+ optional string service = 103 [default="qux"];
+}
+
+message AnotherMessage {}
+service AnotherService {}
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 1df1619..350d1c6 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -42,8 +42,8 @@ from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
-def SetAllFields(message):
- """Sets every field in the message to a unique value.
+def SetAllNonLazyFields(message):
+ """Sets every non-lazy field in the message to a unique value.
Args:
message: A unittest_pb2.TestAllTypes instance.
@@ -66,26 +66,21 @@ def SetAllFields(message):
message.optional_float = 111
message.optional_double = 112
message.optional_bool = True
- # TODO(robinson): Firmly spec out and test how
- # protos interact with unicode. One specific example:
- # what happens if we change the literal below to
- # u'115'? What *should* happen? Still some discussion
- # to finish with Kenton about bytes vs. strings
- # and forcing everything to be utf8. :-/
- message.optional_string = '115'
- message.optional_bytes = '116'
+ message.optional_string = u'115'
+ message.optional_bytes = b'116'
message.optionalgroup.a = 117
message.optional_nested_message.bb = 118
message.optional_foreign_message.c = 119
message.optional_import_message.d = 120
+ message.optional_public_import_message.e = 126
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
- message.optional_string_piece = '124'
- message.optional_cord = '125'
+ message.optional_string_piece = u'124'
+ message.optional_cord = u'125'
#
# Repeated fields.
@@ -104,20 +99,21 @@ def SetAllFields(message):
message.repeated_float.append(211)
message.repeated_double.append(212)
message.repeated_bool.append(True)
- message.repeated_string.append('215')
- message.repeated_bytes.append('216')
+ message.repeated_string.append(u'215')
+ message.repeated_bytes.append(b'216')
message.repeatedgroup.add().a = 217
message.repeated_nested_message.add().bb = 218
message.repeated_foreign_message.add().c = 219
message.repeated_import_message.add().d = 220
+ message.repeated_lazy_message.add().bb = 227
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
- message.repeated_string_piece.append('224')
- message.repeated_cord.append('225')
+ message.repeated_string_piece.append(u'224')
+ message.repeated_cord.append(u'225')
# Add a second one of each field.
message.repeated_int32.append(301)
@@ -133,20 +129,21 @@ def SetAllFields(message):
message.repeated_float.append(311)
message.repeated_double.append(312)
message.repeated_bool.append(False)
- message.repeated_string.append('315')
- message.repeated_bytes.append('316')
+ message.repeated_string.append(u'315')
+ message.repeated_bytes.append(b'316')
message.repeatedgroup.add().a = 317
message.repeated_nested_message.add().bb = 318
message.repeated_foreign_message.add().c = 319
message.repeated_import_message.add().d = 320
+ message.repeated_lazy_message.add().bb = 327
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
- message.repeated_string_piece.append('324')
- message.repeated_cord.append('325')
+ message.repeated_string_piece.append(u'324')
+ message.repeated_cord.append(u'325')
#
# Fields that have defaults.
@@ -166,7 +163,7 @@ def SetAllFields(message):
message.default_double = 412
message.default_bool = False
message.default_string = '415'
- message.default_bytes = '416'
+ message.default_bytes = b'416'
message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
@@ -175,6 +172,16 @@ def SetAllFields(message):
message.default_string_piece = '424'
message.default_cord = '425'
+ message.oneof_uint32 = 601
+ message.oneof_nested_message.bb = 602
+ message.oneof_string = '603'
+ message.oneof_bytes = b'604'
+
+
+def SetAllFields(message):
+ SetAllNonLazyFields(message)
+ message.optional_lazy_message.bb = 127
+
def SetAllExtensions(message):
"""Sets every extension in the message to a unique value.
@@ -204,21 +211,23 @@ def SetAllExtensions(message):
extensions[pb2.optional_float_extension] = 111
extensions[pb2.optional_double_extension] = 112
extensions[pb2.optional_bool_extension] = True
- extensions[pb2.optional_string_extension] = '115'
- extensions[pb2.optional_bytes_extension] = '116'
+ extensions[pb2.optional_string_extension] = u'115'
+ extensions[pb2.optional_bytes_extension] = b'116'
extensions[pb2.optionalgroup_extension].a = 117
extensions[pb2.optional_nested_message_extension].bb = 118
extensions[pb2.optional_foreign_message_extension].c = 119
extensions[pb2.optional_import_message_extension].d = 120
+ extensions[pb2.optional_public_import_message_extension].e = 126
+ extensions[pb2.optional_lazy_message_extension].bb = 127
extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ
extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ
- extensions[pb2.optional_string_piece_extension] = '124'
- extensions[pb2.optional_cord_extension] = '125'
+ extensions[pb2.optional_string_piece_extension] = u'124'
+ extensions[pb2.optional_cord_extension] = u'125'
#
# Repeated fields.
@@ -237,20 +246,21 @@ def SetAllExtensions(message):
extensions[pb2.repeated_float_extension].append(211)
extensions[pb2.repeated_double_extension].append(212)
extensions[pb2.repeated_bool_extension].append(True)
- extensions[pb2.repeated_string_extension].append('215')
- extensions[pb2.repeated_bytes_extension].append('216')
+ extensions[pb2.repeated_string_extension].append(u'215')
+ extensions[pb2.repeated_bytes_extension].append(b'216')
extensions[pb2.repeatedgroup_extension].add().a = 217
extensions[pb2.repeated_nested_message_extension].add().bb = 218
extensions[pb2.repeated_foreign_message_extension].add().c = 219
extensions[pb2.repeated_import_message_extension].add().d = 220
+ extensions[pb2.repeated_lazy_message_extension].add().bb = 227
extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR)
extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR)
extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR)
- extensions[pb2.repeated_string_piece_extension].append('224')
- extensions[pb2.repeated_cord_extension].append('225')
+ extensions[pb2.repeated_string_piece_extension].append(u'224')
+ extensions[pb2.repeated_cord_extension].append(u'225')
# Append a second one of each field.
extensions[pb2.repeated_int32_extension].append(301)
@@ -266,20 +276,21 @@ def SetAllExtensions(message):
extensions[pb2.repeated_float_extension].append(311)
extensions[pb2.repeated_double_extension].append(312)
extensions[pb2.repeated_bool_extension].append(False)
- extensions[pb2.repeated_string_extension].append('315')
- extensions[pb2.repeated_bytes_extension].append('316')
+ extensions[pb2.repeated_string_extension].append(u'315')
+ extensions[pb2.repeated_bytes_extension].append(b'316')
extensions[pb2.repeatedgroup_extension].add().a = 317
extensions[pb2.repeated_nested_message_extension].add().bb = 318
extensions[pb2.repeated_foreign_message_extension].add().c = 319
extensions[pb2.repeated_import_message_extension].add().d = 320
+ extensions[pb2.repeated_lazy_message_extension].add().bb = 327
extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ)
extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ)
extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ)
- extensions[pb2.repeated_string_piece_extension].append('324')
- extensions[pb2.repeated_cord_extension].append('325')
+ extensions[pb2.repeated_string_piece_extension].append(u'324')
+ extensions[pb2.repeated_cord_extension].append(u'325')
#
# Fields with defaults.
@@ -298,16 +309,21 @@ def SetAllExtensions(message):
extensions[pb2.default_float_extension] = 411
extensions[pb2.default_double_extension] = 412
extensions[pb2.default_bool_extension] = False
- extensions[pb2.default_string_extension] = '415'
- extensions[pb2.default_bytes_extension] = '416'
+ extensions[pb2.default_string_extension] = u'415'
+ extensions[pb2.default_bytes_extension] = b'416'
extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO
extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO
extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO
- extensions[pb2.default_string_piece_extension] = '424'
+ extensions[pb2.default_string_piece_extension] = u'424'
extensions[pb2.default_cord_extension] = '425'
+ extensions[pb2.oneof_uint32_extension] = 601
+ extensions[pb2.oneof_nested_message_extension].bb = 602
+ extensions[pb2.oneof_string_extension] = u'603'
+ extensions[pb2.oneof_bytes_extension] = b'604'
+
def SetAllFieldsAndExtensions(message):
"""Sets every field and extension in the message to a unique value.
@@ -346,7 +362,7 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized):
message.my_float = 1.0
expected_strings.append(message.SerializeToString())
message.Clear()
- expected = ''.join(expected_strings)
+ expected = b''.join(expected_strings)
if expected != serialized:
raise ValueError('Expected %r, found %r' % (expected, serialized))
@@ -401,12 +417,14 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(112, message.optional_double)
test_case.assertEqual(True, message.optional_bool)
test_case.assertEqual('115', message.optional_string)
- test_case.assertEqual('116', message.optional_bytes)
+ test_case.assertEqual(b'116', message.optional_bytes)
test_case.assertEqual(117, message.optionalgroup.a)
test_case.assertEqual(118, message.optional_nested_message.bb)
test_case.assertEqual(119, message.optional_foreign_message.c)
test_case.assertEqual(120, message.optional_import_message.d)
+ test_case.assertEqual(126, message.optional_public_import_message.e)
+ test_case.assertEqual(127, message.optional_lazy_message.bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.optional_nested_enum)
@@ -458,12 +476,13 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(212, message.repeated_double[0])
test_case.assertEqual(True, message.repeated_bool[0])
test_case.assertEqual('215', message.repeated_string[0])
- test_case.assertEqual('216', message.repeated_bytes[0])
+ test_case.assertEqual(b'216', message.repeated_bytes[0])
test_case.assertEqual(217, message.repeatedgroup[0].a)
test_case.assertEqual(218, message.repeated_nested_message[0].bb)
test_case.assertEqual(219, message.repeated_foreign_message[0].c)
test_case.assertEqual(220, message.repeated_import_message[0].d)
+ test_case.assertEqual(227, message.repeated_lazy_message[0].bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
message.repeated_nested_enum[0])
@@ -486,12 +505,13 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(312, message.repeated_double[1])
test_case.assertEqual(False, message.repeated_bool[1])
test_case.assertEqual('315', message.repeated_string[1])
- test_case.assertEqual('316', message.repeated_bytes[1])
+ test_case.assertEqual(b'316', message.repeated_bytes[1])
test_case.assertEqual(317, message.repeatedgroup[1].a)
test_case.assertEqual(318, message.repeated_nested_message[1].bb)
test_case.assertEqual(319, message.repeated_foreign_message[1].c)
test_case.assertEqual(320, message.repeated_import_message[1].d)
+ test_case.assertEqual(327, message.repeated_lazy_message[1].bb)
test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
message.repeated_nested_enum[1])
@@ -536,7 +556,7 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(412, message.default_double)
test_case.assertEqual(False, message.default_bool)
test_case.assertEqual('415', message.default_string)
- test_case.assertEqual('416', message.default_bytes)
+ test_case.assertEqual(b'416', message.default_bytes)
test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
message.default_nested_enum)
@@ -545,6 +565,7 @@ def ExpectAllFieldsSet(test_case, message):
test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
message.default_import_enum)
+
def GoldenFile(filename):
"""Finds the given golden file and returns a file object representing it."""
@@ -558,9 +579,15 @@ def GoldenFile(filename):
path = os.path.join(path, '..')
raise RuntimeError(
- 'Could not find golden files. This test must be run from within the '
- 'protobuf source package so that it can read test data files from the '
- 'C++ source tree.')
+ 'Could not find golden files. This test must be run from within the '
+ 'protobuf source package so that it can read test data files from the '
+ 'C++ source tree.')
+
+
+def GoldenFileData(filename):
+ """Finds the given golden file and returns its contents."""
+ with GoldenFile(filename) as f:
+ return f.read()
def SetAllPackedFields(message):
diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py
new file mode 100755
index 0000000..ba0e45d
--- /dev/null
+++ b/python/google/protobuf/internal/text_encoding_test.py
@@ -0,0 +1,68 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.text_encoding."""
+
+from google.apputils import basetest
+from google.protobuf import text_encoding
+
+TEST_VALUES = [
+ ("foo\\rbar\\nbaz\\t",
+ "foo\\rbar\\nbaz\\t",
+ b"foo\rbar\nbaz\t"),
+ ("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'",
+ "\\'full of \\\"sound\\\" and \\\"fury\\\"\\'",
+ b"'full of \"sound\" and \"fury\"'"),
+ ("signi\\\\fying\\\\ nothing\\\\",
+ "signi\\\\fying\\\\ nothing\\\\",
+ b"signi\\fying\\ nothing\\"),
+ ("\\010\\t\\n\\013\\014\\r",
+ "\x08\\t\\n\x0b\x0c\\r",
+ b"\010\011\012\013\014\015")]
+
+
+class TextEncodingTestCase(basetest.TestCase):
+ def testCEscape(self):
+ for escaped, escaped_utf8, unescaped in TEST_VALUES:
+ self.assertEquals(escaped,
+ text_encoding.CEscape(unescaped, as_utf8=False))
+ self.assertEquals(escaped_utf8,
+ text_encoding.CEscape(unescaped, as_utf8=True))
+
+ def testCUnescape(self):
+ for escaped, escaped_utf8, unescaped in TEST_VALUES:
+ self.assertEquals(unescaped, text_encoding.CUnescape(escaped))
+ self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8))
+
+
+if __name__ == "__main__":
+ basetest.main()
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index e0991cb..d27ff7a 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -34,48 +34,71 @@
__author__ = 'kenton@google.com (Kenton Varda)'
-import difflib
+import re
-import unittest
+from google.apputils import basetest
from google.protobuf import text_format
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import test_util
from google.protobuf import unittest_pb2
from google.protobuf import unittest_mset_pb2
+class TextFormatTest(basetest.TestCase):
-class TextFormatTest(unittest.TestCase):
def ReadGolden(self, golden_filename):
- f = test_util.GoldenFile(golden_filename)
- golden_lines = f.readlines()
- f.close()
- return golden_lines
+ with test_util.GoldenFile(golden_filename) as f:
+ return (f.readlines() if str is bytes else # PY3
+ [golden_line.decode('utf-8') for golden_line in f])
def CompareToGoldenFile(self, text, golden_filename):
golden_lines = self.ReadGolden(golden_filename)
- self.CompareToGoldenLines(text, golden_lines)
+ self.assertMultiLineEqual(text, ''.join(golden_lines))
def CompareToGoldenText(self, text, golden_text):
- self.CompareToGoldenLines(text, golden_text.splitlines(1))
-
- def CompareToGoldenLines(self, text, golden_lines):
- actual_lines = text.splitlines(1)
- self.assertEqual(golden_lines, actual_lines,
- "Text doesn't match golden. Diff:\n" +
- ''.join(difflib.ndiff(golden_lines, actual_lines)))
+ self.assertMultiLineEqual(text, golden_text)
def testPrintAllFields(self):
message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message)
self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_data.txt')
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_data_oneof_implemented.txt')
+
+ def testPrintInIndexOrder(self):
+ message = unittest_pb2.TestFieldOrderings()
+ message.my_string = '115'
+ message.my_int = 101
+ message.my_float = 111
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, use_index_order=True)),
+ 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n')
def testPrintAllExtensions(self):
message = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(message)
self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_extensions_data.txt')
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_extensions_data.txt')
+
+ def testPrintAllFieldsPointy(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, pointy_brackets=True)),
+ 'text_format_unittest_data_pointy_oneof.txt')
+
+ def testPrintAllExtensionsPointy(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(
+ message, pointy_brackets=True)),
+ 'text_format_unittest_extensions_data_pointy.txt')
def testPrintMessageSet(self):
message = unittest_mset_pb2.TestMessageSetContainer()
@@ -83,33 +106,179 @@ class TextFormatTest(unittest.TestCase):
ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
message.message_set.Extensions[ext1].i = 23
message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(text_format.MessageToString(message),
- 'message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
def testPrintExotic(self):
message = unittest_pb2.TestAllTypes()
- message.repeated_int64.append(-9223372036854775808);
- message.repeated_uint64.append(18446744073709551615);
- message.repeated_double.append(123.456);
- message.repeated_double.append(1.23e22);
- message.repeated_double.append(1.23e-18);
- message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"');
+ message.repeated_int64.append(-9223372036854775808)
+ message.repeated_uint64.append(18446744073709551615)
+ message.repeated_double.append(123.456)
+ message.repeated_double.append(1.23e22)
+ message.repeated_double.append(1.23e-18)
+ message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
+ message.repeated_string.append(u'\u00fc\ua71f')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'repeated_int64: -9223372036854775808\n'
+ 'repeated_uint64: 18446744073709551615\n'
+ 'repeated_double: 123.456\n'
+ 'repeated_double: 1.23e+22\n'
+ 'repeated_double: 1.23e-18\n'
+ 'repeated_string:'
+ ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
+ 'repeated_string: "\\303\\274\\352\\234\\237"\n')
+
+ def testPrintExoticUnicodeSubclass(self):
+ class UnicodeSub(unicode):
+ pass
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
+ self.CompareToGoldenText(
+ text_format.MessageToString(message),
+ 'repeated_string: "\\303\\274\\352\\234\\237"\n')
+
+ def testPrintNestedMessageAsOneLine(self):
+ message = unittest_pb2.TestAllTypes()
+ msg = message.repeated_nested_message.add()
+ msg.bb = 42
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'repeated_nested_message { bb: 42 }')
+
+ def testPrintRepeatedFieldsAsOneLine(self):
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_int32.append(1)
+ message.repeated_int32.append(1)
+ message.repeated_int32.append(3)
+ message.repeated_string.append("Google")
+ message.repeated_string.append("Zurich")
self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'repeated_int64: -9223372036854775808\n'
- 'repeated_uint64: 18446744073709551615\n'
- 'repeated_double: 123.456\n'
- 'repeated_double: 1.23e+22\n'
- 'repeated_double: 1.23e-18\n'
- 'repeated_string: '
- '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
+ text_format.MessageToString(message, as_one_line=True),
+ 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
+ 'repeated_string: "Google" repeated_string: "Zurich"')
+
+ def testPrintNestedNewLineInStringAsOneLine(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_string = "a\nnew\nline"
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'optional_string: "a\\nnew\\nline"')
+
+ def testPrintMessageSetAsOneLine(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(
+ text_format.MessageToString(message, as_one_line=True),
+ 'message_set {'
+ ' [protobuf_unittest.TestMessageSetExtension1] {'
+ ' i: 23'
+ ' }'
+ ' [protobuf_unittest.TestMessageSetExtension2] {'
+ ' str: \"foo\"'
+ ' }'
+ ' }')
+
+ def testPrintExoticAsOneLine(self):
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_int64.append(-9223372036854775808)
+ message.repeated_uint64.append(18446744073709551615)
+ message.repeated_double.append(123.456)
+ message.repeated_double.append(1.23e22)
+ message.repeated_double.append(1.23e-18)
+ message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
+ message.repeated_string.append(u'\u00fc\ua71f')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(
+ text_format.MessageToString(message, as_one_line=True)),
+ 'repeated_int64: -9223372036854775808'
+ ' repeated_uint64: 18446744073709551615'
+ ' repeated_double: 123.456'
+ ' repeated_double: 1.23e+22'
+ ' repeated_double: 1.23e-18'
+ ' repeated_string: '
+ '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
+ ' repeated_string: "\\303\\274\\352\\234\\237"')
+
+ def testRoundTripExoticAsOneLine(self):
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_int64.append(-9223372036854775808)
+ message.repeated_uint64.append(18446744073709551615)
+ message.repeated_double.append(123.456)
+ message.repeated_double.append(1.23e22)
+ message.repeated_double.append(1.23e-18)
+ message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
+ message.repeated_string.append(u'\u00fc\ua71f')
+
+ # Test as_utf8 = False.
+ wire_text = text_format.MessageToString(
+ message, as_one_line=True, as_utf8=False)
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.Parse(wire_text, parsed_message)
+ self.assertIs(r, parsed_message)
+ self.assertEquals(message, parsed_message)
+
+ # Test as_utf8 = True.
+ wire_text = text_format.MessageToString(
+ message, as_one_line=True, as_utf8=True)
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.Parse(wire_text, parsed_message)
+ self.assertIs(r, parsed_message)
+ self.assertEquals(message, parsed_message,
+ '\n%s != %s' % (message, parsed_message))
+
+ def testPrintRawUtf8String(self):
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_string.append(u'\u00fc\ua71f')
+ text = text_format.MessageToString(message, as_utf8=True)
+ self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
+ parsed_message = unittest_pb2.TestAllTypes()
+ text_format.Parse(text, parsed_message)
+ self.assertEquals(message, parsed_message,
+ '\n%s != %s' % (message, parsed_message))
+
+ def testPrintFloatFormat(self):
+ # Check that float_format argument is passed to sub-message formatting.
+ message = unittest_pb2.NestedTestAllTypes()
+ # We use 1.25 as it is a round number in binary. The proto 32-bit float
+ # will not gain additional imprecise digits as a 64-bit Python float and
+ # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit:
+ # >>> struct.unpack('f', struct.pack('f', 1.2))[0]
+ # 1.2000000476837158
+ # >>> struct.unpack('f', struct.pack('f', 1.25))[0]
+ # 1.25
+ message.payload.optional_float = 1.25
+ # Check rounding at 15 significant digits
+ message.payload.optional_double = -.000003456789012345678
+ # Check no decimal point.
+ message.payload.repeated_float.append(-5642)
+ # Check no trailing zeros.
+ message.payload.repeated_double.append(.000078900)
+ formatted_fields = ['optional_float: 1.25',
+ 'optional_double: -3.45678901234568e-6',
+ 'repeated_float: -5642',
+ 'repeated_double: 7.89e-5']
+ text_message = text_format.MessageToString(message, float_format='.15g')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_message),
+ 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields))
+ # as_one_line=True is a separate code branch where float_format is passed.
+ text_message = text_format.MessageToString(message, as_one_line=True,
+ float_format='.15g')
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_message),
+ 'payload {{ {} {} {} {} }}'.format(*formatted_fields))
def testMessageToString(self):
message = unittest_pb2.ForeignMessage()
@@ -119,52 +288,57 @@ class TextFormatTest(unittest.TestCase):
def RemoveRedundantZeros(self, text):
# Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
# these zeros in order to match the golden file.
- return text.replace('e+0','e+').replace('e+0','e+') \
+ text = text.replace('e+0','e+').replace('e+0','e+') \
.replace('e-0','e-').replace('e-0','e-')
+ # Floating point fields are printed with .0 suffix even if they are
+ # actualy integer numbers.
+ text = re.compile('\.0$', re.MULTILINE).sub('', text)
+ return text
- def testMergeGolden(self):
+ def testParseGolden(self):
golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
parsed_message = unittest_pb2.TestAllTypes()
- text_format.Merge(golden_text, parsed_message)
+ r = text_format.Parse(golden_text, parsed_message)
+ self.assertIs(r, parsed_message)
message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message)
self.assertEquals(message, parsed_message)
- def testMergeGoldenExtensions(self):
+ def testParseGoldenExtensions(self):
golden_text = '\n'.join(self.ReadGolden(
'text_format_unittest_extensions_data.txt'))
parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Merge(golden_text, parsed_message)
+ text_format.Parse(golden_text, parsed_message)
message = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(message)
self.assertEquals(message, parsed_message)
- def testMergeAllFields(self):
+ def testParseAllFields(self):
message = unittest_pb2.TestAllTypes()
test_util.SetAllFields(message)
ascii_text = text_format.MessageToString(message)
parsed_message = unittest_pb2.TestAllTypes()
- text_format.Merge(ascii_text, parsed_message)
+ text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
test_util.ExpectAllFieldsSet(self, message)
- def testMergeAllExtensions(self):
+ def testParseAllExtensions(self):
message = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(message)
ascii_text = text_format.MessageToString(message)
parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Merge(ascii_text, parsed_message)
+ text_format.Parse(ascii_text, parsed_message)
self.assertEqual(message, parsed_message)
- def testMergeMessageSet(self):
+ def testParseMessageSet(self):
message = unittest_pb2.TestAllTypes()
text = ('repeated_uint64: 1\n'
'repeated_uint64: 2\n')
- text_format.Merge(text, message)
+ text_format.Parse(text, message)
self.assertEqual(1, message.repeated_uint64[0])
self.assertEqual(2, message.repeated_uint64[1])
@@ -177,13 +351,13 @@ class TextFormatTest(unittest.TestCase):
' str: \"foo\"\n'
' }\n'
'}\n')
- text_format.Merge(text, message)
+ text_format.Parse(text, message)
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
self.assertEquals(23, message.message_set.Extensions[ext1].i)
self.assertEquals('foo', message.message_set.Extensions[ext2].str)
- def testMergeExotic(self):
+ def testParseExotic(self):
message = unittest_pb2.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n'
@@ -191,9 +365,12 @@ class TextFormatTest(unittest.TestCase):
'repeated_double: 1.23e+22\n'
'repeated_double: 1.23e-18\n'
'repeated_string: \n'
- '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n'
- 'repeated_string: "foo" \'corge\' "grault"')
- text_format.Merge(text, message)
+ '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
+ 'repeated_string: "foo" \'corge\' "grault"\n'
+ 'repeated_string: "\\303\\274\\352\\234\\237"\n'
+ 'repeated_string: "\\xc3\\xbc"\n'
+ 'repeated_string: "\xc3\xbc"\n')
+ text_format.Parse(text, message)
self.assertEqual(-9223372036854775808, message.repeated_int64[0])
self.assertEqual(18446744073709551615, message.repeated_uint64[0])
@@ -201,95 +378,217 @@ class TextFormatTest(unittest.TestCase):
self.assertEqual(1.23e22, message.repeated_double[1])
self.assertEqual(1.23e-18, message.repeated_double[2])
self.assertEqual(
- '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0])
+ '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
self.assertEqual('foocorgegrault', message.repeated_string[1])
+ self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
+ self.assertEqual(u'\u00fc', message.repeated_string[3])
+
+ def testParseTrailingCommas(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_int64: 100;\n'
+ 'repeated_int64: 200;\n'
+ 'repeated_int64: 300,\n'
+ 'repeated_string: "one",\n'
+ 'repeated_string: "two";\n')
+ text_format.Parse(text, message)
+
+ self.assertEqual(100, message.repeated_int64[0])
+ self.assertEqual(200, message.repeated_int64[1])
+ self.assertEqual(300, message.repeated_int64[2])
+ self.assertEqual(u'one', message.repeated_string[0])
+ self.assertEqual(u'two', message.repeated_string[1])
+
+ def testParseEmptyText(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ''
+ text_format.Parse(text, message)
+ self.assertEquals(unittest_pb2.TestAllTypes(), message)
+
+ def testParseInvalidUtf8(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'repeated_string: "\\xc3\\xc3"'
+ self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
+
+ def testParseSingleWord(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'foo'
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
+ '"foo".'),
+ text_format.Parse, text, message)
- def testMergeUnknownField(self):
+ def testParseUnknownField(self):
message = unittest_pb2.TestAllTypes()
text = 'unknown_field: 8\n'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError,
('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
'"unknown_field".'),
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
- def testMergeBadExtension(self):
+ def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.',
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError,
('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
'extensions.'),
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
- def testMergeGroupNotClosed(self):
+ def testParseGroupNotClosed(self):
message = unittest_pb2.TestAllTypes()
text = 'RepeatedGroup: <'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError, '1:16 : Expected ">".',
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
text = 'RepeatedGroup: {'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError, '1:16 : Expected "}".',
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
- def testMergeEmptyGroup(self):
+ def testParseEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
text = 'OptionalGroup: {}'
- text_format.Merge(text, message)
+ text_format.Parse(text, message)
self.assertTrue(message.HasField('optionalgroup'))
message.Clear()
message = unittest_pb2.TestAllTypes()
text = 'OptionalGroup: <>'
- text_format.Merge(text, message)
+ text_format.Parse(text, message)
self.assertTrue(message.HasField('optionalgroup'))
- def testMergeBadEnumValue(self):
+ def testParseBadEnumValue(self):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: BARR'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value named BARR.'),
- text_format.Merge, text, message)
+ text_format.Parse, text, message)
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: 100'
- self.assertRaisesWithMessage(
+ self.assertRaisesWithLiteralMatch(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value with number 100.'),
- text_format.Merge, text, message)
-
- def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs):
- """Same as assertRaises, but also compares the exception message."""
- if hasattr(e_class, '__name__'):
- exc_name = e_class.__name__
- else:
- exc_name = str(e_class)
-
- try:
- func(*args, **kwargs)
- except e_class, expr:
- if str(expr) != e:
- msg = '%s raised, but with wrong message: "%s" instead of "%s"'
- raise self.failureException(msg % (exc_name,
- str(expr).encode('string_escape'),
- e.encode('string_escape')))
- return
- else:
- raise self.failureException('%s not raised' % exc_name)
-
-
-class TokenizerTest(unittest.TestCase):
+ text_format.Parse, text, message)
+
+ def testParseBadIntValue(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_int32: bork'
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:17 : Couldn\'t parse integer: bork'),
+ text_format.Parse, text, message)
+
+ def testParseStringFieldUnescape(self):
+ message = unittest_pb2.TestAllTypes()
+ text = r'''repeated_string: "\xf\x62"
+ repeated_string: "\\xf\\x62"
+ repeated_string: "\\\xf\\\x62"
+ repeated_string: "\\\\xf\\\\x62"
+ repeated_string: "\\\\\xf\\\\\x62"
+ repeated_string: "\x5cx20"'''
+ text_format.Parse(text, message)
+
+ SLASH = '\\'
+ self.assertEqual('\x0fb', message.repeated_string[0])
+ self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1])
+ self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2])
+ self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62',
+ message.repeated_string[3])
+ self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b',
+ message.repeated_string[4])
+ self.assertEqual(SLASH + 'x20', message.repeated_string[5])
+
+ def testMergeRepeatedScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_int32: 42 '
+ 'optional_int32: 67')
+ r = text_format.Merge(text, message)
+ self.assertIs(r, message)
+ self.assertEqual(67, message.optional_int32)
+
+ def testParseRepeatedScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_int32: 42 '
+ 'optional_int32: 67')
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
+ 'have multiple "optional_int32" fields.'),
+ text_format.Parse, text, message)
+
+ def testMergeRepeatedNestedMessageScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { bb: 1 } '
+ 'optional_nested_message { bb: 2 }')
+ r = text_format.Merge(text, message)
+ self.assertTrue(r is message)
+ self.assertEqual(2, message.optional_nested_message.bb)
+
+ def testParseRepeatedNestedMessageScalars(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('optional_nested_message { bb: 1 } '
+ 'optional_nested_message { bb: 2 }')
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
+ 'should not have multiple "bb" fields.'),
+ text_format.Parse, text, message)
+
+ def testMergeRepeatedExtensionScalars(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = ('[protobuf_unittest.optional_int32_extension]: 42 '
+ '[protobuf_unittest.optional_int32_extension]: 67')
+ text_format.Merge(text, message)
+ self.assertEqual(
+ 67,
+ message.Extensions[unittest_pb2.optional_int32_extension])
+
+ def testParseRepeatedExtensionScalars(self):
+ message = unittest_pb2.TestAllExtensions()
+ text = ('[protobuf_unittest.optional_int32_extension]: 42 '
+ '[protobuf_unittest.optional_int32_extension]: 67')
+ self.assertRaisesWithLiteralMatch(
+ text_format.ParseError,
+ ('1:96 : Message type "protobuf_unittest.TestAllExtensions" '
+ 'should not have multiple '
+ '"protobuf_unittest.optional_int32_extension" extensions.'),
+ text_format.Parse, text, message)
+
+ def testParseLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.ParseLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEquals(message, parsed_message)
+
+ def testMergeLinesGolden(self):
+ opened = self.ReadGolden('text_format_unittest_data.txt')
+ parsed_message = unittest_pb2.TestAllTypes()
+ r = text_format.MergeLines(opened, parsed_message)
+ self.assertIs(r, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEqual(message, parsed_message)
+
+
+class TokenizerTest(basetest.TestCase):
def testSimpleTokenCases(self):
text = ('identifier1:"string1"\n \n\n'
@@ -297,8 +596,9 @@ class TokenizerTest(unittest.TestCase):
'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n'
'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
- 'ID12: 2222222222222222222')
- tokenizer = text_format._Tokenizer(text)
+ 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
+ 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ')
+ tokenizer = text_format._Tokenizer(text.splitlines())
methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
':',
(tokenizer.ConsumeString, 'string1'),
@@ -325,10 +625,10 @@ class TokenizerTest(unittest.TestCase):
'{',
(tokenizer.ConsumeIdentifier, 'A'),
':',
- (tokenizer.ConsumeFloat, text_format._INFINITY),
+ (tokenizer.ConsumeFloat, float('inf')),
(tokenizer.ConsumeIdentifier, 'B'),
':',
- (tokenizer.ConsumeFloat, -text_format._INFINITY),
+ (tokenizer.ConsumeFloat, -float('inf')),
(tokenizer.ConsumeIdentifier, 'C'),
':',
(tokenizer.ConsumeBool, True),
@@ -347,7 +647,25 @@ class TokenizerTest(unittest.TestCase):
(tokenizer.ConsumeInt32, -22),
(tokenizer.ConsumeIdentifier, 'ID12'),
':',
- (tokenizer.ConsumeUint64, 2222222222222222222)]
+ (tokenizer.ConsumeUint64, 2222222222222222222),
+ (tokenizer.ConsumeIdentifier, 'ID13'),
+ ':',
+ (tokenizer.ConsumeFloat, 1.23456),
+ (tokenizer.ConsumeIdentifier, 'ID14'),
+ ':',
+ (tokenizer.ConsumeFloat, 1.2e+2),
+ (tokenizer.ConsumeIdentifier, 'false_bool'),
+ ':',
+ (tokenizer.ConsumeBool, False),
+ (tokenizer.ConsumeIdentifier, 'true_BOOL'),
+ ':',
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'true_bool1'),
+ ':',
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'false_BOOL1'),
+ ':',
+ (tokenizer.ConsumeBool, False)]
i = 0
while not tokenizer.AtEnd():
@@ -366,7 +684,7 @@ class TokenizerTest(unittest.TestCase):
int64_max = (1 << 63) - 1
uint32_max = (1 << 32) - 1
text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
self.assertEqual(-1, tokenizer.ConsumeInt32())
@@ -380,7 +698,7 @@ class TokenizerTest(unittest.TestCase):
self.assertTrue(tokenizer.AtEnd())
text = '-0 -0 0 0'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertEqual(0, tokenizer.ConsumeUint32())
self.assertEqual(0, tokenizer.ConsumeUint64())
self.assertEqual(0, tokenizer.ConsumeUint32())
@@ -389,40 +707,30 @@ class TokenizerTest(unittest.TestCase):
def testConsumeByteString(self):
text = '"string1\''
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = 'string1"'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\xt"'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\"'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
text = '\n"\\x"'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
def testConsumeBool(self):
text = 'not-a-bool'
- tokenizer = text_format._Tokenizer(text)
+ tokenizer = text_format._Tokenizer(text.splitlines())
self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
- def testInfNan(self):
- # Make sure our infinity and NaN definitions are sound.
- self.assertEquals(float, type(text_format._INFINITY))
- self.assertEquals(float, type(text_format._NAN))
- self.assertTrue(text_format._NAN != text_format._NAN)
-
- inf_times_zero = text_format._INFINITY * 0
- self.assertTrue(inf_times_zero != inf_times_zero)
- self.assertTrue(text_format._INFINITY > 0)
-
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 2b3cd4d..8e1b3cc 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#PY25 compatible for GAE.
+#
+# Copyright 2008 Google Inc. All Rights Reserved.
+
"""Provides type checking routines.
This module defines type checking utilities in the forms of dictionaries:
@@ -45,6 +49,9 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
__author__ = 'robinson@google.com (Will Robinson)'
+import sys ##PY25
+if sys.version < '2.6': bytes = str ##PY25
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
@@ -53,21 +60,22 @@ from google.protobuf import descriptor
_FieldDescriptor = descriptor.FieldDescriptor
-def GetTypeChecker(cpp_type, field_type):
+def GetTypeChecker(field):
"""Returns a type checker for a message field of the specified types.
Args:
- cpp_type: C++ type of the field (see descriptor.py).
- field_type: Protocol message field type (see descriptor.py).
+ field: FieldDescriptor object for this field.
Returns:
An instance of TypeChecker which can be used to verify the types
of values assigned to a field of the specified type.
"""
- if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and
- field_type == _FieldDescriptor.TYPE_STRING):
+ if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and
+ field.type == _FieldDescriptor.TYPE_STRING):
return UnicodeValueChecker()
- return _VALUE_CHECKERS[cpp_type]
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
+ return EnumValueChecker(field.enum_type)
+ return _VALUE_CHECKERS[field.cpp_type]
# None of the typecheckers below make any attempt to guard against people
@@ -85,10 +93,15 @@ class TypeChecker(object):
self._acceptable_types = acceptable_types
def CheckValue(self, proposed_value):
+ """Type check the provided value and return it.
+
+ The returned value might have been normalized to another type.
+ """
if not isinstance(proposed_value, self._acceptable_types):
message = ('%.1024r has type %s, but expected one of: %s' %
(proposed_value, type(proposed_value), self._acceptable_types))
raise TypeError(message)
+ return proposed_value
# IntValueChecker and its subclasses perform integer type-checks
@@ -104,28 +117,54 @@ class IntValueChecker(object):
raise TypeError(message)
if not self._MIN <= proposed_value <= self._MAX:
raise ValueError('Value out of range: %d' % proposed_value)
+ # We force 32-bit values to int and 64-bit values to long to make
+ # alternate implementations where the distinction is more significant
+ # (e.g. the C++ implementation) simpler.
+ proposed_value = self._TYPE(proposed_value)
+ return proposed_value
+
+
+class EnumValueChecker(object):
+
+ """Checker used for enum fields. Performs type-check and range check."""
+
+ def __init__(self, enum_type):
+ self._enum_type = enum_type
+
+ def CheckValue(self, proposed_value):
+ if not isinstance(proposed_value, (int, long)):
+ message = ('%.1024r has type %s, but expected one of: %s' %
+ (proposed_value, type(proposed_value), (int, long)))
+ raise TypeError(message)
+ if proposed_value not in self._enum_type.values_by_number:
+ raise ValueError('Unknown enum value: %d' % proposed_value)
+ return proposed_value
class UnicodeValueChecker(object):
- """Checker used for string fields."""
+ """Checker used for string fields.
+
+ Always returns a unicode value, even if the input is of type str.
+ """
def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (str, unicode)):
+ if not isinstance(proposed_value, (bytes, unicode)):
message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (str, unicode)))
+ (proposed_value, type(proposed_value), (bytes, unicode)))
raise TypeError(message)
- # If the value is of type 'str' make sure that it is in 7-bit ASCII
+ # If the value is of type 'bytes' make sure that it is in 7-bit ASCII
# encoding.
- if isinstance(proposed_value, str):
+ if isinstance(proposed_value, bytes):
try:
- unicode(proposed_value, 'ascii')
+ proposed_value = proposed_value.decode('ascii')
except UnicodeDecodeError:
- raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII '
+ raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII '
'encoding. Non-ASCII strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
+ return proposed_value
class Int32ValueChecker(IntValueChecker):
@@ -133,21 +172,25 @@ class Int32ValueChecker(IntValueChecker):
# efficient.
_MIN = -2147483648
_MAX = 2147483647
+ _TYPE = int
class Uint32ValueChecker(IntValueChecker):
_MIN = 0
_MAX = (1 << 32) - 1
+ _TYPE = int
class Int64ValueChecker(IntValueChecker):
_MIN = -(1 << 63)
_MAX = (1 << 63) - 1
+ _TYPE = long
class Uint64ValueChecker(IntValueChecker):
_MIN = 0
_MAX = (1 << 64) - 1
+ _TYPE = long
# Type-checkers for all scalar CPPTYPEs.
@@ -161,8 +204,7 @@ _VALUE_CHECKERS = {
_FieldDescriptor.CPPTYPE_FLOAT: TypeChecker(
float, int, long),
_FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int),
- _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(),
- _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str),
+ _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes),
}
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
new file mode 100755
index 0000000..8f3354c
--- /dev/null
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -0,0 +1,231 @@
+#! /usr/bin/python
+# -*- coding: utf-8 -*-
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for preservation of unknown fields in the pure Python implementation."""
+
+__author__ = 'bohdank@google.com (Bohdan Koval)'
+
+from google.apputils import basetest
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import encoder
+from google.protobuf.internal import missing_enum_values_pb2
+from google.protobuf.internal import test_util
+from google.protobuf.internal import type_checkers
+
+
+class UnknownFieldsTest(basetest.TestCase):
+
+ def setUp(self):
+ self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ self.all_fields = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(self.all_fields)
+ self.all_fields_data = self.all_fields.SerializeToString()
+ self.empty_message = unittest_pb2.TestEmptyMessage()
+ self.empty_message.ParseFromString(self.all_fields_data)
+ self.unknown_fields = self.empty_message._unknown_fields
+
+ def GetField(self, name):
+ field_descriptor = self.descriptor.fields_by_name[name]
+ wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
+ field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
+ result_dict = {}
+ for tag_bytes, value in self.unknown_fields:
+ if tag_bytes == field_tag:
+ decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes]
+ decoder(value, 0, len(value), self.all_fields, result_dict)
+ return result_dict[field_descriptor]
+
+ def testEnum(self):
+ value = self.GetField('optional_nested_enum')
+ self.assertEqual(self.all_fields.optional_nested_enum, value)
+
+ def testRepeatedEnum(self):
+ value = self.GetField('repeated_nested_enum')
+ self.assertEqual(self.all_fields.repeated_nested_enum, value)
+
+ def testVarint(self):
+ value = self.GetField('optional_int32')
+ self.assertEqual(self.all_fields.optional_int32, value)
+
+ def testFixed32(self):
+ value = self.GetField('optional_fixed32')
+ self.assertEqual(self.all_fields.optional_fixed32, value)
+
+ def testFixed64(self):
+ value = self.GetField('optional_fixed64')
+ self.assertEqual(self.all_fields.optional_fixed64, value)
+
+ def testLengthDelimited(self):
+ value = self.GetField('optional_string')
+ self.assertEqual(self.all_fields.optional_string, value)
+
+ def testGroup(self):
+ value = self.GetField('optionalgroup')
+ self.assertEqual(self.all_fields.optionalgroup, value)
+
+ def testSerialize(self):
+ data = self.empty_message.SerializeToString()
+
+ # Don't use assertEqual because we don't want to dump raw binary data to
+ # stdout.
+ self.assertTrue(data == self.all_fields_data)
+
+ def testCopyFrom(self):
+ message = unittest_pb2.TestEmptyMessage()
+ message.CopyFrom(self.empty_message)
+ self.assertEqual(self.unknown_fields, message._unknown_fields)
+
+ def testMergeFrom(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_int32 = 1
+ message.optional_uint32 = 2
+ source = unittest_pb2.TestEmptyMessage()
+ source.ParseFromString(message.SerializeToString())
+
+ message.ClearField('optional_int32')
+ message.optional_int64 = 3
+ message.optional_uint32 = 4
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ unknown_fields = destination._unknown_fields[:]
+
+ destination.MergeFrom(source)
+ self.assertEqual(unknown_fields + source._unknown_fields,
+ destination._unknown_fields)
+
+ def testClear(self):
+ self.empty_message.Clear()
+ self.assertEqual(0, len(self.empty_message._unknown_fields))
+
+ def testByteSize(self):
+ self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
+
+ def testUnknownExtensions(self):
+ message = unittest_pb2.TestEmptyMessageWithExtensions()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(self.empty_message._unknown_fields,
+ message._unknown_fields)
+
+ def testListFields(self):
+ # Make sure ListFields doesn't return unknown fields.
+ self.assertEqual(0, len(self.empty_message.ListFields()))
+
+ def testSerializeMessageSetWireFormatUnknownExtension(self):
+ # Create a message using the message set wire format with an unknown
+ # message.
+ raw = unittest_mset_pb2.RawMessageSet()
+
+ # Add an unknown extension.
+ item = raw.item.add()
+ item.type_id = 1545009
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.i = 12345
+ item.message = message1.SerializeToString()
+
+ serialized = raw.SerializeToString()
+
+ # Parse message using the message set wire format.
+ proto = unittest_mset_pb2.TestMessageSet()
+ proto.MergeFromString(serialized)
+
+ # Verify that the unknown extension is serialized unchanged
+ reserialized = proto.SerializeToString()
+ new_raw = unittest_mset_pb2.RawMessageSet()
+ new_raw.MergeFromString(reserialized)
+ self.assertEqual(raw, new_raw)
+
+ def testEquals(self):
+ message = unittest_pb2.TestEmptyMessage()
+ message.ParseFromString(self.all_fields_data)
+ self.assertEqual(self.empty_message, message)
+
+ self.all_fields.ClearField('optional_string')
+ message.ParseFromString(self.all_fields.SerializeToString())
+ self.assertNotEqual(self.empty_message, message)
+
+
+class UnknownFieldsTest(basetest.TestCase):
+
+ def setUp(self):
+ self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
+
+ self.message = missing_enum_values_pb2.TestEnumValues()
+ self.message.optional_nested_enum = (
+ missing_enum_values_pb2.TestEnumValues.ZERO)
+ self.message.repeated_nested_enum.extend([
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
+ self.message.packed_nested_enum.extend([
+ missing_enum_values_pb2.TestEnumValues.ZERO,
+ missing_enum_values_pb2.TestEnumValues.ONE,
+ ])
+ self.message_data = self.message.SerializeToString()
+ self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
+ self.missing_message.ParseFromString(self.message_data)
+ self.unknown_fields = self.missing_message._unknown_fields
+
+ def GetField(self, name):
+ field_descriptor = self.descriptor.fields_by_name[name]
+ wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
+ field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
+ result_dict = {}
+ for tag_bytes, value in self.unknown_fields:
+ if tag_bytes == field_tag:
+ decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
+ tag_bytes]
+ decoder(value, 0, len(value), self.message, result_dict)
+ return result_dict[field_descriptor]
+
+ def testUnknownEnumValue(self):
+ self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
+ value = self.GetField('optional_nested_enum')
+ self.assertEqual(self.message.optional_nested_enum, value)
+
+ def testUnknownRepeatedEnumValue(self):
+ value = self.GetField('repeated_nested_enum')
+ self.assertEqual(self.message.repeated_nested_enum, value)
+
+ def testUnknownPackedEnumValue(self):
+ value = self.GetField('packed_nested_enum')
+ self.assertEqual(self.message.packed_nested_enum, value)
+
+ def testRoundTrip(self):
+ new_message = missing_enum_values_pb2.TestEnumValues()
+ new_message.ParseFromString(self.missing_message.SerializeToString())
+ self.assertEqual(self.message, new_message)
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
index 7600778..9362c72 100755
--- a/python/google/protobuf/internal/wire_format_test.py
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -34,12 +34,12 @@
__author__ = 'robinson@google.com (Will Robinson)'
-import unittest
+from google.apputils import basetest
from google.protobuf import message
from google.protobuf.internal import wire_format
-class WireFormatTest(unittest.TestCase):
+class WireFormatTest(basetest.TestCase):
def testPackTag(self):
field_number = 0xabc
@@ -195,7 +195,7 @@ class WireFormatTest(unittest.TestCase):
# Test UTF-8 string byte size calculation.
# 1 byte for tag, 1 byte for length, 8 bytes for content.
self.assertEqual(10, wire_format.StringByteSize(
- 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8')))
+ 5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8')))
class MockMessage(object):
def __init__(self, byte_size):
@@ -250,4 +250,4 @@ class WireFormatTest(unittest.TestCase):
if __name__ == '__main__':
- unittest.main()
+ basetest.main()
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index f839847..37b0af1 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -67,14 +67,28 @@ class Message(object):
DESCRIPTOR = None
+ def __deepcopy__(self, memo=None):
+ clone = type(self)()
+ clone.MergeFrom(self)
+ return clone
+
def __eq__(self, other_msg):
+ """Recursively compares two messages by value and structure."""
raise NotImplementedError
def __ne__(self, other_msg):
# Can't just say self != other_msg, since that would infinitely recurse. :)
return not self == other_msg
+ def __hash__(self):
+ raise TypeError('unhashable object')
+
def __str__(self):
+ """Outputs a human-readable representation of the message."""
+ raise NotImplementedError
+
+ def __unicode__(self):
+ """Outputs a human-readable representation of the message."""
raise NotImplementedError
def MergeFrom(self, other_msg):
@@ -163,7 +177,11 @@ class Message(object):
raise NotImplementedError
def ParseFromString(self, serialized):
- """Like MergeFromString(), except we clear the object first."""
+ """Parse serialized protocol buffer data into this message.
+
+ Like MergeFromString(), except we clear the object first and
+ do not return the value that MergeFromString returns.
+ """
self.Clear()
self.MergeFromString(serialized)
@@ -215,6 +233,9 @@ class Message(object):
raise NotImplementedError
def HasField(self, field_name):
+ """Checks if a certain field is set for the message. Note if the
+ field_name is not defined in the message descriptor, ValueError will be
+ raised."""
raise NotImplementedError
def ClearField(self, field_name):
@@ -252,3 +273,12 @@ class Message(object):
via a previous _SetListener() call.
"""
raise NotImplementedError
+
+ def __getstate__(self):
+ """Support the pickle protocol."""
+ return dict(serialized=self.SerializePartialToString())
+
+ def __setstate__(self, state):
+ """Support the pickle protocol."""
+ self.__init__()
+ self.ParseFromString(state['serialized'])
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
new file mode 100644
index 0000000..9004ffd
--- /dev/null
+++ b/python/google/protobuf/message_factory.py
@@ -0,0 +1,155 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#PY25 compatible for GAE.
+#
+# Copyright 2012 Google Inc. All Rights Reserved.
+
+"""Provides a factory class for generating dynamic messages.
+
+The easiest way to use this class is if you have access to the FileDescriptor
+protos containing the messages you want to create you can just do the following:
+
+message_classes = message_factory.GetMessages(iterable_of_file_descriptors)
+my_proto_instance = message_classes['some.proto.package.MessageName']()
+"""
+
+__author__ = 'matthewtoia@google.com (Matt Toia)'
+
+import sys ##PY25
+from google.protobuf import descriptor_database
+from google.protobuf import descriptor_pool
+from google.protobuf import message
+from google.protobuf import reflection
+
+
+class MessageFactory(object):
+ """Factory for creating Proto2 messages from descriptors in a pool."""
+
+ def __init__(self, pool=None):
+ """Initializes a new factory."""
+ self.pool = (pool or descriptor_pool.DescriptorPool(
+ descriptor_database.DescriptorDatabase()))
+
+ # local cache of all classes built from protobuf descriptors
+ self._classes = {}
+
+ def GetPrototype(self, descriptor):
+ """Builds a proto2 message class based on the passed in descriptor.
+
+ Passing a descriptor with a fully qualified name matching a previous
+ invocation will cause the same class to be returned.
+
+ Args:
+ descriptor: The descriptor to build from.
+
+ Returns:
+ A class describing the passed in descriptor.
+ """
+ if descriptor.full_name not in self._classes:
+ descriptor_name = descriptor.name
+ if sys.version_info[0] < 3: ##PY25
+##!PY25 if str is bytes: # PY2
+ descriptor_name = descriptor.name.encode('ascii', 'ignore')
+ result_class = reflection.GeneratedProtocolMessageType(
+ descriptor_name,
+ (message.Message,),
+ {'DESCRIPTOR': descriptor, '__module__': None})
+ # If module not set, it wrongly points to the reflection.py module.
+ self._classes[descriptor.full_name] = result_class
+ for field in descriptor.fields:
+ if field.message_type:
+ self.GetPrototype(field.message_type)
+ for extension in result_class.DESCRIPTOR.extensions:
+ if extension.containing_type.full_name not in self._classes:
+ self.GetPrototype(extension.containing_type)
+ extended_class = self._classes[extension.containing_type.full_name]
+ extended_class.RegisterExtension(extension)
+ return self._classes[descriptor.full_name]
+
+ def GetMessages(self, files):
+ """Gets all the messages from a specified file.
+
+ This will find and resolve dependencies, failing if the descriptor
+ pool cannot satisfy them.
+
+ Args:
+ files: The file names to extract messages from.
+
+ Returns:
+ A dictionary mapping proto names to the message classes. This will include
+ any dependent messages as well as any messages defined in the same file as
+ a specified message.
+ """
+ result = {}
+ for file_name in files:
+ file_desc = self.pool.FindFileByName(file_name)
+ for name, msg in file_desc.message_types_by_name.iteritems():
+ if file_desc.package:
+ full_name = '.'.join([file_desc.package, name])
+ else:
+ full_name = msg.name
+ result[full_name] = self.GetPrototype(
+ self.pool.FindMessageTypeByName(full_name))
+
+ # While the extension FieldDescriptors are created by the descriptor pool,
+ # the python classes created in the factory need them to be registered
+ # explicitly, which is done below.
+ #
+ # The call to RegisterExtension will specifically check if the
+ # extension was already registered on the object and either
+ # ignore the registration if the original was the same, or raise
+ # an error if they were different.
+
+ for name, extension in file_desc.extensions_by_name.iteritems():
+ if extension.containing_type.full_name not in self._classes:
+ self.GetPrototype(extension.containing_type)
+ extended_class = self._classes[extension.containing_type.full_name]
+ extended_class.RegisterExtension(extension)
+ return result
+
+
+_FACTORY = MessageFactory()
+
+
+def GetMessages(file_protos):
+ """Builds a dictionary of all the messages available in a set of files.
+
+ Args:
+ file_protos: A sequence of file protos to build messages out of.
+
+ Returns:
+ A dictionary mapping proto names to the message classes. This will include
+ any dependent messages as well as any messages defined in the same file as
+ a specified message.
+ """
+ for file_proto in file_protos:
+ _FACTORY.pool.Add(file_proto)
+ return _FACTORY.GetMessages([file_proto.name for file_proto in file_protos])
diff --git a/python/google/protobuf/pyext/README b/python/google/protobuf/pyext/README
new file mode 100644
index 0000000..6d61cb4
--- /dev/null
+++ b/python/google/protobuf/pyext/README
@@ -0,0 +1,6 @@
+This is the 'v2' C++ implementation for python proto2.
+
+It is active when:
+
+PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp
+PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2
diff --git a/python/google/protobuf/pyext/__init__.py b/python/google/protobuf/pyext/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/python/google/protobuf/pyext/__init__.py
diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py
new file mode 100644
index 0000000..ba87f8e
--- /dev/null
+++ b/python/google/protobuf/pyext/cpp_message.py
@@ -0,0 +1,61 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Protocol message implementation hooks for C++ implementation.
+
+Contains helper functions used to create protocol message classes from
+Descriptor objects at runtime backed by the protocol buffer C++ API.
+"""
+
+__author__ = 'tibell@google.com (Johan Tibell)'
+
+from google.protobuf.pyext import _message
+from google.protobuf import message
+
+
+def NewMessage(bases, message_descriptor, dictionary):
+ """Creates a new protocol message *class*."""
+ new_bases = []
+ for base in bases:
+ if base is message.Message:
+ # _message.Message must come before message.Message as it
+ # overrides methods in that class.
+ new_bases.append(_message.Message)
+ new_bases.append(base)
+ return tuple(new_bases)
+
+
+def InitMessage(message_descriptor, cls):
+ """Constructs a new message instance (called before instance's __init__)."""
+
+ def SubInit(self, **kwargs):
+ super(cls, self).__init__(message_descriptor, **kwargs)
+ cls.__init__ = SubInit
+ cls.AddDescriptors(message_descriptor)
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
new file mode 100644
index 0000000..cbf42c0
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -0,0 +1,357 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: petar@google.com (Petar Petrov)
+
+#include <Python.h>
+#include <string>
+
+#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#define C(str) const_cast<char*>(str)
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #define PyInt_FromLong PyLong_FromLong
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #else
+ #define PyString_AsString(ob) \
+ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
+ #endif
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+
+#ifndef PyVarObject_HEAD_INIT
+#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
+#endif
+#ifndef Py_TYPE
+#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
+#endif
+
+
+static google::protobuf::DescriptorPool* g_descriptor_pool = NULL;
+
+namespace cfield_descriptor {
+
+static void Dealloc(CFieldDescriptor* self) {
+ Py_CLEAR(self->descriptor_field);
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+static PyObject* GetFullName(CFieldDescriptor* self, void *closure) {
+ return PyString_FromStringAndSize(
+ self->descriptor->full_name().c_str(),
+ self->descriptor->full_name().size());
+}
+
+static PyObject* GetName(CFieldDescriptor *self, void *closure) {
+ return PyString_FromStringAndSize(
+ self->descriptor->name().c_str(),
+ self->descriptor->name().size());
+}
+
+static PyObject* GetCppType(CFieldDescriptor *self, void *closure) {
+ return PyInt_FromLong(self->descriptor->cpp_type());
+}
+
+static PyObject* GetLabel(CFieldDescriptor *self, void *closure) {
+ return PyInt_FromLong(self->descriptor->label());
+}
+
+static PyObject* GetID(CFieldDescriptor *self, void *closure) {
+ return PyLong_FromVoidPtr(self);
+}
+
+static PyGetSetDef Getters[] = {
+ { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL},
+ { C("name"), (getter)GetName, NULL, "last name", NULL},
+ { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL},
+ { C("label"), (getter)GetLabel, NULL, "Label", NULL},
+ { C("id"), (getter)GetID, NULL, "ID", NULL},
+ {NULL}
+};
+
+} // namespace cfield_descriptor
+
+PyTypeObject CFieldDescriptor_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ C("google.protobuf.internal."
+ "_net_proto2___python."
+ "CFieldDescriptor"), // tp_name
+ sizeof(CFieldDescriptor), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)cfield_descriptor::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ C("A Field Descriptor"), // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ 0, // tp_methods
+ 0, // tp_members
+ cfield_descriptor::Getters, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ PyType_GenericAlloc, // tp_alloc
+ PyType_GenericNew, // tp_new
+ PyObject_Del, // tp_free
+};
+
+namespace cdescriptor_pool {
+
+static void Dealloc(CDescriptorPool* self) {
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+static PyObject* NewCDescriptor(
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ CFieldDescriptor* cfield_descriptor = PyObject_New(
+ CFieldDescriptor, &CFieldDescriptor_Type);
+ if (cfield_descriptor == NULL) {
+ return NULL;
+ }
+ cfield_descriptor->descriptor = field_descriptor;
+ cfield_descriptor->descriptor_field = NULL;
+
+ return reinterpret_cast<PyObject*>(cfield_descriptor);
+}
+
+PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name) {
+ const char* full_field_name = PyString_AsString(name);
+ if (full_field_name == NULL) {
+ return NULL;
+ }
+
+ const google::protobuf::FieldDescriptor* field_descriptor = NULL;
+
+ field_descriptor = self->pool->FindFieldByName(full_field_name);
+
+ if (field_descriptor == NULL) {
+ PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s",
+ full_field_name);
+ return NULL;
+ }
+
+ return NewCDescriptor(field_descriptor);
+}
+
+PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg) {
+ const char* full_field_name = PyString_AsString(arg);
+ if (full_field_name == NULL) {
+ return NULL;
+ }
+
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ self->pool->FindExtensionByName(full_field_name);
+ if (field_descriptor == NULL) {
+ PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s",
+ full_field_name);
+ return NULL;
+ }
+
+ return NewCDescriptor(field_descriptor);
+}
+
+static PyMethodDef Methods[] = {
+ { C("FindFieldByName"),
+ (PyCFunction)FindFieldByName,
+ METH_O,
+ C("Searches for a field descriptor by full name.") },
+ { C("FindExtensionByName"),
+ (PyCFunction)FindExtensionByName,
+ METH_O,
+ C("Searches for extension descriptor by full name.") },
+ {NULL}
+};
+
+} // namespace cdescriptor_pool
+
+PyTypeObject CDescriptorPool_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ C("google.protobuf.internal."
+ "_net_proto2___python."
+ "CFieldDescriptor"), // tp_name
+ sizeof(CDescriptorPool), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)cdescriptor_pool::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ C("A Descriptor Pool"), // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ cdescriptor_pool::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ 0, // tp_init
+ PyType_GenericAlloc, // tp_alloc
+ PyType_GenericNew, // tp_new
+ PyObject_Del, // tp_free
+};
+
+google::protobuf::DescriptorPool* GetDescriptorPool() {
+ if (g_descriptor_pool == NULL) {
+ g_descriptor_pool = new google::protobuf::DescriptorPool(
+ google::protobuf::DescriptorPool::generated_pool());
+ }
+ return g_descriptor_pool;
+}
+
+PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args) {
+ CDescriptorPool* cdescriptor_pool = PyObject_New(
+ CDescriptorPool, &CDescriptorPool_Type);
+ if (cdescriptor_pool == NULL) {
+ return NULL;
+ }
+ cdescriptor_pool->pool = GetDescriptorPool();
+ return reinterpret_cast<PyObject*>(cdescriptor_pool);
+}
+
+
+// Collects errors that occur during proto file building to allow them to be
+// propagated in the python exception instead of only living in ERROR logs.
+class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector {
+ public:
+ BuildFileErrorCollector() : error_message(""), had_errors(false) {}
+
+ void AddError(const string& filename, const string& element_name,
+ const Message* descriptor, ErrorLocation location,
+ const string& message) {
+ // Replicates the logging behavior that happens in the C++ implementation
+ // when an error collector is not passed in.
+ if (!had_errors) {
+ error_message +=
+ ("Invalid proto descriptor for file \"" + filename + "\":\n");
+ }
+ // As this only happens on failure and will result in the program not
+ // running at all, no effort is made to optimize this string manipulation.
+ error_message += (" " + element_name + ": " + message + "\n");
+ }
+
+ string error_message;
+ bool had_errors;
+};
+
+PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) {
+ char* message_type;
+ Py_ssize_t message_len;
+
+ if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) {
+ return NULL;
+ }
+
+ google::protobuf::FileDescriptorProto file_proto;
+ if (!file_proto.ParseFromArray(message_type, message_len)) {
+ PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!");
+ return NULL;
+ }
+
+ if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName(
+ file_proto.name()) != NULL) {
+ Py_RETURN_NONE;
+ }
+
+ BuildFileErrorCollector error_collector;
+ const google::protobuf::FileDescriptor* descriptor =
+ GetDescriptorPool()->BuildFileCollectingErrors(file_proto,
+ &error_collector);
+ if (descriptor == NULL) {
+ PyErr_Format(PyExc_TypeError,
+ "Couldn't build proto file into descriptor pool!\n%s",
+ error_collector.error_message.c_str());
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
+}
+
+bool InitDescriptor() {
+ CFieldDescriptor_Type.tp_new = PyType_GenericNew;
+ if (PyType_Ready(&CFieldDescriptor_Type) < 0)
+ return false;
+
+ CDescriptorPool_Type.tp_new = PyType_GenericNew;
+ if (PyType_Ready(&CDescriptorPool_Type) < 0)
+ return false;
+
+ return true;
+}
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
new file mode 100644
index 0000000..d114425
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -0,0 +1,96 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: petar@google.com (Petar Petrov)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
+
+#include <Python.h>
+#include <structmember.h>
+
+#include <google/protobuf/descriptor.h>
+
+#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
+typedef int Py_ssize_t;
+#define PY_SSIZE_T_MAX INT_MAX
+#define PY_SSIZE_T_MIN INT_MIN
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+typedef struct CFieldDescriptor {
+ PyObject_HEAD
+
+ // The proto2 descriptor that this object represents.
+ const google::protobuf::FieldDescriptor* descriptor;
+
+ // Reference to the original field object in the Python DESCRIPTOR.
+ PyObject* descriptor_field;
+} CFieldDescriptor;
+
+typedef struct {
+ PyObject_HEAD
+
+ const google::protobuf::DescriptorPool* pool;
+} CDescriptorPool;
+
+extern PyTypeObject CFieldDescriptor_Type;
+
+extern PyTypeObject CDescriptorPool_Type;
+
+namespace cdescriptor_pool {
+
+// Looks up a field by name. Returns a CDescriptor corresponding to
+// the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindFieldByName(CDescriptorPool* self, PyObject* name);
+
+// Looks up an extension by name. Returns a CDescriptor corresponding
+// to the field on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* FindExtensionByName(CDescriptorPool* self, PyObject* arg);
+
+} // namespace cdescriptor_pool
+
+PyObject* Python_NewCDescriptorPool(PyObject* ignored, PyObject* args);
+PyObject* Python_BuildFile(PyObject* ignored, PyObject* args);
+bool InitDescriptor();
+google::protobuf::DescriptorPool* GetDescriptorPool();
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
diff --git a/python/google/protobuf/pyext/descriptor_cpp2_test.py b/python/google/protobuf/pyext/descriptor_cpp2_test.py
new file mode 100644
index 0000000..3a3ff29
--- /dev/null
+++ b/python/google/protobuf/pyext/descriptor_cpp2_test.py
@@ -0,0 +1,58 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.pyext behavior."""
+
+__author__ = 'anuraag@google.com (Anuraag Agrawal)'
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
+
+# We must set the implementation version above before the google3 imports.
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+# Run all tests from the original module by putting them in our namespace.
+# pylint: disable=wildcard-import
+from google.protobuf.internal.descriptor_test import *
+
+
+class ConfirmCppApi2Test(basetest.TestCase):
+
+ def testImplementationSetting(self):
+ self.assertEqual('cpp', api_implementation.Type())
+ self.assertEqual(2, api_implementation.Version())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
new file mode 100644
index 0000000..1e14b42
--- /dev/null
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -0,0 +1,338 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#include <google/protobuf/pyext/extension_dict.h>
+
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/repeated_composite_container.h>
+#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/stubs/shared_ptr.h>
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+extern google::protobuf::DynamicMessageFactory* global_message_factory;
+
+namespace extension_dict {
+
+// TODO(tibell): Always use self->message for clarity, just like in
+// RepeatedCompositeContainer.
+static google::protobuf::Message* GetMessage(ExtensionDict* self) {
+ if (self->parent != NULL) {
+ return self->parent->message;
+ } else {
+ return self->message;
+ }
+}
+
+CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension) {
+ PyObject* cdescriptor = PyObject_GetAttrString(extension, "_cdescriptor");
+ if (cdescriptor == NULL) {
+ PyErr_SetString(PyExc_KeyError, "Unregistered extension.");
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor");
+ Py_DECREF(cdescriptor);
+ return NULL;
+ }
+ CFieldDescriptor* descriptor =
+ reinterpret_cast<CFieldDescriptor*>(cdescriptor);
+ return descriptor;
+}
+
+PyObject* len(ExtensionDict* self) {
+#if PY_MAJOR_VERSION >= 3
+ return PyLong_FromLong(PyDict_Size(self->values));
+#else
+ return PyInt_FromLong(PyDict_Size(self->values));
+#endif
+}
+
+// TODO(tibell): Use VisitCompositeField.
+int ReleaseExtension(ExtensionDict* self,
+ PyObject* extension,
+ const google::protobuf::FieldDescriptor* descriptor) {
+ if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (repeated_composite_container::Release(
+ reinterpret_cast<RepeatedCompositeContainer*>(
+ extension)) < 0) {
+ return -1;
+ }
+ } else {
+ if (repeated_scalar_container::Release(
+ reinterpret_cast<RepeatedScalarContainer*>(
+ extension)) < 0) {
+ return -1;
+ }
+ }
+ } else if (descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (cmessage::ReleaseSubMessage(
+ GetMessage(self), descriptor,
+ reinterpret_cast<CMessage*>(extension)) < 0) {
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+PyObject* subscript(ExtensionDict* self, PyObject* key) {
+ CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
+ key);
+ if (cdescriptor == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
+ const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
+ descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
+ return cmessage::InternalGetScalar(self->parent, descriptor);
+ }
+
+ PyObject* value = PyDict_GetItem(self->values, key);
+ if (value != NULL) {
+ Py_INCREF(value);
+ return value;
+ }
+
+ if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
+ descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyObject* sub_message = cmessage::InternalGetSubMessage(
+ self->parent, cdescriptor);
+ if (sub_message == NULL) {
+ return NULL;
+ }
+ PyDict_SetItem(self->values, key, sub_message);
+ return sub_message;
+ }
+
+ if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ // COPIED
+ PyObject* py_container = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type),
+ NULL);
+ if (py_container == NULL) {
+ return NULL;
+ }
+ RepeatedCompositeContainer* container =
+ reinterpret_cast<RepeatedCompositeContainer*>(py_container);
+ PyObject* field = cdescriptor->descriptor_field;
+ PyObject* message_type = PyObject_GetAttrString(field, "message_type");
+ PyObject* concrete_class = PyObject_GetAttrString(message_type,
+ "_concrete_class");
+ container->owner = self->owner;
+ container->parent = self->parent;
+ container->message = self->parent->message;
+ container->parent_field = cdescriptor;
+ container->subclass_init = concrete_class;
+ Py_DECREF(message_type);
+ PyDict_SetItem(self->values, key, py_container);
+ return py_container;
+ } else {
+ // COPIED
+ ScopedPyObjectPtr init_args(PyTuple_Pack(2, self->parent, cdescriptor));
+ PyObject* py_container = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type),
+ init_args);
+ if (py_container == NULL) {
+ return NULL;
+ }
+ PyDict_SetItem(self->values, key, py_container);
+ return py_container;
+ }
+ }
+ PyErr_SetString(PyExc_ValueError, "control reached unexpected line");
+ return NULL;
+}
+
+int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
+ CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
+ key);
+ if (cdescriptor == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
+ const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
+ if (descriptor->label() != FieldDescriptor::LABEL_OPTIONAL ||
+ descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyErr_SetString(PyExc_TypeError, "Extension is repeated and/or composite "
+ "type");
+ return -1;
+ }
+ cmessage::AssureWritable(self->parent);
+ if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
+ return -1;
+ }
+ // TODO(tibell): We shouldn't write scalars to the cache.
+ PyDict_SetItem(self->values, key, value);
+ return 0;
+}
+
+PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) {
+ CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
+ extension);
+ if (cdescriptor == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
+ PyObject* value = PyDict_GetItem(self->values, extension);
+ if (value != NULL) {
+ if (ReleaseExtension(self, value, cdescriptor->descriptor) < 0) {
+ return NULL;
+ }
+ }
+ if (cmessage::ClearFieldByDescriptor(self->parent,
+ cdescriptor->descriptor) == NULL) {
+ return NULL;
+ }
+ if (PyDict_DelItem(self->values, extension) < 0) {
+ PyErr_Clear();
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* HasExtension(ExtensionDict* self, PyObject* extension) {
+ CFieldDescriptor* cdescriptor = InternalGetCDescriptorFromExtension(
+ extension);
+ if (cdescriptor == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
+ PyObject* result = cmessage::HasFieldByDescriptor(
+ self->parent, cdescriptor->descriptor);
+ return result;
+}
+
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) {
+ ScopedPyObjectPtr extensions_by_name(PyObject_GetAttrString(
+ reinterpret_cast<PyObject*>(self->parent), "_extensions_by_name"));
+ if (extensions_by_name == NULL) {
+ return NULL;
+ }
+ PyObject* result = PyDict_GetItem(extensions_by_name, name);
+ if (result == NULL) {
+ Py_RETURN_NONE;
+ } else {
+ Py_INCREF(result);
+ return result;
+ }
+}
+
+int init(ExtensionDict* self, PyObject* args, PyObject* kwargs) {
+ self->parent = NULL;
+ self->message = NULL;
+ self->values = PyDict_New();
+ return 0;
+}
+
+void dealloc(ExtensionDict* self) {
+ Py_CLEAR(self->values);
+ self->owner.reset();
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+static PyMappingMethods MpMethods = {
+ (lenfunc)len, /* mp_length */
+ (binaryfunc)subscript, /* mp_subscript */
+ (objobjargproc)ass_subscript,/* mp_ass_subscript */
+};
+
+#define EDMETHOD(name, args, doc) { #name, (PyCFunction)name, args, doc }
+static PyMethodDef Methods[] = {
+ EDMETHOD(ClearExtension, METH_O, "Clears an extension from the object."),
+ EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."),
+ EDMETHOD(_FindExtensionByName, METH_O,
+ "Finds an extension by name."),
+ { NULL, NULL }
+};
+
+} // namespace extension_dict
+
+PyTypeObject ExtensionDict_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "google.protobuf.internal."
+ "cpp._message.ExtensionDict", // tp_name
+ sizeof(ExtensionDict), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)extension_dict::dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ &extension_dict::MpMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "An extension dict", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ 0, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ extension_dict::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ (initproc)extension_dict::init, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
new file mode 100644
index 0000000..1343001
--- /dev/null
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -0,0 +1,123 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+
+namespace google {
+namespace protobuf {
+
+class Message;
+class FieldDescriptor;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CMessage;
+struct CFieldDescriptor;
+
+typedef struct ExtensionDict {
+ PyObject_HEAD;
+ shared_ptr<Message> owner;
+ CMessage* parent;
+ Message* message;
+ PyObject* values;
+} ExtensionDict;
+
+extern PyTypeObject ExtensionDict_Type;
+
+namespace extension_dict {
+
+// Gets the _cdescriptor reference to a CFieldDescriptor object given a
+// python descriptor object.
+//
+// Returns a new reference.
+CFieldDescriptor* InternalGetCDescriptorFromExtension(PyObject* extension);
+
+// Gets the number of extension values in this ExtensionDict as a python object.
+//
+// Returns a new reference.
+PyObject* len(ExtensionDict* self);
+
+// Releases extensions referenced outside this dictionary to keep outside
+// references alive.
+//
+// Returns 0 on success, -1 on failure.
+int ReleaseExtension(ExtensionDict* self,
+ PyObject* extension,
+ const google::protobuf::FieldDescriptor* descriptor);
+
+// Gets an extension from the dict for the given extension descriptor.
+//
+// Returns a new reference.
+PyObject* subscript(ExtensionDict* self, PyObject* key);
+
+// Assigns a value to an extension in the dict. Can only be used for singular
+// simple types.
+//
+// Returns 0 on success, -1 on failure.
+int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value);
+
+// Clears an extension from the dict. Will release the extension if there
+// is still an external reference left to it.
+//
+// Returns None on success.
+PyObject* ClearExtension(ExtensionDict* self,
+ PyObject* extension);
+
+// Checks if the dict has an extension.
+//
+// Returns a new python boolean reference.
+PyObject* HasExtension(ExtensionDict* self, PyObject* extension);
+
+// Gets an extension from the dict given the extension name as opposed to
+// descriptor.
+//
+// Returns a new reference.
+PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name);
+
+} // namespace extension_dict
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
new file mode 100644
index 0000000..c45cbf0
--- /dev/null
+++ b/python/google/protobuf/pyext/message.cc
@@ -0,0 +1,2561 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#include <google/protobuf/pyext/message.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+#include <string>
+#include <vector>
+
+#ifndef PyVarObject_HEAD_INIT
+#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size,
+#endif
+#ifndef Py_TYPE
+#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
+#endif
+#include <google/protobuf/descriptor.pb.h>
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/extension_dict.h>
+#include <google/protobuf/pyext/repeated_composite_container.h>
+#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyInt_Check PyLong_Check
+ #define PyInt_AsLong PyLong_AsLong
+ #define PyInt_FromLong PyLong_FromLong
+ #define PyInt_FromSize_t PyLong_FromSize_t
+ #define PyString_Check PyUnicode_Check
+ #define PyString_FromString PyUnicode_FromString
+ #define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #else
+ #define PyString_AsString(ob) \
+ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
+ #endif
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+// Forward declarations
+namespace cmessage {
+static PyObject* GetDescriptor(CMessage* self, PyObject* name);
+static string GetMessageName(CMessage* self);
+int InternalReleaseFieldByDescriptor(
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* composite_field,
+ google::protobuf::Message* parent_message);
+} // namespace cmessage
+
+// ---------------------------------------------------------------------
+// Visiting the composite children of a CMessage
+
+struct ChildVisitor {
+ // Returns 0 on success, -1 on failure.
+ int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
+ return 0;
+ }
+
+ // Returns 0 on success, -1 on failure.
+ int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
+ return 0;
+ }
+
+ // Returns 0 on success, -1 on failure.
+ int VisitCMessage(CMessage* cmessage,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ return 0;
+ }
+};
+
+// Apply a function to a composite field. Does nothing if child is of
+// non-composite type.
+template<class Visitor>
+static int VisitCompositeField(const FieldDescriptor* descriptor,
+ PyObject* child,
+ Visitor visitor) {
+ if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ RepeatedCompositeContainer* container =
+ reinterpret_cast<RepeatedCompositeContainer*>(child);
+ if (visitor.VisitRepeatedCompositeContainer(container) == -1)
+ return -1;
+ } else {
+ RepeatedScalarContainer* container =
+ reinterpret_cast<RepeatedScalarContainer*>(child);
+ if (visitor.VisitRepeatedScalarContainer(container) == -1)
+ return -1;
+ }
+ } else if (descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ CMessage* cmsg = reinterpret_cast<CMessage*>(child);
+ if (visitor.VisitCMessage(cmsg, descriptor) == -1)
+ return -1;
+ }
+ // The ExtensionDict might contain non-composite fields, which we
+ // skip here.
+ return 0;
+}
+
+// Visit each composite field and extension field of this CMessage.
+// Returns -1 on error and 0 on success.
+template<class Visitor>
+int ForEachCompositeField(CMessage* self, Visitor visitor) {
+ Py_ssize_t pos = 0;
+ PyObject* key;
+ PyObject* field;
+
+ // Visit normal fields.
+ while (PyDict_Next(self->composite_fields, &pos, &key, &field)) {
+ PyObject* cdescriptor = cmessage::GetDescriptor(self, key);
+ if (cdescriptor != NULL) {
+ const google::protobuf::FieldDescriptor* descriptor =
+ reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor;
+ if (VisitCompositeField(descriptor, field, visitor) == -1)
+ return -1;
+ }
+ }
+
+ // Visit extension fields.
+ if (self->extensions != NULL) {
+ while (PyDict_Next(self->extensions->values, &pos, &key, &field)) {
+ CFieldDescriptor* cdescriptor =
+ extension_dict::InternalGetCDescriptorFromExtension(key);
+ if (cdescriptor == NULL)
+ return -1;
+ if (VisitCompositeField(cdescriptor->descriptor, field, visitor) == -1)
+ return -1;
+ }
+ }
+
+ return 0;
+}
+
+// ---------------------------------------------------------------------
+
+// Constants used for integer type range checking.
+PyObject* kPythonZero;
+PyObject* kint32min_py;
+PyObject* kint32max_py;
+PyObject* kuint32max_py;
+PyObject* kint64min_py;
+PyObject* kint64max_py;
+PyObject* kuint64max_py;
+
+PyObject* EnumTypeWrapper_class;
+PyObject* EncodeError_class;
+PyObject* DecodeError_class;
+PyObject* PickleError_class;
+
+// Constant PyString values used for GetAttr/GetItem.
+static PyObject* kDESCRIPTOR;
+static PyObject* k__descriptors;
+static PyObject* kfull_name;
+static PyObject* kname;
+static PyObject* kmessage_type;
+static PyObject* kis_extendable;
+static PyObject* kextensions_by_name;
+static PyObject* k_extensions_by_name;
+static PyObject* k_extensions_by_number;
+static PyObject* k_concrete_class;
+static PyObject* kfields_by_name;
+
+static CDescriptorPool* descriptor_pool;
+
+/* Is 64bit */
+void FormatTypeError(PyObject* arg, char* expected_types) {
+ PyObject* repr = PyObject_Repr(arg);
+ if (repr) {
+ PyErr_Format(PyExc_TypeError,
+ "%.100s has type %.100s, but expected one of: %s",
+ PyString_AsString(repr),
+ Py_TYPE(arg)->tp_name,
+ expected_types);
+ Py_DECREF(repr);
+ }
+}
+
+template<class T>
+bool CheckAndGetInteger(
+ PyObject* arg, T* value, PyObject* min, PyObject* max) {
+ bool is_long = PyLong_Check(arg);
+#if PY_MAJOR_VERSION < 3
+ if (!PyInt_Check(arg) && !is_long) {
+ FormatTypeError(arg, "int, long");
+ return false;
+ }
+ if (PyObject_Compare(min, arg) > 0 || PyObject_Compare(max, arg) < 0) {
+#else
+ if (!is_long) {
+ FormatTypeError(arg, "int");
+ return false;
+ }
+ if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 ||
+ PyObject_RichCompareBool(max, arg, Py_GE) != 1) {
+#endif
+ PyObject *s = PyObject_Str(arg);
+ if (s) {
+ PyErr_Format(PyExc_ValueError,
+ "Value out of range: %s",
+ PyString_AsString(s));
+ Py_DECREF(s);
+ }
+ return false;
+ }
+#if PY_MAJOR_VERSION < 3
+ if (!is_long) {
+ *value = static_cast<T>(PyInt_AsLong(arg));
+ } else // NOLINT
+#endif
+ {
+ if (min == kPythonZero) {
+ *value = static_cast<T>(PyLong_AsUnsignedLongLong(arg));
+ } else {
+ *value = static_cast<T>(PyLong_AsLongLong(arg));
+ }
+ }
+ return true;
+}
+
+// These are referenced by repeated_scalar_container, and must
+// be explicitly instantiated.
+template bool CheckAndGetInteger<int32>(
+ PyObject*, int32*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<int64>(
+ PyObject*, int64*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<uint32>(
+ PyObject*, uint32*, PyObject*, PyObject*);
+template bool CheckAndGetInteger<uint64>(
+ PyObject*, uint64*, PyObject*, PyObject*);
+
+bool CheckAndGetDouble(PyObject* arg, double* value) {
+ if (!PyInt_Check(arg) && !PyLong_Check(arg) &&
+ !PyFloat_Check(arg)) {
+ FormatTypeError(arg, "int, long, float");
+ return false;
+ }
+ *value = PyFloat_AsDouble(arg);
+ return true;
+}
+
+bool CheckAndGetFloat(PyObject* arg, float* value) {
+ double double_value;
+ if (!CheckAndGetDouble(arg, &double_value)) {
+ return false;
+ }
+ *value = static_cast<float>(double_value);
+ return true;
+}
+
+bool CheckAndGetBool(PyObject* arg, bool* value) {
+ if (!PyInt_Check(arg) && !PyBool_Check(arg) && !PyLong_Check(arg)) {
+ FormatTypeError(arg, "int, long, bool");
+ return false;
+ }
+ *value = static_cast<bool>(PyInt_AsLong(arg));
+ return true;
+}
+
+bool CheckAndSetString(
+ PyObject* arg, google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* descriptor,
+ const google::protobuf::Reflection* reflection,
+ bool append,
+ int index) {
+ GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING ||
+ descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES);
+ if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) {
+ if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) {
+ FormatTypeError(arg, "bytes, unicode");
+ return false;
+ }
+
+ if (PyBytes_Check(arg)) {
+ PyObject* unicode = PyUnicode_FromEncodedObject(arg, "ascii", NULL);
+ if (unicode == NULL) {
+ PyObject* repr = PyObject_Repr(arg);
+ PyErr_Format(PyExc_ValueError,
+ "%s has type str, but isn't in 7-bit ASCII "
+ "encoding. Non-ASCII strings must be converted to "
+ "unicode objects before being added.",
+ PyString_AsString(repr));
+ Py_DECREF(repr);
+ return false;
+ } else {
+ Py_DECREF(unicode);
+ }
+ }
+ } else if (!PyBytes_Check(arg)) {
+ FormatTypeError(arg, "bytes");
+ return false;
+ }
+
+ PyObject* encoded_string = NULL;
+ if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) {
+ if (PyBytes_Check(arg)) {
+#if PY_MAJOR_VERSION < 3
+ encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL);
+#else
+ encoded_string = arg; // Already encoded.
+ Py_INCREF(encoded_string);
+#endif
+ } else {
+ encoded_string = PyUnicode_AsEncodedObject(arg, "utf-8", NULL);
+ }
+ } else {
+ // In this case field type is "bytes".
+ encoded_string = arg;
+ Py_INCREF(encoded_string);
+ }
+
+ if (encoded_string == NULL) {
+ return false;
+ }
+
+ char* value;
+ Py_ssize_t value_len;
+ if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) {
+ Py_DECREF(encoded_string);
+ return false;
+ }
+
+ string value_string(value, value_len);
+ if (append) {
+ reflection->AddString(message, descriptor, value_string);
+ } else if (index < 0) {
+ reflection->SetString(message, descriptor, value_string);
+ } else {
+ reflection->SetRepeatedString(message, descriptor, index, value_string);
+ }
+ Py_DECREF(encoded_string);
+ return true;
+}
+
+PyObject* ToStringObject(
+ const google::protobuf::FieldDescriptor* descriptor, string value) {
+ if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) {
+ return PyBytes_FromStringAndSize(value.c_str(), value.length());
+ }
+
+ PyObject* result = PyUnicode_DecodeUTF8(value.c_str(), value.length(), NULL);
+ // If the string can't be decoded in UTF-8, just return a string object that
+ // contains the raw bytes. This can't happen if the value was assigned using
+ // the members of the Python message object, but can happen if the values were
+ // parsed from the wire (binary).
+ if (result == NULL) {
+ PyErr_Clear();
+ result = PyBytes_FromStringAndSize(value.c_str(), value.length());
+ }
+ return result;
+}
+
+google::protobuf::DynamicMessageFactory* global_message_factory;
+
+namespace cmessage {
+
+static int MaybeReleaseOverlappingOneofField(
+ CMessage* cmessage,
+ const google::protobuf::FieldDescriptor* field) {
+#ifdef GOOGLE_PROTOBUF_HAS_ONEOF
+ google::protobuf::Message* message = cmessage->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ if (!field->containing_oneof() ||
+ !reflection->HasOneof(*message, field->containing_oneof()) ||
+ reflection->HasField(*message, field)) {
+ // No other field in this oneof, no need to release.
+ return 0;
+ }
+
+ const OneofDescriptor* oneof = field->containing_oneof();
+ const FieldDescriptor* existing_field =
+ reflection->GetOneofFieldDescriptor(*message, oneof);
+ if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ // Non-message fields don't need to be released.
+ return 0;
+ }
+ const char* field_name = existing_field->name().c_str();
+ PyObject* child_message = PyDict_GetItemString(
+ cmessage->composite_fields, field_name);
+ if (child_message == NULL) {
+ // No python reference to this field so no need to release.
+ return 0;
+ }
+
+ if (InternalReleaseFieldByDescriptor(
+ existing_field, child_message, message) < 0) {
+ return -1;
+ }
+ return PyDict_DelItemString(cmessage->composite_fields, field_name);
+#else
+ return 0;
+#endif
+}
+
+// ---------------------------------------------------------------------
+// Making a message writable
+
+static google::protobuf::Message* GetMutableMessage(
+ CMessage* parent,
+ const google::protobuf::FieldDescriptor* parent_field) {
+ google::protobuf::Message* parent_message = parent->message;
+ const google::protobuf::Reflection* reflection = parent_message->GetReflection();
+ if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) {
+ return NULL;
+ }
+ return reflection->MutableMessage(
+ parent_message, parent_field, global_message_factory);
+}
+
+struct FixupMessageReference : public ChildVisitor {
+ // message must outlive this object.
+ explicit FixupMessageReference(google::protobuf::Message* message) :
+ message_(message) {}
+
+ int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
+ container->message = message_;
+ return 0;
+ }
+
+ int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
+ container->message = message_;
+ return 0;
+ }
+
+ private:
+ google::protobuf::Message* message_;
+};
+
+int AssureWritable(CMessage* self) {
+ if (self == NULL || !self->read_only) {
+ return 0;
+ }
+
+ if (self->parent == NULL) {
+ // If parent is NULL but we are trying to modify a read-only message, this
+ // is a reference to a constant default instance that needs to be replaced
+ // with a mutable top-level message.
+ const Message* prototype = global_message_factory->GetPrototype(
+ self->message->GetDescriptor());
+ self->message = prototype->New();
+ self->owner.reset(self->message);
+ } else {
+ // Otherwise, we need a mutable child message.
+ if (AssureWritable(self->parent) == -1)
+ return -1;
+
+ // Make self->message writable.
+ google::protobuf::Message* parent_message = self->parent->message;
+ google::protobuf::Message* mutable_message = GetMutableMessage(
+ self->parent,
+ self->parent_field->descriptor);
+ if (mutable_message == NULL) {
+ return -1;
+ }
+ self->message = mutable_message;
+ }
+ self->read_only = false;
+
+ // When a CMessage is made writable its Message pointer is updated
+ // to point to a new mutable Message. When that happens we need to
+ // update any references to the old, read-only CMessage. There are
+ // three places such references occur: RepeatedScalarContainer,
+ // RepeatedCompositeContainer, and ExtensionDict.
+ if (self->extensions != NULL)
+ self->extensions->message = self->message;
+ if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1)
+ return -1;
+
+ return 0;
+}
+
+// --- Globals:
+
+static PyObject* GetDescriptor(CMessage* self, PyObject* name) {
+ PyObject* descriptors =
+ PyDict_GetItem(Py_TYPE(self)->tp_dict, k__descriptors);
+ if (descriptors == NULL) {
+ PyErr_SetString(PyExc_TypeError, "No __descriptors");
+ return NULL;
+ }
+
+ return PyDict_GetItem(descriptors, name);
+}
+
+static const google::protobuf::Message* CreateMessage(const char* message_type) {
+ string message_name(message_type);
+ const google::protobuf::Descriptor* descriptor =
+ GetDescriptorPool()->FindMessageTypeByName(message_name);
+ if (descriptor == NULL) {
+ PyErr_SetString(PyExc_TypeError, message_type);
+ return NULL;
+ }
+ return global_message_factory->GetPrototype(descriptor);
+}
+
+// If cmessage_list is not NULL, this function releases values into the
+// container CMessages instead of just removing. Repeated composite container
+// needs to do this to make sure CMessages stay alive if they're still
+// referenced after deletion. Repeated scalar container doesn't need to worry.
+int InternalDeleteRepeatedField(
+ google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* slice,
+ PyObject* cmessage_list) {
+ Py_ssize_t length, from, to, step, slice_length;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ int min, max;
+ length = reflection->FieldSize(*message, field_descriptor);
+
+ if (PyInt_Check(slice) || PyLong_Check(slice)) {
+ from = to = PyLong_AsLong(slice);
+ if (from < 0) {
+ from = to = length + from;
+ }
+ step = 1;
+ min = max = from;
+
+ // Range check.
+ if (from < 0 || from >= length) {
+ PyErr_Format(PyExc_IndexError, "list assignment index out of range");
+ return -1;
+ }
+ } else if (PySlice_Check(slice)) {
+ from = to = step = slice_length = 0;
+ PySlice_GetIndicesEx(
+#if PY_MAJOR_VERSION < 3
+ reinterpret_cast<PySliceObject*>(slice),
+#else
+ slice,
+#endif
+ length, &from, &to, &step, &slice_length);
+ if (from < to) {
+ min = from;
+ max = to - 1;
+ } else {
+ min = to + 1;
+ max = from;
+ }
+ } else {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ Py_ssize_t i = from;
+ std::vector<bool> to_delete(length, false);
+ while (i >= min && i <= max) {
+ to_delete[i] = true;
+ i += step;
+ }
+
+ to = 0;
+ for (i = 0; i < length; ++i) {
+ if (!to_delete[i]) {
+ if (i != to) {
+ reflection->SwapElements(message, field_descriptor, i, to);
+ if (cmessage_list != NULL) {
+ // If a list of cmessages is passed in (i.e. from a repeated
+ // composite container), swap those as well to correspond to the
+ // swaps in the underlying message so they're in the right order
+ // when we start releasing.
+ PyObject* tmp = PyList_GET_ITEM(cmessage_list, i);
+ PyList_SET_ITEM(cmessage_list, i,
+ PyList_GET_ITEM(cmessage_list, to));
+ PyList_SET_ITEM(cmessage_list, to, tmp);
+ }
+ }
+ ++to;
+ }
+ }
+
+ while (i > to) {
+ if (cmessage_list == NULL) {
+ reflection->RemoveLast(message, field_descriptor);
+ } else {
+ CMessage* last_cmessage = reinterpret_cast<CMessage*>(
+ PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1));
+ repeated_composite_container::ReleaseLastTo(
+ field_descriptor, message, last_cmessage);
+ if (PySequence_DelItem(cmessage_list, -1) < 0) {
+ return -1;
+ }
+ }
+ --i;
+ }
+
+ return 0;
+}
+
+int InitAttributes(CMessage* self, PyObject* arg, PyObject* kwargs) {
+ ScopedPyObjectPtr descriptor;
+ if (arg == NULL) {
+ descriptor.reset(
+ PyObject_GetAttr(reinterpret_cast<PyObject*>(self), kDESCRIPTOR));
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ } else {
+ descriptor.reset(arg);
+ descriptor.inc();
+ }
+ ScopedPyObjectPtr is_extendable(PyObject_GetAttr(descriptor, kis_extendable));
+ if (is_extendable == NULL) {
+ return NULL;
+ }
+ int retcode = PyObject_IsTrue(is_extendable);
+ if (retcode == -1) {
+ return NULL;
+ }
+ if (retcode) {
+ PyObject* py_extension_dict = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL);
+ if (py_extension_dict == NULL) {
+ return NULL;
+ }
+ ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>(
+ py_extension_dict);
+ extension_dict->parent = self;
+ extension_dict->message = self->message;
+ self->extensions = extension_dict;
+ }
+
+ if (kwargs == NULL) {
+ return 0;
+ }
+
+ Py_ssize_t pos = 0;
+ PyObject* name;
+ PyObject* value;
+ while (PyDict_Next(kwargs, &pos, &name, &value)) {
+ if (!PyString_Check(name)) {
+ PyErr_SetString(PyExc_ValueError, "Field name must be a string");
+ return -1;
+ }
+ PyObject* py_cdescriptor = GetDescriptor(self, name);
+ if (py_cdescriptor == NULL) {
+ PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.",
+ PyString_AsString(name));
+ return -1;
+ }
+ const google::protobuf::FieldDescriptor* descriptor =
+ reinterpret_cast<CFieldDescriptor*>(py_cdescriptor)->descriptor;
+ if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ ScopedPyObjectPtr container(GetAttr(self, name));
+ if (container == NULL) {
+ return -1;
+ }
+ if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ if (repeated_composite_container::Extend(
+ reinterpret_cast<RepeatedCompositeContainer*>(container.get()),
+ value)
+ == NULL) {
+ return -1;
+ }
+ } else {
+ if (repeated_scalar_container::Extend(
+ reinterpret_cast<RepeatedScalarContainer*>(container.get()),
+ value) ==
+ NULL) {
+ return -1;
+ }
+ }
+ } else if (descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ ScopedPyObjectPtr message(GetAttr(self, name));
+ if (message == NULL) {
+ return -1;
+ }
+ if (MergeFrom(reinterpret_cast<CMessage*>(message.get()),
+ value) == NULL) {
+ return -1;
+ }
+ } else {
+ if (SetAttr(self, name, value) < 0) {
+ return -1;
+ }
+ }
+ }
+ return 0;
+}
+
+static PyObject* New(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
+ CMessage* self = reinterpret_cast<CMessage*>(type->tp_alloc(type, 0));
+ if (self == NULL) {
+ return NULL;
+ }
+
+ self->message = NULL;
+ self->parent = NULL;
+ self->parent_field = NULL;
+ self->read_only = false;
+ self->extensions = NULL;
+
+ self->composite_fields = PyDict_New();
+ if (self->composite_fields == NULL) {
+ return NULL;
+ }
+ return reinterpret_cast<PyObject*>(self);
+}
+
+PyObject* NewEmpty(PyObject* type) {
+ return New(reinterpret_cast<PyTypeObject*>(type), NULL, NULL);
+}
+
+static int Init(CMessage* self, PyObject* args, PyObject* kwargs) {
+ if (kwargs == NULL) {
+ // TODO(anuraag): Set error
+ return -1;
+ }
+
+ PyObject* descriptor = PyTuple_GetItem(args, 0);
+ if (descriptor == NULL || PyTuple_Size(args) != 1) {
+ PyErr_SetString(PyExc_ValueError, "args must contain one arg: descriptor");
+ return -1;
+ }
+
+ ScopedPyObjectPtr py_message_type(PyObject_GetAttr(descriptor, kfull_name));
+ if (py_message_type == NULL) {
+ return -1;
+ }
+
+ const char* message_type = PyString_AsString(py_message_type.get());
+ const google::protobuf::Message* message = CreateMessage(message_type);
+ if (message == NULL) {
+ return -1;
+ }
+
+ self->message = message->New();
+ self->owner.reset(self->message);
+
+ if (InitAttributes(self, descriptor, kwargs) < 0) {
+ return -1;
+ }
+ return 0;
+}
+
+// ---------------------------------------------------------------------
+// Deallocating a CMessage
+//
+// Deallocating a CMessage requires that we clear any weak references
+// from children to the message being deallocated.
+
+// Clear the weak reference from the child to the parent.
+struct ClearWeakReferences : public ChildVisitor {
+ int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
+ container->parent = NULL;
+ // The elements in the container have the same parent as the
+ // container itself, so NULL out that pointer as well.
+ const Py_ssize_t n = PyList_GET_SIZE(container->child_messages);
+ for (Py_ssize_t i = 0; i < n; ++i) {
+ CMessage* child_cmessage = reinterpret_cast<CMessage*>(
+ PyList_GET_ITEM(container->child_messages, i));
+ child_cmessage->parent = NULL;
+ }
+ return 0;
+ }
+
+ int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
+ container->parent = NULL;
+ return 0;
+ }
+
+ int VisitCMessage(CMessage* cmessage,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ cmessage->parent = NULL;
+ return 0;
+ }
+};
+
+static void Dealloc(CMessage* self) {
+ // Null out all weak references from children to this message.
+ GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
+
+ Py_CLEAR(self->extensions);
+ Py_CLEAR(self->composite_fields);
+ self->owner.reset();
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+// ---------------------------------------------------------------------
+
+
+PyObject* IsInitialized(CMessage* self, PyObject* args) {
+ PyObject* errors = NULL;
+ if (PyArg_ParseTuple(args, "|O", &errors) < 0) {
+ return NULL;
+ }
+ if (self->message->IsInitialized()) {
+ Py_RETURN_TRUE;
+ }
+ if (errors != NULL) {
+ ScopedPyObjectPtr initialization_errors(
+ FindInitializationErrors(self));
+ if (initialization_errors == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr extend_name(PyString_FromString("extend"));
+ if (extend_name == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr result(PyObject_CallMethodObjArgs(
+ errors,
+ extend_name.get(),
+ initialization_errors.get(),
+ NULL));
+ if (result == NULL) {
+ return NULL;
+ }
+ }
+ Py_RETURN_FALSE;
+}
+
+PyObject* HasFieldByDescriptor(
+ CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) {
+ google::protobuf::Message* message = self->message;
+ if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
+ PyErr_SetString(PyExc_KeyError,
+ "Field does not belong to message!");
+ return NULL;
+ }
+ if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ PyErr_SetString(PyExc_KeyError,
+ "Field is repeated. A singular method is required.");
+ return NULL;
+ }
+ bool has_field =
+ message->GetReflection()->HasField(*message, field_descriptor);
+ return PyBool_FromLong(has_field ? 1 : 0);
+}
+
+const google::protobuf::FieldDescriptor* FindFieldWithOneofs(
+ const google::protobuf::Message* message, const char* field_name, bool* in_oneof) {
+ const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ descriptor->FindFieldByName(field_name);
+ if (field_descriptor == NULL) {
+ const google::protobuf::OneofDescriptor* oneof_desc =
+ message->GetDescriptor()->FindOneofByName(field_name);
+ if (oneof_desc == NULL) {
+ *in_oneof = false;
+ return NULL;
+ } else {
+ *in_oneof = true;
+ return message->GetReflection()->GetOneofFieldDescriptor(
+ *message, oneof_desc);
+ }
+ }
+ return field_descriptor;
+}
+
+PyObject* HasField(CMessage* self, PyObject* arg) {
+#if PY_MAJOR_VERSION < 3
+ char* field_name;
+ if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) {
+#else
+ char* field_name = PyUnicode_AsUTF8(arg);
+ if (!field_name) {
+#endif
+ return NULL;
+ }
+
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
+ bool is_in_oneof;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ FindFieldWithOneofs(message, field_name, &is_in_oneof);
+ if (field_descriptor == NULL) {
+ if (!is_in_oneof) {
+ PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name);
+ return NULL;
+ } else {
+ Py_RETURN_FALSE;
+ }
+ }
+
+ if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ PyErr_Format(PyExc_ValueError,
+ "Protocol message has no singular \"%s\" field.", field_name);
+ return NULL;
+ }
+
+ bool has_field =
+ message->GetReflection()->HasField(*message, field_descriptor);
+ if (!has_field && field_descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_ENUM) {
+ // We may have an invalid enum value stored in the UnknownFieldSet and need
+ // to check presence in there as well.
+ const google::protobuf::UnknownFieldSet& unknown_field_set =
+ message->GetReflection()->GetUnknownFields(*message);
+ for (int i = 0; i < unknown_field_set.field_count(); ++i) {
+ if (unknown_field_set.field(i).number() == field_descriptor->number()) {
+ Py_RETURN_TRUE;
+ }
+ }
+ Py_RETURN_FALSE;
+ }
+ return PyBool_FromLong(has_field ? 1 : 0);
+}
+
+PyObject* ClearExtension(CMessage* self, PyObject* arg) {
+ if (self->extensions != NULL) {
+ return extension_dict::ClearExtension(self->extensions, arg);
+ }
+ PyErr_SetString(PyExc_TypeError, "Message is not extendable");
+ return NULL;
+}
+
+PyObject* HasExtension(CMessage* self, PyObject* arg) {
+ if (self->extensions != NULL) {
+ return extension_dict::HasExtension(self->extensions, arg);
+ }
+ PyErr_SetString(PyExc_TypeError, "Message is not extendable");
+ return NULL;
+}
+
+// ---------------------------------------------------------------------
+// Releasing messages
+//
+// The Python API's ClearField() and Clear() methods behave
+// differently than their C++ counterparts. While the C++ versions
+// clears the children the Python versions detaches the children,
+// without touching their content. This impedance mismatch causes
+// some complexity in the implementation, which is captured in this
+// section.
+//
+// When a CMessage field is cleared we need to:
+//
+// * Release the Message used as the backing store for the CMessage
+// from its parent.
+//
+// * Change the owner field of the released CMessage and all of its
+// children to point to the newly released Message.
+//
+// * Clear the weak references from the released CMessage to the
+// parent.
+//
+// When a RepeatedCompositeContainer field is cleared we need to:
+//
+// * Release all the Message used as the backing store for the
+// CMessages stored in the container.
+//
+// * Change the owner field of all the released CMessage and all of
+// their children to point to the newly released Messages.
+//
+// * Clear the weak references from the released container to the
+// parent.
+
+struct SetOwnerVisitor : public ChildVisitor {
+ // new_owner must outlive this object.
+ explicit SetOwnerVisitor(const shared_ptr<Message>& new_owner)
+ : new_owner_(new_owner) {}
+
+ int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
+ repeated_composite_container::SetOwner(container, new_owner_);
+ return 0;
+ }
+
+ int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
+ repeated_scalar_container::SetOwner(container, new_owner_);
+ return 0;
+ }
+
+ int VisitCMessage(CMessage* cmessage,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ return SetOwner(cmessage, new_owner_);
+ }
+
+ private:
+ const shared_ptr<Message>& new_owner_;
+};
+
+// Change the owner of this CMessage and all its children, recursively.
+int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) {
+ self->owner = new_owner;
+ if (ForEachCompositeField(self, SetOwnerVisitor(new_owner)) == -1)
+ return -1;
+ return 0;
+}
+
+// Releases the message specified by 'field' and returns the
+// pointer. If the field does not exist a new message is created using
+// 'descriptor'. The caller takes ownership of the returned pointer.
+Message* ReleaseMessage(google::protobuf::Message* message,
+ const google::protobuf::Descriptor* descriptor,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ Message* released_message = message->GetReflection()->ReleaseMessage(
+ message, field_descriptor, global_message_factory);
+ // ReleaseMessage will return NULL which differs from
+ // child_cmessage->message, if the field does not exist. In this case,
+ // the latter points to the default instance via a const_cast<>, so we
+ // have to reset it to a new mutable object since we are taking ownership.
+ if (released_message == NULL) {
+ const Message* prototype = global_message_factory->GetPrototype(
+ descriptor);
+ GOOGLE_DCHECK(prototype != NULL);
+ released_message = prototype->New();
+ }
+
+ return released_message;
+}
+
+int ReleaseSubMessage(google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ CMessage* child_cmessage) {
+ // Release the Message
+ shared_ptr<Message> released_message(ReleaseMessage(
+ message, child_cmessage->message->GetDescriptor(), field_descriptor));
+ child_cmessage->message = released_message.get();
+ child_cmessage->owner.swap(released_message);
+ child_cmessage->parent = NULL;
+ child_cmessage->parent_field = NULL;
+ child_cmessage->read_only = false;
+ return ForEachCompositeField(child_cmessage,
+ SetOwnerVisitor(child_cmessage->owner));
+}
+
+struct ReleaseChild : public ChildVisitor {
+ // message must outlive this object.
+ explicit ReleaseChild(google::protobuf::Message* parent_message) :
+ parent_message_(parent_message) {}
+
+ int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) {
+ return repeated_composite_container::Release(
+ reinterpret_cast<RepeatedCompositeContainer*>(container));
+ }
+
+ int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) {
+ return repeated_scalar_container::Release(
+ reinterpret_cast<RepeatedScalarContainer*>(container));
+ }
+
+ int VisitCMessage(CMessage* cmessage,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ return ReleaseSubMessage(parent_message_, field_descriptor,
+ reinterpret_cast<CMessage*>(cmessage));
+ }
+
+ google::protobuf::Message* parent_message_;
+};
+
+int InternalReleaseFieldByDescriptor(
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* composite_field,
+ google::protobuf::Message* parent_message) {
+ return VisitCompositeField(
+ field_descriptor,
+ composite_field,
+ ReleaseChild(parent_message));
+}
+
+int InternalReleaseField(CMessage* self, PyObject* composite_field,
+ PyObject* name) {
+ PyObject* cdescriptor = GetDescriptor(self, name);
+ if (cdescriptor != NULL) {
+ const google::protobuf::FieldDescriptor* descriptor =
+ reinterpret_cast<CFieldDescriptor*>(cdescriptor)->descriptor;
+ return InternalReleaseFieldByDescriptor(
+ descriptor, composite_field, self->message);
+ }
+
+ return 0;
+}
+
+PyObject* ClearFieldByDescriptor(
+ CMessage* self,
+ const google::protobuf::FieldDescriptor* descriptor) {
+ if (!FIELD_BELONGS_TO_MESSAGE(descriptor, self->message)) {
+ PyErr_SetString(PyExc_KeyError,
+ "Field does not belong to message!");
+ return NULL;
+ }
+ AssureWritable(self);
+ self->message->GetReflection()->ClearField(self->message, descriptor);
+ Py_RETURN_NONE;
+}
+
+PyObject* ClearField(CMessage* self, PyObject* arg) {
+ char* field_name;
+ if (!PyString_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError, "field name must be a string");
+ return NULL;
+ }
+#if PY_MAJOR_VERSION < 3
+ if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) {
+ return NULL;
+ }
+#else
+ field_name = PyUnicode_AsUTF8(arg);
+#endif
+ AssureWritable(self);
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Descriptor* descriptor = message->GetDescriptor();
+ ScopedPyObjectPtr arg_in_oneof;
+ bool is_in_oneof;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ FindFieldWithOneofs(message, field_name, &is_in_oneof);
+ if (field_descriptor == NULL) {
+ if (!is_in_oneof) {
+ PyErr_Format(PyExc_ValueError,
+ "Protocol message has no \"%s\" field.", field_name);
+ return NULL;
+ } else {
+ Py_RETURN_NONE;
+ }
+ } else if (is_in_oneof) {
+ arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str()));
+ arg = arg_in_oneof.get();
+ }
+
+ PyObject* composite_field = PyDict_GetItem(self->composite_fields,
+ arg);
+
+ // Only release the field if there's a possibility that there are
+ // references to it.
+ if (composite_field != NULL) {
+ if (InternalReleaseField(self, composite_field, arg) < 0) {
+ return NULL;
+ }
+ PyDict_DelItem(self->composite_fields, arg);
+ }
+ message->GetReflection()->ClearField(message, field_descriptor);
+ if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) {
+ google::protobuf::UnknownFieldSet* unknown_field_set =
+ message->GetReflection()->MutableUnknownFields(message);
+ unknown_field_set->DeleteByNumber(field_descriptor->number());
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* Clear(CMessage* self) {
+ AssureWritable(self);
+ if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1)
+ return NULL;
+
+ // The old ExtensionDict still aliases this CMessage, but all its
+ // fields have been released.
+ if (self->extensions != NULL) {
+ Py_CLEAR(self->extensions);
+ PyObject* py_extension_dict = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&ExtensionDict_Type), NULL);
+ if (py_extension_dict == NULL) {
+ return NULL;
+ }
+ ExtensionDict* extension_dict = reinterpret_cast<ExtensionDict*>(
+ py_extension_dict);
+ extension_dict->parent = self;
+ extension_dict->message = self->message;
+ self->extensions = extension_dict;
+ }
+ PyDict_Clear(self->composite_fields);
+ self->message->Clear();
+ Py_RETURN_NONE;
+}
+
+// ---------------------------------------------------------------------
+
+static string GetMessageName(CMessage* self) {
+ if (self->parent_field != NULL) {
+ return self->parent_field->descriptor->full_name();
+ } else {
+ return self->message->GetDescriptor()->full_name();
+ }
+}
+
+static PyObject* SerializeToString(CMessage* self, PyObject* args) {
+ if (!self->message->IsInitialized()) {
+ ScopedPyObjectPtr errors(FindInitializationErrors(self));
+ if (errors == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr comma(PyString_FromString(","));
+ if (comma == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr joined(
+ PyObject_CallMethod(comma.get(), "join", "O", errors.get()));
+ if (joined == NULL) {
+ return NULL;
+ }
+ PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s",
+ GetMessageName(self).c_str(), PyString_AsString(joined.get()));
+ return NULL;
+ }
+ int size = self->message->ByteSize();
+ if (size <= 0) {
+ return PyBytes_FromString("");
+ }
+ PyObject* result = PyBytes_FromStringAndSize(NULL, size);
+ if (result == NULL) {
+ return NULL;
+ }
+ char* buffer = PyBytes_AS_STRING(result);
+ self->message->SerializeWithCachedSizesToArray(
+ reinterpret_cast<uint8*>(buffer));
+ return result;
+}
+
+static PyObject* SerializePartialToString(CMessage* self) {
+ string contents;
+ self->message->SerializePartialToString(&contents);
+ return PyBytes_FromStringAndSize(contents.c_str(), contents.size());
+}
+
+// Formats proto fields for ascii dumps using python formatting functions where
+// appropriate.
+class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter {
+ public:
+ PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {}
+
+ // Python has some differences from C++ when printing floating point numbers.
+ //
+ // 1) Trailing .0 is always printed.
+ // 2) Outputted is rounded to 12 digits.
+ //
+ // We override floating point printing with the C-API function for printing
+ // Python floats to ensure consistency.
+ string PrintFloat(float value) const { return PrintDouble(value); }
+ string PrintDouble(double value) const {
+ reinterpret_cast<PyFloatObject*>(float_holder_.get())->ob_fval = value;
+ ScopedPyObjectPtr s(PyObject_Str(float_holder_.get()));
+ if (s == NULL) return string();
+#if PY_MAJOR_VERSION < 3
+ char *cstr = PyBytes_AS_STRING(static_cast<PyObject*>(s));
+#else
+ char *cstr = PyUnicode_AsUTF8(s);
+#endif
+ return string(cstr);
+ }
+
+ private:
+ // Holder for a python float object which we use to allow us to use
+ // the Python API for printing doubles. We initialize once and then
+ // directly modify it for every float printed to save on allocations
+ // and refcounting.
+ ScopedPyObjectPtr float_holder_;
+};
+
+static PyObject* ToStr(CMessage* self) {
+ google::protobuf::TextFormat::Printer printer;
+ // Passes ownership
+ printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter());
+ printer.SetHideUnknownFields(true);
+ string output;
+ if (!printer.PrintToString(*self->message, &output)) {
+ PyErr_SetString(PyExc_ValueError, "Unable to convert message to str");
+ return NULL;
+ }
+ return PyString_FromString(output.c_str());
+}
+
+PyObject* MergeFrom(CMessage* self, PyObject* arg) {
+ CMessage* other_message;
+ if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Must be a message");
+ return NULL;
+ }
+
+ other_message = reinterpret_cast<CMessage*>(arg);
+ if (other_message->message->GetDescriptor() !=
+ self->message->GetDescriptor()) {
+ PyErr_Format(PyExc_TypeError,
+ "Tried to merge from a message with a different type. "
+ "to: %s, from: %s",
+ self->message->GetDescriptor()->full_name().c_str(),
+ other_message->message->GetDescriptor()->full_name().c_str());
+ return NULL;
+ }
+ AssureWritable(self);
+
+ // TODO(tibell): Message::MergeFrom might turn some child Messages
+ // into mutable messages, invalidating the message field in the
+ // corresponding CMessages. We should run a FixupMessageReferences
+ // pass here.
+
+ self->message->MergeFrom(*other_message->message);
+ Py_RETURN_NONE;
+}
+
+static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
+ CMessage* other_message;
+ if (!PyObject_TypeCheck(reinterpret_cast<PyObject *>(arg), &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Must be a message");
+ return NULL;
+ }
+
+ other_message = reinterpret_cast<CMessage*>(arg);
+
+ if (self == other_message) {
+ Py_RETURN_NONE;
+ }
+
+ if (other_message->message->GetDescriptor() !=
+ self->message->GetDescriptor()) {
+ PyErr_Format(PyExc_TypeError,
+ "Tried to copy from a message with a different type. "
+ "to: %s, from: %s",
+ self->message->GetDescriptor()->full_name().c_str(),
+ other_message->message->GetDescriptor()->full_name().c_str());
+ return NULL;
+ }
+
+ AssureWritable(self);
+
+ // CopyFrom on the message will not clean up self->composite_fields,
+ // which can leave us in an inconsistent state, so clear it out here.
+ Clear(self);
+
+ self->message->CopyFrom(*other_message->message);
+
+ Py_RETURN_NONE;
+}
+
+static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
+ const void* data;
+ Py_ssize_t data_length;
+ if (PyObject_AsReadBuffer(arg, &data, &data_length) < 0) {
+ return NULL;
+ }
+
+ AssureWritable(self);
+ google::protobuf::io::CodedInputStream input(
+ reinterpret_cast<const uint8*>(data), data_length);
+ input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory);
+ bool success = self->message->MergePartialFromCodedStream(&input);
+ if (success) {
+ return PyInt_FromLong(input.CurrentPosition());
+ } else {
+ PyErr_Format(DecodeError_class, "Error parsing message");
+ return NULL;
+ }
+}
+
+static PyObject* ParseFromString(CMessage* self, PyObject* arg) {
+ if (Clear(self) == NULL) {
+ return NULL;
+ }
+ return MergeFromString(self, arg);
+}
+
+static PyObject* ByteSize(CMessage* self, PyObject* args) {
+ return PyLong_FromLong(self->message->ByteSize());
+}
+
+static PyObject* RegisterExtension(PyObject* cls,
+ PyObject* extension_handle) {
+ ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR));
+ if (message_descriptor == NULL) {
+ return NULL;
+ }
+ if (PyObject_SetAttrString(extension_handle, "containing_type",
+ message_descriptor) < 0) {
+ return NULL;
+ }
+ ScopedPyObjectPtr extensions_by_name(
+ PyObject_GetAttr(cls, k_extensions_by_name));
+ if (extensions_by_name == NULL) {
+ PyErr_SetString(PyExc_TypeError, "no extensions_by_name on class");
+ return NULL;
+ }
+ ScopedPyObjectPtr full_name(PyObject_GetAttr(extension_handle, kfull_name));
+ if (full_name == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) {
+ return NULL;
+ }
+
+ // Also store a mapping from extension number to implementing class.
+ ScopedPyObjectPtr extensions_by_number(
+ PyObject_GetAttr(cls, k_extensions_by_number));
+ if (extensions_by_number == NULL) {
+ PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class");
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number"));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(extensions_by_number, number, extension_handle) < 0) {
+ return NULL;
+ }
+
+ CFieldDescriptor* cdescriptor =
+ extension_dict::InternalGetCDescriptorFromExtension(extension_handle);
+ ScopedPyObjectPtr py_cdescriptor(reinterpret_cast<PyObject*>(cdescriptor));
+ if (cdescriptor == NULL) {
+ return NULL;
+ }
+ Py_INCREF(extension_handle);
+ cdescriptor->descriptor_field = extension_handle;
+ const google::protobuf::FieldDescriptor* descriptor = cdescriptor->descriptor;
+ // Check if it's a message set
+ if (descriptor->is_extension() &&
+ descriptor->containing_type()->options().message_set_wire_format() &&
+ descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE &&
+ descriptor->message_type() == descriptor->extension_scope() &&
+ descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) {
+ ScopedPyObjectPtr message_name(PyString_FromStringAndSize(
+ descriptor->message_type()->full_name().c_str(),
+ descriptor->message_type()->full_name().size()));
+ if (message_name == NULL) {
+ return NULL;
+ }
+ PyDict_SetItem(extensions_by_name, message_name, extension_handle);
+ }
+
+ Py_RETURN_NONE;
+}
+
+static PyObject* SetInParent(CMessage* self, PyObject* args) {
+ AssureWritable(self);
+ Py_RETURN_NONE;
+}
+
+static PyObject* WhichOneof(CMessage* self, PyObject* arg) {
+ char* oneof_name;
+ if (!PyString_Check(arg)) {
+ PyErr_SetString(PyExc_TypeError, "field name must be a string");
+ return NULL;
+ }
+ oneof_name = PyString_AsString(arg);
+ if (oneof_name == NULL) {
+ return NULL;
+ }
+ const google::protobuf::OneofDescriptor* oneof_desc =
+ self->message->GetDescriptor()->FindOneofByName(oneof_name);
+ if (oneof_desc == NULL) {
+ PyErr_Format(PyExc_ValueError,
+ "Protocol message has no oneof \"%s\" field.", oneof_name);
+ return NULL;
+ }
+ const google::protobuf::FieldDescriptor* field_in_oneof =
+ self->message->GetReflection()->GetOneofFieldDescriptor(
+ *self->message, oneof_desc);
+ if (field_in_oneof == NULL) {
+ Py_RETURN_NONE;
+ } else {
+ return PyString_FromString(field_in_oneof->name().c_str());
+ }
+}
+
+static PyObject* ListFields(CMessage* self) {
+ vector<const google::protobuf::FieldDescriptor*> fields;
+ self->message->GetReflection()->ListFields(*self->message, &fields);
+
+ PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR);
+ if (descriptor == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr fields_by_name(
+ PyObject_GetAttr(descriptor, kfields_by_name));
+ if (fields_by_name == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr(
+ reinterpret_cast<PyObject*>(Py_TYPE(self)), k_extensions_by_name));
+ if (extensions_by_name == NULL) {
+ PyErr_SetString(PyExc_ValueError, "no extensionsbyname");
+ return NULL;
+ }
+ // Normally, the list will be exactly the size of the fields.
+ PyObject* all_fields = PyList_New(fields.size());
+ if (all_fields == NULL) {
+ return NULL;
+ }
+
+ // When there are unknown extensions, the py list will *not* contain
+ // the field information. Thus the actual size of the py list will be
+ // smaller than the size of fields. Set the actual size at the end.
+ Py_ssize_t actual_size = 0;
+ for (Py_ssize_t i = 0; i < fields.size(); ++i) {
+ ScopedPyObjectPtr t(PyTuple_New(2));
+ if (t == NULL) {
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+
+ if (fields[i]->is_extension()) {
+ const string& field_name = fields[i]->full_name();
+ PyObject* extension_field = PyDict_GetItemString(extensions_by_name,
+ field_name.c_str());
+ if (extension_field == NULL) {
+ // If we couldn't fetch extension_field, it means the module that
+ // defines this extension has not been explicitly imported in Python
+ // code, and the extension hasn't been registered. There's nothing much
+ // we can do about this, so just skip it in the output to match the
+ // behavior of the python implementation.
+ continue;
+ }
+ PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions);
+ if (extensions == NULL) {
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+ // 'extension' reference later stolen by PyTuple_SET_ITEM.
+ PyObject* extension = PyObject_GetItem(extensions, extension_field);
+ if (extension == NULL) {
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+ Py_INCREF(extension_field);
+ PyTuple_SET_ITEM(t.get(), 0, extension_field);
+ // Steals reference to 'extension'
+ PyTuple_SET_ITEM(t.get(), 1, extension);
+ } else {
+ const string& field_name = fields[i]->name();
+ ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize(
+ field_name.c_str(), field_name.length()));
+ if (py_field_name == NULL) {
+ PyErr_SetString(PyExc_ValueError, "bad string");
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+ PyObject* field_descriptor =
+ PyDict_GetItem(fields_by_name, py_field_name);
+ if (field_descriptor == NULL) {
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+
+ PyObject* field_value = GetAttr(self, py_field_name);
+ if (field_value == NULL) {
+ PyErr_SetObject(PyExc_ValueError, py_field_name);
+ Py_DECREF(all_fields);
+ return NULL;
+ }
+ Py_INCREF(field_descriptor);
+ PyTuple_SET_ITEM(t.get(), 0, field_descriptor);
+ PyTuple_SET_ITEM(t.get(), 1, field_value);
+ }
+ PyList_SET_ITEM(all_fields, actual_size, t.release());
+ ++actual_size;
+ }
+ Py_SIZE(all_fields) = actual_size;
+ return all_fields;
+}
+
+PyObject* FindInitializationErrors(CMessage* self) {
+ google::protobuf::Message* message = self->message;
+ vector<string> errors;
+ message->FindInitializationErrors(&errors);
+
+ PyObject* error_list = PyList_New(errors.size());
+ if (error_list == NULL) {
+ return NULL;
+ }
+ for (Py_ssize_t i = 0; i < errors.size(); ++i) {
+ const string& error = errors[i];
+ PyObject* error_string = PyString_FromStringAndSize(
+ error.c_str(), error.length());
+ if (error_string == NULL) {
+ Py_DECREF(error_list);
+ return NULL;
+ }
+ PyList_SET_ITEM(error_list, i, error_string);
+ }
+ return error_list;
+}
+
+static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
+ if (!PyObject_TypeCheck(other, &CMessage_Type)) {
+ if (opid == Py_EQ) {
+ Py_RETURN_FALSE;
+ } else if (opid == Py_NE) {
+ Py_RETURN_TRUE;
+ }
+ }
+ if (opid == Py_EQ || opid == Py_NE) {
+ ScopedPyObjectPtr self_fields(ListFields(self));
+ ScopedPyObjectPtr other_fields(ListFields(
+ reinterpret_cast<CMessage*>(other)));
+ return PyObject_RichCompare(self_fields, other_fields, opid);
+ } else {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+}
+
+PyObject* InternalGetScalar(
+ CMessage* self,
+ const google::protobuf::FieldDescriptor* field_descriptor) {
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+
+ if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
+ PyErr_SetString(
+ PyExc_KeyError, "Field does not belong to message!");
+ return NULL;
+ }
+
+ PyObject* result = NULL;
+ switch (field_descriptor->cpp_type()) {
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ int32 value = reflection->GetInt32(*message, field_descriptor);
+ result = PyInt_FromLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ int64 value = reflection->GetInt64(*message, field_descriptor);
+ result = PyLong_FromLongLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ uint32 value = reflection->GetUInt32(*message, field_descriptor);
+ result = PyInt_FromSize_t(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ uint64 value = reflection->GetUInt64(*message, field_descriptor);
+ result = PyLong_FromUnsignedLongLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ float value = reflection->GetFloat(*message, field_descriptor);
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ double value = reflection->GetDouble(*message, field_descriptor);
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ bool value = reflection->GetBool(*message, field_descriptor);
+ result = PyBool_FromLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ string value = reflection->GetString(*message, field_descriptor);
+ result = ToStringObject(field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ if (!message->GetReflection()->HasField(*message, field_descriptor)) {
+ // Look for the value in the unknown fields.
+ google::protobuf::UnknownFieldSet* unknown_field_set =
+ message->GetReflection()->MutableUnknownFields(message);
+ for (int i = 0; i < unknown_field_set->field_count(); ++i) {
+ if (unknown_field_set->field(i).number() ==
+ field_descriptor->number()) {
+ result = PyInt_FromLong(unknown_field_set->field(i).varint());
+ break;
+ }
+ }
+ }
+
+ if (result == NULL) {
+ const google::protobuf::EnumValueDescriptor* enum_value =
+ message->GetReflection()->GetEnum(*message, field_descriptor);
+ result = PyInt_FromLong(enum_value->number());
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Getting a value from a field of unknown type %d",
+ field_descriptor->cpp_type());
+ }
+
+ return result;
+}
+
+PyObject* InternalGetSubMessage(CMessage* self,
+ CFieldDescriptor* cfield_descriptor) {
+ PyObject* field = cfield_descriptor->descriptor_field;
+ ScopedPyObjectPtr message_type(PyObject_GetAttr(field, kmessage_type));
+ if (message_type == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr concrete_class(
+ PyObject_GetAttr(message_type, k_concrete_class));
+ if (concrete_class == NULL) {
+ return NULL;
+ }
+ PyObject* py_cmsg = cmessage::NewEmpty(concrete_class);
+ if (py_cmsg == NULL) {
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(py_cmsg, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a CMessage!");
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
+
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ cfield_descriptor->descriptor;
+ const google::protobuf::Reflection* reflection = self->message->GetReflection();
+ const google::protobuf::Message& sub_message = reflection->GetMessage(
+ *self->message, field_descriptor, global_message_factory);
+ cmsg->owner = self->owner;
+ cmsg->parent = self;
+ cmsg->parent_field = cfield_descriptor;
+ cmsg->read_only = !reflection->HasField(*self->message, field_descriptor);
+ cmsg->message = const_cast<google::protobuf::Message*>(&sub_message);
+
+ if (InitAttributes(cmsg, NULL, NULL) < 0) {
+ Py_DECREF(py_cmsg);
+ return NULL;
+ }
+ return py_cmsg;
+}
+
+int InternalSetScalar(
+ CMessage* self,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* arg) {
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+
+ if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
+ PyErr_SetString(
+ PyExc_KeyError, "Field does not belong to message!");
+ return -1;
+ }
+
+ if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) {
+ return -1;
+ }
+
+ switch (field_descriptor->cpp_type()) {
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(arg, value, -1);
+ reflection->SetInt32(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(arg, value, -1);
+ reflection->SetInt64(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(arg, value, -1);
+ reflection->SetUInt32(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(arg, value, -1);
+ reflection->SetUInt64(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
+ reflection->SetFloat(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
+ reflection->SetDouble(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(arg, value, -1);
+ reflection->SetBool(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ if (!CheckAndSetString(
+ arg, message, field_descriptor, reflection, false, -1)) {
+ return -1;
+ }
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ GOOGLE_CHECK_GET_INT32(arg, value, -1);
+ const google::protobuf::EnumDescriptor* enum_descriptor =
+ field_descriptor->enum_type();
+ const google::protobuf::EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->SetEnum(message, field_descriptor, enum_value);
+ } else {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
+ return -1;
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Setting value to a field of unknown type %d",
+ field_descriptor->cpp_type());
+ return -1;
+ }
+
+ return 0;
+}
+
+PyObject* FromString(PyTypeObject* cls, PyObject* serialized) {
+ PyObject* py_cmsg = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(cls), NULL);
+ if (py_cmsg == NULL) {
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
+
+ ScopedPyObjectPtr py_length(MergeFromString(cmsg, serialized));
+ if (py_length == NULL) {
+ Py_DECREF(py_cmsg);
+ return NULL;
+ }
+
+ if (InitAttributes(cmsg, NULL, NULL) < 0) {
+ Py_DECREF(py_cmsg);
+ return NULL;
+ }
+ return py_cmsg;
+}
+
+static PyObject* AddDescriptors(PyTypeObject* cls,
+ PyObject* descriptor) {
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ k_extensions_by_name, PyDict_New()) < 0) {
+ return NULL;
+ }
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ k_extensions_by_number, PyDict_New()) < 0) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr field_descriptors(PyDict_New());
+
+ ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields"));
+ if (fields == NULL) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER"));
+ if (_NUMBER_string == NULL) {
+ return NULL;
+ }
+
+ const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get());
+ for (int i = 0; i < fields_size; ++i) {
+ PyObject* field = PyList_GET_ITEM(fields.get(), i);
+ ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname));
+ ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name));
+ if (field_name == NULL || full_field_name == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Name is null");
+ return NULL;
+ }
+
+ PyObject* field_descriptor =
+ cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name);
+ if (field_descriptor == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Couldn't find field");
+ return NULL;
+ }
+ Py_INCREF(field);
+ CFieldDescriptor* cfield_descriptor = reinterpret_cast<CFieldDescriptor*>(
+ field_descriptor);
+ cfield_descriptor->descriptor_field = field;
+ if (PyDict_SetItem(field_descriptors, field_name, field_descriptor) < 0) {
+ return NULL;
+ }
+
+ // The FieldDescriptor's name field might either be of type bytes or
+ // of type unicode, depending on whether the FieldDescriptor was
+ // parsed from a serialized message or read from the
+ // <message>_pb2.py module.
+ ScopedPyObjectPtr field_name_upcased(
+ PyObject_CallMethod(field_name, "upper", NULL));
+ if (field_name_upcased == NULL) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr field_number_name(PyObject_CallMethod(
+ field_name_upcased, "__add__", "(O)", _NUMBER_string.get()));
+ if (field_number_name == NULL) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr number(PyInt_FromLong(
+ cfield_descriptor->descriptor->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ field_number_name, number) == -1) {
+ return NULL;
+ }
+ }
+
+ PyDict_SetItem(cls->tp_dict, k__descriptors, field_descriptors);
+
+ // Enum Values
+ ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor,
+ "enum_types"));
+ if (enum_types == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types));
+ if (type_iter == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr enum_type;
+ while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) {
+ ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs(
+ EnumTypeWrapper_class, enum_type.get(), NULL));
+ if (wrapped == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname));
+ if (enum_name == NULL) {
+ return NULL;
+ }
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ enum_name, wrapped) == -1) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values"));
+ if (enum_values == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values));
+ if (values_iter == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr enum_value;
+ while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) {
+ ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname));
+ if (value_name == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value,
+ "number"));
+ if (value_number == NULL) {
+ return NULL;
+ }
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ value_name, value_number) == -1) {
+ return NULL;
+ }
+ }
+ if (PyErr_Occurred()) { // If PyIter_Next failed
+ return NULL;
+ }
+ }
+ if (PyErr_Occurred()) { // If PyIter_Next failed
+ return NULL;
+ }
+
+ ScopedPyObjectPtr extension_dict(
+ PyObject_GetAttr(descriptor, kextensions_by_name));
+ if (extension_dict == NULL || !PyDict_Check(extension_dict)) {
+ PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict");
+ return NULL;
+ }
+ Py_ssize_t pos = 0;
+ PyObject* extension_name;
+ PyObject* extension_field;
+
+ while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) {
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ extension_name, extension_field) == -1) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_cfield_descriptor(
+ PyObject_GetAttrString(extension_field, "_cdescriptor"));
+ if (py_cfield_descriptor == NULL) {
+ return NULL;
+ }
+ CFieldDescriptor* cfield_descriptor =
+ reinterpret_cast<CFieldDescriptor*>(py_cfield_descriptor.get());
+ Py_INCREF(extension_field);
+ cfield_descriptor->descriptor_field = extension_field;
+
+ ScopedPyObjectPtr field_name_upcased(
+ PyObject_CallMethod(extension_name, "upper", NULL));
+ if (field_name_upcased == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr field_number_name(PyObject_CallMethod(
+ field_name_upcased, "__add__", "(O)", _NUMBER_string.get()));
+ if (field_number_name == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr number(PyInt_FromLong(
+ cfield_descriptor->descriptor->number()));
+ if (number == NULL) {
+ return NULL;
+ }
+ if (PyObject_SetAttr(reinterpret_cast<PyObject*>(cls),
+ field_number_name, PyInt_FromLong(
+ cfield_descriptor->descriptor->number())) == -1) {
+ return NULL;
+ }
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* DeepCopy(CMessage* self, PyObject* arg) {
+ PyObject* clone = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(Py_TYPE(self)), NULL);
+ if (clone == NULL) {
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(clone, &CMessage_Type)) {
+ Py_DECREF(clone);
+ return NULL;
+ }
+ if (InitAttributes(reinterpret_cast<CMessage*>(clone), NULL, NULL) < 0) {
+ Py_DECREF(clone);
+ return NULL;
+ }
+ if (MergeFrom(reinterpret_cast<CMessage*>(clone),
+ reinterpret_cast<PyObject*>(self)) == NULL) {
+ Py_DECREF(clone);
+ return NULL;
+ }
+ return clone;
+}
+
+PyObject* ToUnicode(CMessage* self) {
+ // Lazy import to prevent circular dependencies
+ ScopedPyObjectPtr text_format(
+ PyImport_ImportModule("google.protobuf.text_format"));
+ if (text_format == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr method_name(PyString_FromString("MessageToString"));
+ if (method_name == NULL) {
+ return NULL;
+ }
+ Py_INCREF(Py_True);
+ ScopedPyObjectPtr encoded(PyObject_CallMethodObjArgs(text_format, method_name,
+ self, Py_True, NULL));
+ Py_DECREF(Py_True);
+ if (encoded == NULL) {
+ return NULL;
+ }
+#if PY_MAJOR_VERSION < 3
+ PyObject* decoded = PyString_AsDecodedObject(encoded, "utf-8", NULL);
+#else
+ PyObject* decoded = PyUnicode_FromEncodedObject(encoded, "utf-8", NULL);
+#endif
+ if (decoded == NULL) {
+ return NULL;
+ }
+ return decoded;
+}
+
+PyObject* Reduce(CMessage* self) {
+ ScopedPyObjectPtr constructor(reinterpret_cast<PyObject*>(Py_TYPE(self)));
+ constructor.inc();
+ ScopedPyObjectPtr args(PyTuple_New(0));
+ if (args == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr state(PyDict_New());
+ if (state == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr serialized(SerializePartialToString(self));
+ if (serialized == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItemString(state, "serialized", serialized) < 0) {
+ return NULL;
+ }
+ return Py_BuildValue("OOO", constructor.get(), args.get(), state.get());
+}
+
+PyObject* SetState(CMessage* self, PyObject* state) {
+ if (!PyDict_Check(state)) {
+ PyErr_SetString(PyExc_TypeError, "state not a dict");
+ return NULL;
+ }
+ PyObject* serialized = PyDict_GetItemString(state, "serialized");
+ if (serialized == NULL) {
+ return NULL;
+ }
+ if (ParseFromString(self, serialized) == NULL) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+// CMessage static methods:
+PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) {
+ return cdescriptor_pool::FindFieldByName(descriptor_pool, arg);
+}
+
+PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) {
+ return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg);
+}
+
+static PyMemberDef Members[] = {
+ {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0,
+ "Extension dict"},
+ {NULL}
+};
+
+static PyMethodDef Methods[] = {
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the message." },
+ { "__setstate__", (PyCFunction)SetState, METH_O,
+ "Inputs picklable representation of the message." },
+ { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
+ "Outputs a unicode representation of the message." },
+ { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS,
+ "Adds field descriptors to the class" },
+ { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
+ "Returns the size of the message in bytes." },
+ { "Clear", (PyCFunction)Clear, METH_NOARGS,
+ "Clears the message." },
+ { "ClearExtension", (PyCFunction)ClearExtension, METH_O,
+ "Clears a message field." },
+ { "ClearField", (PyCFunction)ClearField, METH_O,
+ "Clears a message field." },
+ { "CopyFrom", (PyCFunction)CopyFrom, METH_O,
+ "Copies a protocol message into the current message." },
+ { "FindInitializationErrors", (PyCFunction)FindInitializationErrors,
+ METH_NOARGS,
+ "Finds unset required fields." },
+ { "FromString", (PyCFunction)FromString, METH_O | METH_CLASS,
+ "Creates new method instance from given serialized data." },
+ { "HasExtension", (PyCFunction)HasExtension, METH_O,
+ "Checks if a message field is set." },
+ { "HasField", (PyCFunction)HasField, METH_O,
+ "Checks if a message field is set." },
+ { "IsInitialized", (PyCFunction)IsInitialized, METH_VARARGS,
+ "Checks if all required fields of a protocol message are set." },
+ { "ListFields", (PyCFunction)ListFields, METH_NOARGS,
+ "Lists all set fields of a message." },
+ { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
+ "Merges a protocol message into the current message." },
+ { "MergeFromString", (PyCFunction)MergeFromString, METH_O,
+ "Merges a serialized message into the current message." },
+ { "ParseFromString", (PyCFunction)ParseFromString, METH_O,
+ "Parses a serialized message into the current message." },
+ { "RegisterExtension", (PyCFunction)RegisterExtension, METH_O | METH_CLASS,
+ "Registers an extension with the current message." },
+ { "SerializePartialToString", (PyCFunction)SerializePartialToString,
+ METH_NOARGS,
+ "Serializes the message to a string, even if it isn't initialized." },
+ { "SerializeToString", (PyCFunction)SerializeToString, METH_NOARGS,
+ "Serializes the message to a string, only for initialized messages." },
+ { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
+ "Sets the has bit of the given field in its parent message." },
+ { "WhichOneof", (PyCFunction)WhichOneof, METH_O,
+ "Returns the name of the field set inside a oneof, "
+ "or None if no field is set." },
+
+ // Static Methods.
+ { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC,
+ "Registers a new protocol buffer file in the global C++ descriptor pool." },
+ { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor,
+ METH_O | METH_STATIC, "Finds a field descriptor in the message pool." },
+ { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor,
+ METH_O | METH_STATIC,
+ "Finds a extension descriptor in the message pool." },
+ { NULL, NULL}
+};
+
+PyObject* GetAttr(CMessage* self, PyObject* name) {
+ PyObject* value = PyDict_GetItem(self->composite_fields, name);
+ if (value != NULL) {
+ Py_INCREF(value);
+ return value;
+ }
+
+ PyObject* descriptor = GetDescriptor(self, name);
+ if (descriptor != NULL) {
+ CFieldDescriptor* cdescriptor =
+ reinterpret_cast<CFieldDescriptor*>(descriptor);
+ const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor;
+ if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ if (field_descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyObject* py_container = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&RepeatedCompositeContainer_Type),
+ NULL);
+ if (py_container == NULL) {
+ return NULL;
+ }
+ RepeatedCompositeContainer* container =
+ reinterpret_cast<RepeatedCompositeContainer*>(py_container);
+ PyObject* field = cdescriptor->descriptor_field;
+ PyObject* message_type = PyObject_GetAttr(field, kmessage_type);
+ if (message_type == NULL) {
+ return NULL;
+ }
+ PyObject* concrete_class =
+ PyObject_GetAttr(message_type, k_concrete_class);
+ if (concrete_class == NULL) {
+ return NULL;
+ }
+ container->parent = self;
+ container->parent_field = cdescriptor;
+ container->message = self->message;
+ container->owner = self->owner;
+ container->subclass_init = concrete_class;
+ Py_DECREF(message_type);
+ if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) {
+ Py_DECREF(py_container);
+ return NULL;
+ }
+ return py_container;
+ } else {
+ ScopedPyObjectPtr init_args(PyTuple_Pack(2, self, cdescriptor));
+ PyObject* py_container = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type),
+ init_args);
+ if (py_container == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) {
+ Py_DECREF(py_container);
+ return NULL;
+ }
+ return py_container;
+ }
+ } else {
+ if (field_descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyObject* sub_message = InternalGetSubMessage(self, cdescriptor);
+ if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) {
+ Py_DECREF(sub_message);
+ return NULL;
+ }
+ return sub_message;
+ } else {
+ return InternalGetScalar(self, field_descriptor);
+ }
+ }
+ }
+
+ return CMessage_Type.tp_base->tp_getattro(reinterpret_cast<PyObject*>(self),
+ name);
+}
+
+int SetAttr(CMessage* self, PyObject* name, PyObject* value) {
+ if (PyDict_Contains(self->composite_fields, name)) {
+ PyErr_SetString(PyExc_TypeError, "Can't set composite field");
+ return -1;
+ }
+
+ PyObject* descriptor = GetDescriptor(self, name);
+ if (descriptor != NULL) {
+ AssureWritable(self);
+ CFieldDescriptor* cdescriptor =
+ reinterpret_cast<CFieldDescriptor*>(descriptor);
+ const google::protobuf::FieldDescriptor* field_descriptor = cdescriptor->descriptor;
+ if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) {
+ PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated "
+ "field \"%s\" in protocol message object.",
+ field_descriptor->name().c_str());
+ return -1;
+ } else {
+ if (field_descriptor->cpp_type() ==
+ google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyErr_Format(PyExc_AttributeError, "Assignment not allowed to "
+ "field \"%s\" in protocol message object.",
+ field_descriptor->name().c_str());
+ return -1;
+ } else {
+ return InternalSetScalar(self, field_descriptor, value);
+ }
+ }
+ }
+
+ PyErr_Format(PyExc_AttributeError, "Assignment not allowed");
+ return -1;
+}
+
+} // namespace cmessage
+
+PyTypeObject CMessage_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "google.protobuf.internal."
+ "cpp._message.CMessage", // tp_name
+ sizeof(CMessage), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)cmessage::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ 0, // tp_as_sequence
+ 0, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ (reprfunc)cmessage::ToStr, // tp_str
+ (getattrofunc)cmessage::GetAttr, // tp_getattro
+ (setattrofunc)cmessage::SetAttr, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
+ "A ProtocolMessage", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ (richcmpfunc)cmessage::RichCompare, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ cmessage::Methods, // tp_methods
+ cmessage::Members, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ (initproc)cmessage::Init, // tp_init
+ 0, // tp_alloc
+ cmessage::New, // tp_new
+};
+
+// --- Exposing the C proto living inside Python proto to C code:
+
+const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg);
+Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg);
+
+static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) {
+ if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
+ return cmsg->message;
+}
+
+static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
+ if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
+ if (PyDict_Size(cmsg->composite_fields) != 0 ||
+ (cmsg->extensions != NULL &&
+ PyDict_Size(cmsg->extensions->values) != 0)) {
+ // There is currently no way of accurately syncing arbitrary changes to
+ // the underlying C++ message back to the CMessage (e.g. removed repeated
+ // composite containers). We only allow direct mutation of the underlying
+ // C++ message if there is no child data in the CMessage.
+ return NULL;
+ }
+ cmessage::AssureWritable(cmsg);
+ return cmsg->message;
+}
+
+static const char module_docstring[] =
+"python-proto2 is a module that can be used to enhance proto2 Python API\n"
+"performance.\n"
+"\n"
+"It provides access to the protocol buffers C++ reflection API that\n"
+"implements the basic protocol buffer functions.";
+
+void InitGlobals() {
+ // TODO(gps): Check all return values in this function for NULL and propagate
+ // the error (MemoryError) on up to result in an import failure. These should
+ // also be freed and reset to NULL during finalization.
+ kPythonZero = PyInt_FromLong(0);
+ kint32min_py = PyInt_FromLong(kint32min);
+ kint32max_py = PyInt_FromLong(kint32max);
+ kuint32max_py = PyLong_FromLongLong(kuint32max);
+ kint64min_py = PyLong_FromLongLong(kint64min);
+ kint64max_py = PyLong_FromLongLong(kint64max);
+ kuint64max_py = PyLong_FromUnsignedLongLong(kuint64max);
+
+ kDESCRIPTOR = PyString_FromString("DESCRIPTOR");
+ k__descriptors = PyString_FromString("__descriptors");
+ kfull_name = PyString_FromString("full_name");
+ kis_extendable = PyString_FromString("is_extendable");
+ kextensions_by_name = PyString_FromString("extensions_by_name");
+ k_extensions_by_name = PyString_FromString("_extensions_by_name");
+ k_extensions_by_number = PyString_FromString("_extensions_by_number");
+ k_concrete_class = PyString_FromString("_concrete_class");
+ kmessage_type = PyString_FromString("message_type");
+ kname = PyString_FromString("name");
+ kfields_by_name = PyString_FromString("fields_by_name");
+
+ global_message_factory = new DynamicMessageFactory(GetDescriptorPool());
+ global_message_factory->SetDelegateToGeneratedFactory(true);
+
+ descriptor_pool = reinterpret_cast<google::protobuf::python::CDescriptorPool*>(
+ Python_NewCDescriptorPool(NULL, NULL));
+}
+
+bool InitProto2MessageModule(PyObject *m) {
+ InitGlobals();
+
+ google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented;
+ if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) {
+ return false;
+ }
+
+ // All three of these are actually set elsewhere, directly onto the child
+ // protocol buffer message class, but set them here as well to document that
+ // subclasses need to set these.
+ PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
+ PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict,
+ k_extensions_by_name, Py_None);
+ PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict,
+ k_extensions_by_number, Py_None);
+
+ PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(
+ &google::protobuf::python::CMessage_Type));
+
+ google::protobuf::python::RepeatedScalarContainer_Type.tp_new = PyType_GenericNew;
+ google::protobuf::python::RepeatedScalarContainer_Type.tp_hash =
+ PyObject_HashNotImplemented;
+ if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "RepeatedScalarContainer",
+ reinterpret_cast<PyObject*>(
+ &google::protobuf::python::RepeatedScalarContainer_Type));
+
+ google::protobuf::python::RepeatedCompositeContainer_Type.tp_new = PyType_GenericNew;
+ google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash =
+ PyObject_HashNotImplemented;
+ if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(
+ m, "RepeatedCompositeContainer",
+ reinterpret_cast<PyObject*>(
+ &google::protobuf::python::RepeatedCompositeContainer_Type));
+
+ google::protobuf::python::ExtensionDict_Type.tp_new = PyType_GenericNew;
+ google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented;
+ if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(
+ m, "ExtensionDict",
+ reinterpret_cast<PyObject*>(&google::protobuf::python::ExtensionDict_Type));
+
+ if (!google::protobuf::python::InitDescriptor()) {
+ return false;
+ }
+
+ PyObject* enum_type_wrapper = PyImport_ImportModule(
+ "google.protobuf.internal.enum_type_wrapper");
+ if (enum_type_wrapper == NULL) {
+ return false;
+ }
+ google::protobuf::python::EnumTypeWrapper_class =
+ PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper");
+ Py_DECREF(enum_type_wrapper);
+
+ PyObject* message_module = PyImport_ImportModule(
+ "google.protobuf.message");
+ if (message_module == NULL) {
+ return false;
+ }
+ google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module,
+ "EncodeError");
+ google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module,
+ "DecodeError");
+ Py_DECREF(message_module);
+
+ PyObject* pickle_module = PyImport_ImportModule("pickle");
+ if (pickle_module == NULL) {
+ return false;
+ }
+ google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module,
+ "PickleError");
+ Py_DECREF(pickle_module);
+
+ // Override {Get,Mutable}CProtoInsidePyProto.
+ google::protobuf::python::GetCProtoInsidePyProtoPtr =
+ google::protobuf::python::GetCProtoInsidePyProtoImpl;
+ google::protobuf::python::MutableCProtoInsidePyProtoPtr =
+ google::protobuf::python::MutableCProtoInsidePyProtoImpl;
+
+ return true;
+}
+
+} // namespace python
+} // namespace protobuf
+
+
+#if PY_MAJOR_VERSION >= 3
+static struct PyModuleDef _module = {
+ PyModuleDef_HEAD_INIT,
+ "_message",
+ google::protobuf::python::module_docstring,
+ -1,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+#define INITFUNC PyInit__message
+#define INITFUNC_ERRORVAL NULL
+#else // Python 2
+#define INITFUNC init_message
+#define INITFUNC_ERRORVAL
+#endif
+
+extern "C" {
+ PyMODINIT_FUNC INITFUNC(void) {
+ PyObject* m;
+#if PY_MAJOR_VERSION >= 3
+ m = PyModule_Create(&_module);
+#else
+ m = Py_InitModule3("_message", NULL, google::protobuf::python::module_docstring);
+#endif
+ if (m == NULL) {
+ return INITFUNC_ERRORVAL;
+ }
+
+ if (!google::protobuf::python::InitProto2MessageModule(m)) {
+ Py_DECREF(m);
+ return INITFUNC_ERRORVAL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ return m;
+#endif
+ }
+}
+} // namespace google
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
new file mode 100644
index 0000000..28e504f
--- /dev/null
+++ b/python/google/protobuf/pyext/message.h
@@ -0,0 +1,305 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+#include <string>
+
+
+namespace google {
+namespace protobuf {
+
+class Message;
+class Reflection;
+class FieldDescriptor;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CFieldDescriptor;
+struct ExtensionDict;
+
+typedef struct CMessage {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python CMessage holds a reference to it in
+ // order to keep it alive as long as there's a Python object that
+ // references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Weak reference to a parent CMessage object. This is NULL for any top-level
+ // message and is set for any child message (i.e. a child submessage or a
+ // part of a repeated composite field).
+ //
+ // Used to make sure all ancestors are also mutable when first modifying
+ // a child submessage (in other words, turning a default message instance
+ // into a mutable one).
+ //
+ // If a submessage is released (becomes a new top-level message), this field
+ // MUST be set to NULL. The parent may get deallocated and further attempts
+ // to use this pointer will result in a crash.
+ struct CMessage* parent;
+
+ // Weak reference to the parent's descriptor that describes this submessage.
+ // Used together with the parent's message when making a default message
+ // instance mutable.
+ // TODO(anuraag): With a bit of work on the Python/C++ layer, it should be
+ // possible to make this a direct pointer to a C++ FieldDescriptor, this would
+ // be easier if this implementation replaces upstream.
+ CFieldDescriptor* parent_field;
+
+ // Pointer to the C++ Message object for this CMessage. The
+ // CMessage does not own this pointer.
+ Message* message;
+
+ // Indicates this submessage is pointing to a default instance of a message.
+ // Submessages are always first created as read only messages and are then
+ // made writable, at which point this field is set to false.
+ bool read_only;
+
+ // A reference to a Python dictionary containing CMessage,
+ // RepeatedCompositeContainer, and RepeatedScalarContainer
+ // objects. Used as a cache to make sure we don't have to make a
+ // Python wrapper for the C++ Message objects on every access, or
+ // deal with the synchronization nightmare that could create.
+ PyObject* composite_fields;
+
+ // A reference to the dictionary containing the message's extensions.
+ // Similar to composite_fields, acting as a cache, but also contains the
+ // required extension dict logic.
+ ExtensionDict* extensions;
+} CMessage;
+
+extern PyTypeObject CMessage_Type;
+
+namespace cmessage {
+
+// Create a new empty message that can be populated by the parent.
+PyObject* NewEmpty(PyObject* type);
+
+// Release a submessage from its proto tree, making it a new top-level messgae.
+// A new message will be created if this is a read-only default instance.
+//
+// Corresponds to reflection api method ReleaseMessage.
+int ReleaseSubMessage(google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ CMessage* child_cmessage);
+
+// Initializes a new CMessage instance for a submessage. Only called once per
+// submessage as the result is cached in composite_fields.
+//
+// Corresponds to reflection api method GetMessage.
+PyObject* InternalGetSubMessage(CMessage* self,
+ CFieldDescriptor* cfield_descriptor);
+
+// Deletes a range of C++ submessages in a repeated field (following a
+// removal in a RepeatedCompositeContainer).
+//
+// Releases messages to the provided cmessage_list if it is not NULL rather
+// than just removing them from the underlying proto. This cmessage_list must
+// have a CMessage for each underlying submessage. The CMessages refered to
+// by slice will be removed from cmessage_list by this function.
+//
+// Corresponds to reflection api method RemoveLast.
+int InternalDeleteRepeatedField(google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* slice, PyObject* cmessage_list);
+
+// Sets the specified scalar value to the message.
+int InternalSetScalar(CMessage* self,
+ const google::protobuf::FieldDescriptor* field_descriptor,
+ PyObject* value);
+
+// Retrieves the specified scalar value from the message.
+//
+// Returns a new python reference.
+PyObject* InternalGetScalar(CMessage* self,
+ const google::protobuf::FieldDescriptor* field_descriptor);
+
+// Clears the message, removing all contained data. Extension dictionary and
+// submessages are released first if there are remaining external references.
+//
+// Corresponds to message api method Clear.
+PyObject* Clear(CMessage* self);
+
+// Clears the data described by the given descriptor. Used to clear extensions
+// (which don't have names). Extension release is handled by ExtensionDict
+// class, not this function.
+// TODO(anuraag): Try to make this discrepancy in release semantics with
+// ClearField less confusing.
+//
+// Corresponds to reflection api method ClearField.
+PyObject* ClearFieldByDescriptor(
+ CMessage* self,
+ const google::protobuf::FieldDescriptor* descriptor);
+
+// Clears the data for the given field name. The message is released if there
+// are any external references.
+//
+// Corresponds to reflection api method ClearField.
+PyObject* ClearField(CMessage* self, PyObject* arg);
+
+// Checks if the message has the field described by the descriptor. Used for
+// extensions (which have no name).
+//
+// Corresponds to reflection api method HasField
+PyObject* HasFieldByDescriptor(
+ CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor);
+
+// Checks if the message has the named field.
+//
+// Corresponds to reflection api method HasField.
+PyObject* HasField(CMessage* self, PyObject* arg);
+
+// Initializes constants/enum values on a message. This is called by
+// RepeatedCompositeContainer and ExtensionDict after calling the constructor.
+// TODO(anuraag): Make it always called from within the constructor since it can
+int InitAttributes(CMessage* self, PyObject* descriptor, PyObject* kwargs);
+
+PyObject* MergeFrom(CMessage* self, PyObject* arg);
+
+// Retrieves an attribute named 'name' from CMessage 'self'. Returns
+// the attribute value on success, or NULL on failure.
+//
+// Returns a new reference.
+PyObject* GetAttr(CMessage* self, PyObject* name);
+
+// Set the value of the attribute named 'name', for CMessage 'self',
+// to the value 'value'. Returns -1 on failure.
+int SetAttr(CMessage* self, PyObject* name, PyObject* value);
+
+PyObject* FindInitializationErrors(CMessage* self);
+
+// Set the owner field of self and any children of self, recursively.
+// Used when self is being released and thus has a new owner (the
+// released Message.)
+int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner);
+
+int AssureWritable(CMessage* self);
+
+} // namespace cmessage
+
+/* Is 64bit */
+#define IS_64BIT (SIZEOF_LONG == 8)
+
+#define FIELD_BELONGS_TO_MESSAGE(field_descriptor, message) \
+ ((message)->GetDescriptor() == (field_descriptor)->containing_type())
+
+#define FIELD_IS_REPEATED(field_descriptor) \
+ ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED)
+
+#define GOOGLE_CHECK_GET_INT32(arg, value, err) \
+ int32 value; \
+ if (!CheckAndGetInteger(arg, &value, kint32min_py, kint32max_py)) { \
+ return err; \
+ }
+
+#define GOOGLE_CHECK_GET_INT64(arg, value, err) \
+ int64 value; \
+ if (!CheckAndGetInteger(arg, &value, kint64min_py, kint64max_py)) { \
+ return err; \
+ }
+
+#define GOOGLE_CHECK_GET_UINT32(arg, value, err) \
+ uint32 value; \
+ if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint32max_py)) { \
+ return err; \
+ }
+
+#define GOOGLE_CHECK_GET_UINT64(arg, value, err) \
+ uint64 value; \
+ if (!CheckAndGetInteger(arg, &value, kPythonZero, kuint64max_py)) { \
+ return err; \
+ }
+
+#define GOOGLE_CHECK_GET_FLOAT(arg, value, err) \
+ float value; \
+ if (!CheckAndGetFloat(arg, &value)) { \
+ return err; \
+ } \
+
+#define GOOGLE_CHECK_GET_DOUBLE(arg, value, err) \
+ double value; \
+ if (!CheckAndGetDouble(arg, &value)) { \
+ return err; \
+ }
+
+#define GOOGLE_CHECK_GET_BOOL(arg, value, err) \
+ bool value; \
+ if (!CheckAndGetBool(arg, &value)) { \
+ return err; \
+ }
+
+
+extern PyObject* kPythonZero;
+extern PyObject* kint32min_py;
+extern PyObject* kint32max_py;
+extern PyObject* kuint32max_py;
+extern PyObject* kint64min_py;
+extern PyObject* kint64max_py;
+extern PyObject* kuint64max_py;
+
+#define C(str) const_cast<char*>(str)
+
+void FormatTypeError(PyObject* arg, char* expected_types);
+template<class T>
+bool CheckAndGetInteger(
+ PyObject* arg, T* value, PyObject* min, PyObject* max);
+bool CheckAndGetDouble(PyObject* arg, double* value);
+bool CheckAndGetFloat(PyObject* arg, float* value);
+bool CheckAndGetBool(PyObject* arg, bool* value);
+bool CheckAndSetString(
+ PyObject* arg, google::protobuf::Message* message,
+ const google::protobuf::FieldDescriptor* descriptor,
+ const google::protobuf::Reflection* reflection,
+ bool append,
+ int index);
+PyObject* ToStringObject(
+ const google::protobuf::FieldDescriptor* descriptor, string value);
+
+extern PyObject* PickleError_class;
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
diff --git a/python/google/protobuf/pyext/message_factory_cpp2_test.py b/python/google/protobuf/pyext/message_factory_cpp2_test.py
new file mode 100644
index 0000000..fb52e1b
--- /dev/null
+++ b/python/google/protobuf/pyext/message_factory_cpp2_test.py
@@ -0,0 +1,56 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.message_factory."""
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
+
+# We must set the implementation version above before the google3 imports.
+# pylint: disable=g-import-not-at-top
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+# Run all tests from the original module by putting them in our namespace.
+# pylint: disable=wildcard-import
+from google.protobuf.internal.message_factory_test import *
+
+
+class ConfirmCppApi2Test(basetest.TestCase):
+
+ def testImplementationSetting(self):
+ self.assertEqual('cpp', api_implementation.Type())
+ self.assertEqual(2, api_implementation.Version())
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/pyext/proto2_api_test.proto b/python/google/protobuf/pyext/proto2_api_test.proto
new file mode 100644
index 0000000..eef9b73
--- /dev/null
+++ b/python/google/protobuf/pyext/proto2_api_test.proto
@@ -0,0 +1,38 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import "google/protobuf/internal/cpp/proto1_api_test.proto";
+
+package google.protobuf.python.internal;
+
+message TestNestedProto1APIMessage {
+ optional int32 a = 1;
+ optional TestMessage.NestedMessage b = 2;
+}
diff --git a/python/google/protobuf/pyext/python.proto b/python/google/protobuf/pyext/python.proto
new file mode 100644
index 0000000..ee6d5ab
--- /dev/null
+++ b/python/google/protobuf/pyext/python.proto
@@ -0,0 +1,66 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: tibell@google.com (Johan Tibell)
+//
+// These message definitions are used to exercises known corner cases
+// in the C++ implementation of the Python API.
+
+
+package google.protobuf.python.internal;
+
+// Protos optimized for SPEED use a strict superset of the generated code
+// of equivalent ones optimized for CODE_SIZE, so we should optimize all our
+// tests for speed unless explicitly testing code size optimization.
+option optimize_for = SPEED;
+
+message TestAllTypes {
+ message NestedMessage {
+ optional int32 bb = 1;
+ optional ForeignMessage cc = 2;
+ }
+
+ repeated NestedMessage repeated_nested_message = 1;
+ optional NestedMessage optional_nested_message = 2;
+ optional int32 optional_int32 = 3;
+}
+
+message ForeignMessage {
+ optional int32 c = 1;
+ repeated int32 d = 2;
+}
+
+message TestAllExtensions {
+ extensions 1 to max;
+}
+
+extend TestAllExtensions {
+ optional TestAllTypes.NestedMessage optional_nested_message_extension = 1;
+}
diff --git a/python/google/protobuf/pyext/python_protobuf.h b/python/google/protobuf/pyext/python_protobuf.h
new file mode 100644
index 0000000..c5b0b1c
--- /dev/null
+++ b/python/google/protobuf/pyext/python_protobuf.h
@@ -0,0 +1,57 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: qrczak@google.com (Marcin Kowalczyk)
+//
+// This module exposes the C proto inside the given Python proto, in
+// case the Python proto is implemented with a C proto.
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__
+#define GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__
+
+#include <Python.h>
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+namespace python {
+
+// Return the pointer to the C proto inside the given Python proto,
+// or NULL when this is not a Python proto implemented with a C proto.
+const Message* GetCProtoInsidePyProto(PyObject* msg);
+Message* MutableCProtoInsidePyProto(PyObject* msg);
+
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__
diff --git a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py
new file mode 100755
index 0000000..d7fce5f
--- /dev/null
+++ b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py
@@ -0,0 +1,94 @@
+#! /usr/bin/python
+# -*- coding: utf-8 -*-
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Unittest for reflection.py, which tests the generated C++ implementation."""
+
+__author__ = 'jasonh@google.com (Jason Hsueh)'
+
+import os
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp'
+os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2'
+
+from google.apputils import basetest
+from google.protobuf.internal import api_implementation
+from google.protobuf.internal import more_extensions_dynamic_pb2
+from google.protobuf.internal import more_extensions_pb2
+from google.protobuf.internal.reflection_test import *
+
+
+class ReflectionCppTest(basetest.TestCase):
+ def testImplementationSetting(self):
+ self.assertEqual('cpp', api_implementation.Type())
+ self.assertEqual(2, api_implementation.Version())
+
+ def testExtensionOfGeneratedTypeInDynamicFile(self):
+ """Tests that a file built dynamically can extend a generated C++ type.
+
+ The C++ implementation uses a DescriptorPool that has the generated
+ DescriptorPool as an underlay. Typically, a type can only find
+ extensions in its own pool. With the python C-extension, the generated C++
+ extendee may be available, but not the extension. This tests that the
+ C-extension implements the correct special handling to make such extensions
+ available.
+ """
+ pb1 = more_extensions_pb2.ExtendedMessage()
+ # Test that basic accessors work.
+ self.assertFalse(
+ pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
+ self.assertFalse(
+ pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
+ pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17
+ pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24
+ self.assertTrue(
+ pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
+ self.assertTrue(
+ pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
+
+ # Now serialize the data and parse to a new message.
+ pb2 = more_extensions_pb2.ExtendedMessage()
+ pb2.MergeFromString(pb1.SerializeToString())
+
+ self.assertTrue(
+ pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension))
+ self.assertTrue(
+ pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension))
+ self.assertEqual(
+ 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension])
+ self.assertEqual(
+ 24,
+ pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a)
+
+
+
+if __name__ == '__main__':
+ basetest.main()
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
new file mode 100644
index 0000000..b164505
--- /dev/null
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -0,0 +1,763 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#include <google/protobuf/pyext/repeated_composite_container.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyInt_Check PyLong_Check
+ #define PyInt_AsLong PyLong_AsLong
+ #define PyInt_FromLong PyLong_FromLong
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+extern google::protobuf::DynamicMessageFactory* global_message_factory;
+
+namespace repeated_composite_container {
+
+// TODO(tibell): We might also want to check:
+// GOOGLE_CHECK_NOTNULL((self)->owner.get());
+#define GOOGLE_CHECK_ATTACHED(self) \
+ do { \
+ GOOGLE_CHECK_NOTNULL((self)->message); \
+ GOOGLE_CHECK_NOTNULL((self)->parent_field); \
+ } while (0);
+
+#define GOOGLE_CHECK_RELEASED(self) \
+ do { \
+ GOOGLE_CHECK((self)->owner.get() == NULL); \
+ GOOGLE_CHECK((self)->message == NULL); \
+ GOOGLE_CHECK((self)->parent_field == NULL); \
+ GOOGLE_CHECK((self)->parent == NULL); \
+ } while (0);
+
+// Returns a new reference.
+static PyObject* GetKey(PyObject* x) {
+ // Just the identity function.
+ Py_INCREF(x);
+ return x;
+}
+
+#define GET_KEY(keyfunc, value) \
+ ((keyfunc) == NULL ? \
+ GetKey((value)) : \
+ PyObject_CallFunctionObjArgs((keyfunc), (value), NULL))
+
+// Converts a comparison function that returns -1, 0, or 1 into a
+// less-than predicate.
+//
+// Returns -1 on error, 1 if x < y, 0 if x >= y.
+static int islt(PyObject *x, PyObject *y, PyObject *compare) {
+ if (compare == NULL)
+ return PyObject_RichCompareBool(x, y, Py_LT);
+
+ ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL));
+ if (res == NULL)
+ return -1;
+ if (!PyInt_Check(res)) {
+ PyErr_Format(PyExc_TypeError,
+ "comparison function must return int, not %.200s",
+ Py_TYPE(res)->tp_name);
+ return -1;
+ }
+ return PyInt_AsLong(res) < 0;
+}
+
+// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps
+// TODO(anuraag): Is there a better way to do this then reinventing the wheel?
+static int InternalQuickSort(RepeatedCompositeContainer* self,
+ Py_ssize_t start,
+ Py_ssize_t limit,
+ PyObject* cmp,
+ PyObject* keyfunc) {
+ if (limit - start <= 1)
+ return 0; // Nothing to sort.
+
+ GOOGLE_CHECK_ATTACHED(self);
+
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor;
+ Py_ssize_t left;
+ Py_ssize_t right;
+
+ PyObject* children = self->child_messages;
+
+ do {
+ left = start;
+ right = limit;
+ ScopedPyObjectPtr mid(
+ GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2)));
+ do {
+ ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left)));
+ int is_lt = islt(key, mid, cmp);
+ if (is_lt == -1)
+ return -1;
+ /* array[left]<x */
+ while (is_lt) {
+ ++left;
+ ScopedPyObjectPtr key(GET_KEY(keyfunc,
+ PyList_GET_ITEM(children, left)));
+ is_lt = islt(key, mid, cmp);
+ if (is_lt == -1)
+ return -1;
+ }
+ key.reset(GET_KEY(keyfunc, PyList_GET_ITEM(children, right - 1)));
+ is_lt = islt(mid, key, cmp);
+ if (is_lt == -1)
+ return -1;
+ while (is_lt) {
+ --right;
+ ScopedPyObjectPtr key(GET_KEY(keyfunc,
+ PyList_GET_ITEM(children, right - 1)));
+ is_lt = islt(mid, key, cmp);
+ if (is_lt == -1)
+ return -1;
+ }
+ if (left < right) {
+ --right;
+ if (left < right) {
+ reflection->SwapElements(message, descriptor, left, right);
+ PyObject* tmp = PyList_GET_ITEM(children, left);
+ PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right));
+ PyList_SET_ITEM(children, right, tmp);
+ }
+ ++left;
+ }
+ } while (left < right);
+
+ if ((right - start) < (limit - left)) {
+ /* sort [start..right[ */
+ if (start < (right - 1)) {
+ InternalQuickSort(self, start, right, cmp, keyfunc);
+ }
+
+ /* sort [left..limit[ */
+ start = left;
+ } else {
+ /* sort [left..limit[ */
+ if (left < (limit - 1)) {
+ InternalQuickSort(self, left, limit, cmp, keyfunc);
+ }
+
+ /* sort [start..right[ */
+ limit = right;
+ }
+ } while (start < (limit - 1));
+
+ return 0;
+}
+
+#undef GET_KEY
+
+// ---------------------------------------------------------------------
+// len()
+
+static Py_ssize_t Length(RepeatedCompositeContainer* self) {
+ google::protobuf::Message* message = self->message;
+ if (message != NULL) {
+ return message->GetReflection()->FieldSize(*message,
+ self->parent_field->descriptor);
+ } else {
+ // The container has been released (i.e. by a call to Clear() or
+ // ClearField() on the parent) and thus there's no message.
+ return PyList_GET_SIZE(self->child_messages);
+ }
+}
+
+// Returns 0 if successful; returns -1 and sets an exception if
+// unsuccessful.
+static int UpdateChildMessages(RepeatedCompositeContainer* self) {
+ if (self->message == NULL)
+ return 0;
+
+ // A MergeFrom on a parent message could have caused extra messages to be
+ // added in the underlying protobuf so add them to our list. They can never
+ // be removed in such a way so there's no need to worry about that.
+ Py_ssize_t message_length = Length(self);
+ Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages);
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ for (Py_ssize_t i = child_length; i < message_length; ++i) {
+ const Message& sub_message = reflection->GetRepeatedMessage(
+ *(self->message), self->parent_field->descriptor, i);
+ ScopedPyObjectPtr py_cmsg(cmessage::NewEmpty(self->subclass_init));
+ if (py_cmsg == NULL) {
+ return -1;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg.get());
+ cmsg->owner = self->owner;
+ cmsg->message = const_cast<google::protobuf::Message*>(&sub_message);
+ cmsg->parent = self->parent;
+ if (cmessage::InitAttributes(cmsg, NULL, NULL) < 0) {
+ return -1;
+ }
+ PyList_Append(self->child_messages, py_cmsg);
+ }
+ return 0;
+}
+
+// ---------------------------------------------------------------------
+// add()
+
+static PyObject* AddToAttached(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwargs) {
+ GOOGLE_CHECK_ATTACHED(self);
+
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ if (cmessage::AssureWritable(self->parent) == -1)
+ return NULL;
+ google::protobuf::Message* message = self->message;
+ google::protobuf::Message* sub_message =
+ message->GetReflection()->AddMessage(message,
+ self->parent_field->descriptor);
+ PyObject* py_cmsg = cmessage::NewEmpty(self->subclass_init);
+ if (py_cmsg == NULL) {
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
+
+ cmsg->owner = self->owner;
+ cmsg->message = sub_message;
+ cmsg->parent = self->parent;
+ // cmessage::InitAttributes must be called after cmsg->message has
+ // been set.
+ if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) {
+ Py_DECREF(py_cmsg);
+ return NULL;
+ }
+ PyList_Append(self->child_messages, py_cmsg);
+ return py_cmsg;
+}
+
+static PyObject* AddToReleased(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwargs) {
+ GOOGLE_CHECK_RELEASED(self);
+
+ // Create the CMessage
+ PyObject* py_cmsg = PyObject_CallObject(self->subclass_init, NULL);
+ if (py_cmsg == NULL)
+ return NULL;
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
+ if (cmessage::InitAttributes(cmsg, NULL, kwargs) < 0) {
+ Py_DECREF(py_cmsg);
+ return NULL;
+ }
+
+ // The Message got created by the call to subclass_init above and
+ // it set self->owner to the newly allocated message.
+
+ PyList_Append(self->child_messages, py_cmsg);
+ return py_cmsg;
+}
+
+PyObject* Add(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwargs) {
+ if (self->message == NULL)
+ return AddToReleased(self, args, kwargs);
+ else
+ return AddToAttached(self, args, kwargs);
+}
+
+// ---------------------------------------------------------------------
+// extend()
+
+PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
+ cmessage::AssureWritable(self->parent);
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return NULL;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter))) != NULL) {
+ if (!PyObject_TypeCheck(next, &CMessage_Type)) {
+ PyErr_SetString(PyExc_TypeError, "Not a cmessage");
+ return NULL;
+ }
+ ScopedPyObjectPtr new_message(Add(self, NULL, NULL));
+ if (new_message == NULL) {
+ return NULL;
+ }
+ CMessage* new_cmessage = reinterpret_cast<CMessage*>(new_message.get());
+ if (cmessage::MergeFrom(new_cmessage, next) == NULL) {
+ return NULL;
+ }
+ }
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other) {
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ return Extend(self, other);
+}
+
+PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice) {
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ Py_ssize_t from;
+ Py_ssize_t to;
+ Py_ssize_t step;
+ Py_ssize_t length = Length(self);
+ Py_ssize_t slicelength;
+ if (PySlice_Check(slice)) {
+#if PY_MAJOR_VERSION >= 3
+ if (PySlice_GetIndicesEx(slice,
+#else
+ if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
+#endif
+ length, &from, &to, &step, &slicelength) == -1) {
+ return NULL;
+ }
+ return PyList_GetSlice(self->child_messages, from, to);
+ } else if (PyInt_Check(slice) || PyLong_Check(slice)) {
+ from = to = PyLong_AsLong(slice);
+ if (from < 0) {
+ from = to = length + from;
+ }
+ PyObject* result = PyList_GetItem(self->child_messages, from);
+ if (result == NULL) {
+ return NULL;
+ }
+ Py_INCREF(result);
+ return result;
+ }
+ PyErr_SetString(PyExc_TypeError, "index must be an integer or slice");
+ return NULL;
+}
+
+int AssignSubscript(RepeatedCompositeContainer* self,
+ PyObject* slice,
+ PyObject* value) {
+ if (UpdateChildMessages(self) < 0) {
+ return -1;
+ }
+ if (value != NULL) {
+ PyErr_SetString(PyExc_TypeError, "does not support assignment");
+ return -1;
+ }
+
+ // Delete from the underlying Message, if any.
+ if (self->message != NULL) {
+ if (cmessage::InternalDeleteRepeatedField(self->message,
+ self->parent_field->descriptor,
+ slice,
+ self->child_messages) < 0) {
+ return -1;
+ }
+ } else {
+ Py_ssize_t from;
+ Py_ssize_t to;
+ Py_ssize_t step;
+ Py_ssize_t length = Length(self);
+ Py_ssize_t slicelength;
+ if (PySlice_Check(slice)) {
+#if PY_MAJOR_VERSION >= 3
+ if (PySlice_GetIndicesEx(slice,
+#else
+ if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
+#endif
+ length, &from, &to, &step, &slicelength) == -1) {
+ return -1;
+ }
+ return PySequence_DelSlice(self->child_messages, from, to);
+ } else if (PyInt_Check(slice) || PyLong_Check(slice)) {
+ from = to = PyLong_AsLong(slice);
+ if (from < 0) {
+ from = to = length + from;
+ }
+ return PySequence_DelItem(self->child_messages, from);
+ }
+ }
+
+ return 0;
+}
+
+static PyObject* Remove(RepeatedCompositeContainer* self, PyObject* value) {
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ Py_ssize_t index = PySequence_Index(self->child_messages, value);
+ if (index == -1) {
+ return NULL;
+ }
+ ScopedPyObjectPtr py_index(PyLong_FromLong(index));
+ if (AssignSubscript(self, py_index, NULL) < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyObject* RichCompare(RepeatedCompositeContainer* self,
+ PyObject* other,
+ int opid) {
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(other, &RepeatedCompositeContainer_Type)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Can only compare repeated composite fields "
+ "against other repeated composite fields.");
+ return NULL;
+ }
+ if (opid == Py_EQ || opid == Py_NE) {
+ // TODO(anuraag): Don't make new lists just for this...
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(self, full_slice));
+ if (list == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr other_list(
+ Subscript(
+ reinterpret_cast<RepeatedCompositeContainer*>(other), full_slice));
+ if (other_list == NULL) {
+ return NULL;
+ }
+ return PyObject_RichCompare(list, other_list, opid);
+ } else {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+}
+
+// ---------------------------------------------------------------------
+// sort()
+
+static PyObject* SortAttached(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwds) {
+ // Sort the underlying Message array.
+ PyObject *compare = NULL;
+ int reverse = 0;
+ PyObject *keyfunc = NULL;
+ static char *kwlist[] = {"cmp", "key", "reverse", 0};
+
+ if (args != NULL) {
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort",
+ kwlist, &compare, &keyfunc, &reverse))
+ return NULL;
+ }
+ if (compare == Py_None)
+ compare = NULL;
+ if (keyfunc == Py_None)
+ keyfunc = NULL;
+
+ const Py_ssize_t length = Length(self);
+ if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0)
+ return NULL;
+
+ // Finally reverse the result if requested.
+ if (reverse) {
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ const google::protobuf::FieldDescriptor* descriptor = self->parent_field->descriptor;
+
+ // Reverse the Message array.
+ for (int i = 0; i < length / 2; ++i)
+ reflection->SwapElements(message, descriptor, i, length - i - 1);
+
+ // Reverse the Python list.
+ ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages,
+ "reverse", NULL));
+ if (res == NULL)
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
+}
+
+static PyObject* SortReleased(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwds) {
+ ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort"));
+ if (m == NULL)
+ return NULL;
+ if (PyObject_Call(m, args, kwds) == NULL)
+ return NULL;
+ Py_RETURN_NONE;
+}
+
+static PyObject* Sort(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwds) {
+ // Support the old sort_function argument for backwards
+ // compatibility.
+ if (kwds != NULL) {
+ PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function");
+ if (sort_func != NULL) {
+ // Must set before deleting as sort_func is a borrowed reference
+ // and kwds might be the only thing keeping it alive.
+ PyDict_SetItemString(kwds, "cmp", sort_func);
+ PyDict_DelItemString(kwds, "sort_function");
+ }
+ }
+
+ if (UpdateChildMessages(self) < 0)
+ return NULL;
+ if (self->message == NULL) {
+ return SortReleased(self, args, kwds);
+ } else {
+ return SortAttached(self, args, kwds);
+ }
+}
+
+// ---------------------------------------------------------------------
+
+static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) {
+ if (UpdateChildMessages(self) < 0) {
+ return NULL;
+ }
+ Py_ssize_t length = Length(self);
+ if (index < 0) {
+ index = length + index;
+ }
+ PyObject* item = PyList_GetItem(self->child_messages, index);
+ if (item == NULL) {
+ return NULL;
+ }
+ Py_INCREF(item);
+ return item;
+}
+
+// The caller takes ownership of the returned Message.
+Message* ReleaseLast(const FieldDescriptor* field,
+ const Descriptor* type,
+ Message* message) {
+ GOOGLE_CHECK_NOTNULL(field);
+ GOOGLE_CHECK_NOTNULL(type);
+ GOOGLE_CHECK_NOTNULL(message);
+
+ Message* released_message = message->GetReflection()->ReleaseLast(
+ message, field);
+ // TODO(tibell): Deal with proto1.
+
+ // ReleaseMessage will return NULL which differs from
+ // child_cmessage->message, if the field does not exist. In this case,
+ // the latter points to the default instance via a const_cast<>, so we
+ // have to reset it to a new mutable object since we are taking ownership.
+ if (released_message == NULL) {
+ const Message* prototype = global_message_factory->GetPrototype(type);
+ GOOGLE_CHECK_NOTNULL(prototype);
+ return prototype->New();
+ } else {
+ return released_message;
+ }
+}
+
+// Release field of message and transfer the ownership to cmessage.
+void ReleaseLastTo(const FieldDescriptor* field,
+ Message* message,
+ CMessage* cmessage) {
+ GOOGLE_CHECK_NOTNULL(field);
+ GOOGLE_CHECK_NOTNULL(message);
+ GOOGLE_CHECK_NOTNULL(cmessage);
+
+ shared_ptr<Message> released_message(
+ ReleaseLast(field, cmessage->message->GetDescriptor(), message));
+ cmessage->parent = NULL;
+ cmessage->parent_field = NULL;
+ cmessage->message = released_message.get();
+ cmessage->read_only = false;
+ cmessage::SetOwner(cmessage, released_message);
+}
+
+// Called to release a container using
+// ClearField('container_field_name') on the parent.
+int Release(RepeatedCompositeContainer* self) {
+ if (UpdateChildMessages(self) < 0) {
+ PyErr_WriteUnraisable(PyBytes_FromString("Failed to update released "
+ "messages"));
+ return -1;
+ }
+
+ Message* message = self->message;
+ const FieldDescriptor* field = self->parent_field->descriptor;
+
+ // The reflection API only lets us release the last message in a
+ // repeated field. Therefore we iterate through the children
+ // starting with the last one.
+ const Py_ssize_t size = PyList_GET_SIZE(self->child_messages);
+ GOOGLE_DCHECK_EQ(size, message->GetReflection()->FieldSize(*message, field));
+ for (Py_ssize_t i = size - 1; i >= 0; --i) {
+ CMessage* child_cmessage = reinterpret_cast<CMessage*>(
+ PyList_GET_ITEM(self->child_messages, i));
+ ReleaseLastTo(field, message, child_cmessage);
+ }
+
+ // Detach from containing message.
+ self->parent = NULL;
+ self->parent_field = NULL;
+ self->message = NULL;
+ self->owner.reset();
+
+ return 0;
+}
+
+int SetOwner(RepeatedCompositeContainer* self,
+ const shared_ptr<Message>& new_owner) {
+ GOOGLE_CHECK_ATTACHED(self);
+
+ self->owner = new_owner;
+ const Py_ssize_t n = PyList_GET_SIZE(self->child_messages);
+ for (Py_ssize_t i = 0; i < n; ++i) {
+ PyObject* msg = PyList_GET_ITEM(self->child_messages, i);
+ if (cmessage::SetOwner(reinterpret_cast<CMessage*>(msg), new_owner) == -1) {
+ return -1;
+ }
+ }
+ return 0;
+}
+
+static int Init(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwargs) {
+ self->message = NULL;
+ self->parent = NULL;
+ self->parent_field = NULL;
+ self->subclass_init = NULL;
+ self->child_messages = PyList_New(0);
+ return 0;
+}
+
+static void Dealloc(RepeatedCompositeContainer* self) {
+ Py_CLEAR(self->child_messages);
+ // TODO(tibell): Do we need to call delete on these objects to make
+ // sure their destructors are called?
+ self->owner.reset();
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+static PySequenceMethods SqMethods = {
+ (lenfunc)Length, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ (ssizeargfunc)Item /* sq_item */
+};
+
+static PyMappingMethods MpMethods = {
+ (lenfunc)Length, /* mp_length */
+ (binaryfunc)Subscript, /* mp_subscript */
+ (objobjargproc)AssignSubscript,/* mp_ass_subscript */
+};
+
+static PyMethodDef Methods[] = {
+ { "add", (PyCFunction) Add, METH_VARARGS | METH_KEYWORDS,
+ "Adds an object to the repeated container." },
+ { "extend", (PyCFunction) Extend, METH_O,
+ "Adds objects to the repeated container." },
+ { "remove", (PyCFunction) Remove, METH_O,
+ "Removes an object from the repeated container." },
+ { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS,
+ "Sorts the repeated container." },
+ { "MergeFrom", (PyCFunction) MergeFrom, METH_O,
+ "Adds objects to the repeated container." },
+ { NULL, NULL }
+};
+
+} // namespace repeated_composite_container
+
+PyTypeObject RepeatedCompositeContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "google.protobuf.internal."
+ "cpp._message.RepeatedCompositeContainer", // tp_name
+ sizeof(RepeatedCompositeContainer), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)repeated_composite_container::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ &repeated_composite_container::SqMethods, // tp_as_sequence
+ &repeated_composite_container::MpMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Repeated scalar container", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ (richcmpfunc)repeated_composite_container::RichCompare, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ repeated_composite_container::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ (initproc)repeated_composite_container::Init, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h
new file mode 100644
index 0000000..e8ed30e
--- /dev/null
+++ b/python/google/protobuf/pyext/repeated_composite_container.h
@@ -0,0 +1,172 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+#include <string>
+#include <vector>
+
+
+namespace google {
+namespace protobuf {
+
+class FieldDescriptor;
+class Message;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CMessage;
+struct CFieldDescriptor;
+
+// A RepeatedCompositeContainer can be in one of two states: attached
+// or released.
+//
+// When in the attached state all modifications to the container are
+// done both on the 'message' and on the 'child_messages'
+// list. In this state all Messages refered to by the children in
+// 'child_messages' are owner by the 'owner'.
+//
+// When in the released state 'message', 'owner', 'parent', and
+// 'parent_field' are NULL.
+typedef struct RepeatedCompositeContainer {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python RepeatedCompositeContainer holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Weak reference to parent object. May be NULL. Used to make sure
+ // the parent is writable before modifying the
+ // RepeatedCompositeContainer.
+ CMessage* parent;
+
+ // A descriptor used to modify the underlying 'message'.
+ CFieldDescriptor* parent_field;
+
+ // Pointer to the C++ Message that contains this container. The
+ // RepeatedCompositeContainer does not own this pointer.
+ //
+ // If NULL, this message has been released from its parent (by
+ // calling Clear() or ClearField() on the parent.
+ Message* message;
+
+ // A callable that is used to create new child messages.
+ PyObject* subclass_init;
+
+ // A list of child messages.
+ PyObject* child_messages;
+} RepeatedCompositeContainer;
+
+extern PyTypeObject RepeatedCompositeContainer_Type;
+
+namespace repeated_composite_container {
+
+// Returns the number of items in this repeated composite container.
+static Py_ssize_t Length(RepeatedCompositeContainer* self);
+
+// Appends a new CMessage to the container and returns it. The
+// CMessage is initialized using the content of kwargs.
+//
+// Returns a new reference if successful; returns NULL and sets an
+// exception if unsuccessful.
+PyObject* Add(RepeatedCompositeContainer* self,
+ PyObject* args,
+ PyObject* kwargs);
+
+// Appends all the CMessages in the input iterator to the container.
+//
+// Returns None if successful; returns NULL and sets an exception if
+// unsuccessful.
+PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value);
+
+// Appends a new message to the container for each message in the
+// input iterator, merging each data element in. Equivalent to extend.
+//
+// Returns None if successful; returns NULL and sets an exception if
+// unsuccessful.
+PyObject* MergeFrom(RepeatedCompositeContainer* self, PyObject* other);
+
+// Accesses messages in the container.
+//
+// Returns a new reference to the message for an integer parameter.
+// Returns a new reference to a list of messages for a slice.
+PyObject* Subscript(RepeatedCompositeContainer* self, PyObject* slice);
+
+// Deletes items from the container (cannot be used for assignment).
+//
+// Returns 0 on success, -1 on failure.
+int AssignSubscript(RepeatedCompositeContainer* self,
+ PyObject* slice,
+ PyObject* value);
+
+// Releases the messages in the container to the given message.
+//
+// Returns 0 on success, -1 on failure.
+int ReleaseToMessage(RepeatedCompositeContainer* self,
+ google::protobuf::Message* new_message);
+
+// Releases the messages in the container to a new message.
+//
+// Returns 0 on success, -1 on failure.
+int Release(RepeatedCompositeContainer* self);
+
+// Returns 0 on success, -1 on failure.
+int SetOwner(RepeatedCompositeContainer* self,
+ const shared_ptr<Message>& new_owner);
+
+// Removes the last element of the repeated message field 'field' on
+// the Message 'message', and transfers the ownership of the released
+// Message to 'cmessage'.
+//
+// Corresponds to reflection api method ReleaseMessage.
+void ReleaseLastTo(const FieldDescriptor* field,
+ Message* message,
+ CMessage* cmessage);
+
+} // namespace repeated_composite_container
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
new file mode 100644
index 0000000..b0fcd81
--- /dev/null
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -0,0 +1,825 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#include <google/protobuf/pyext/repeated_scalar_container.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+#include <google/protobuf/stubs/common.h>
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/dynamic_message.h>
+#include <google/protobuf/message.h>
+#include <google/protobuf/pyext/descriptor.h>
+#include <google/protobuf/pyext/message.h>
+#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+
+#if PY_MAJOR_VERSION >= 3
+ #define PyInt_FromLong PyLong_FromLong
+ #if PY_VERSION_HEX < 0x03030000
+ #error "Python 3.0 - 3.2 are not supported."
+ #else
+ #define PyString_AsString(ob) \
+ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AS_STRING(ob))
+ #endif
+#endif
+
+namespace google {
+namespace protobuf {
+namespace python {
+
+extern google::protobuf::DynamicMessageFactory* global_message_factory;
+
+namespace repeated_scalar_container {
+
+static int InternalAssignRepeatedField(
+ RepeatedScalarContainer* self, PyObject* list) {
+ self->message->GetReflection()->ClearField(self->message,
+ self->parent_field->descriptor);
+ for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) {
+ PyObject* value = PyList_GET_ITEM(list, i);
+ if (Append(self, value) == NULL) {
+ return -1;
+ }
+ }
+ return 0;
+}
+
+static Py_ssize_t Len(RepeatedScalarContainer* self) {
+ google::protobuf::Message* message = self->message;
+ return message->GetReflection()->FieldSize(*message,
+ self->parent_field->descriptor);
+}
+
+static int AssignItem(RepeatedScalarContainer* self,
+ Py_ssize_t index,
+ PyObject* arg) {
+ cmessage::AssureWritable(self->parent);
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ self->parent_field->descriptor;
+ if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
+ PyErr_SetString(
+ PyExc_KeyError, "Field does not belong to message!");
+ return -1;
+ }
+
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ int field_size = reflection->FieldSize(*message, field_descriptor);
+ if (index < 0) {
+ index = field_size + index;
+ }
+ if (index < 0 || index >= field_size) {
+ PyErr_Format(PyExc_IndexError,
+ "list assignment index (%d) out of range",
+ static_cast<int>(index));
+ return -1;
+ }
+
+ if (arg == NULL) {
+ ScopedPyObjectPtr py_index(PyLong_FromLong(index));
+ return cmessage::InternalDeleteRepeatedField(message, field_descriptor,
+ py_index, NULL);
+ }
+
+ if (PySequence_Check(arg) && !(PyBytes_Check(arg) || PyUnicode_Check(arg))) {
+ PyErr_SetString(PyExc_TypeError, "Value must be scalar");
+ return -1;
+ }
+
+ switch (field_descriptor->cpp_type()) {
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(arg, value, -1);
+ reflection->SetRepeatedInt32(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(arg, value, -1);
+ reflection->SetRepeatedInt64(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(arg, value, -1);
+ reflection->SetRepeatedUInt32(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(arg, value, -1);
+ reflection->SetRepeatedUInt64(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ GOOGLE_CHECK_GET_FLOAT(arg, value, -1);
+ reflection->SetRepeatedFloat(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ GOOGLE_CHECK_GET_DOUBLE(arg, value, -1);
+ reflection->SetRepeatedDouble(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(arg, value, -1);
+ reflection->SetRepeatedBool(message, field_descriptor, index, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ if (!CheckAndSetString(
+ arg, message, field_descriptor, reflection, false, index)) {
+ return -1;
+ }
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ GOOGLE_CHECK_GET_INT32(arg, value, -1);
+ const google::protobuf::EnumDescriptor* enum_descriptor =
+ field_descriptor->enum_type();
+ const google::protobuf::EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->SetRepeatedEnum(message, field_descriptor, index,
+ enum_value);
+ } else {
+ ScopedPyObjectPtr s(PyObject_Str(arg));
+ if (s != NULL) {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
+ PyString_AsString(s.get()));
+ }
+ return -1;
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Adding value to a field of unknown type %d",
+ field_descriptor->cpp_type());
+ return -1;
+ }
+ return 0;
+}
+
+static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) {
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ self->parent_field->descriptor;
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+
+ int field_size = reflection->FieldSize(*message, field_descriptor);
+ if (index < 0) {
+ index = field_size + index;
+ }
+ if (index < 0 || index >= field_size) {
+ PyErr_Format(PyExc_IndexError,
+ "list assignment index (%d) out of range",
+ static_cast<int>(index));
+ return NULL;
+ }
+
+ PyObject* result = NULL;
+ switch (field_descriptor->cpp_type()) {
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ int32 value = reflection->GetRepeatedInt32(
+ *message, field_descriptor, index);
+ result = PyInt_FromLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ int64 value = reflection->GetRepeatedInt64(
+ *message, field_descriptor, index);
+ result = PyLong_FromLongLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ uint32 value = reflection->GetRepeatedUInt32(
+ *message, field_descriptor, index);
+ result = PyLong_FromLongLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ uint64 value = reflection->GetRepeatedUInt64(
+ *message, field_descriptor, index);
+ result = PyLong_FromUnsignedLongLong(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ float value = reflection->GetRepeatedFloat(
+ *message, field_descriptor, index);
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ double value = reflection->GetRepeatedDouble(
+ *message, field_descriptor, index);
+ result = PyFloat_FromDouble(value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ bool value = reflection->GetRepeatedBool(
+ *message, field_descriptor, index);
+ result = PyBool_FromLong(value ? 1 : 0);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ const google::protobuf::EnumValueDescriptor* enum_value =
+ message->GetReflection()->GetRepeatedEnum(
+ *message, field_descriptor, index);
+ result = PyInt_FromLong(enum_value->number());
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ string value = reflection->GetRepeatedString(
+ *message, field_descriptor, index);
+ result = ToStringObject(field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
+ PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast<PyObject*>(
+ &CMessage_Type), NULL);
+ if (py_cmsg == NULL) {
+ return NULL;
+ }
+ CMessage* cmsg = reinterpret_cast<CMessage*>(py_cmsg);
+ const google::protobuf::Message& msg = reflection->GetRepeatedMessage(
+ *message, field_descriptor, index);
+ cmsg->owner = self->owner;
+ cmsg->parent = self->parent;
+ cmsg->message = const_cast<google::protobuf::Message*>(&msg);
+ cmsg->read_only = false;
+ result = reinterpret_cast<PyObject*>(py_cmsg);
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError,
+ "Getting value from a repeated field of unknown type %d",
+ field_descriptor->cpp_type());
+ }
+
+ return result;
+}
+
+static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) {
+ Py_ssize_t from;
+ Py_ssize_t to;
+ Py_ssize_t step;
+ Py_ssize_t length;
+ Py_ssize_t slicelength;
+ bool return_list = false;
+#if PY_MAJOR_VERSION < 3
+ if (PyInt_Check(slice)) {
+ from = to = PyInt_AsLong(slice);
+ } else // NOLINT
+#endif
+ if (PyLong_Check(slice)) {
+ from = to = PyLong_AsLong(slice);
+ } else if (PySlice_Check(slice)) {
+ length = Len(self);
+#if PY_MAJOR_VERSION >= 3
+ if (PySlice_GetIndicesEx(slice,
+#else
+ if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
+#endif
+ length, &from, &to, &step, &slicelength) == -1) {
+ return NULL;
+ }
+ return_list = true;
+ } else {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return NULL;
+ }
+
+ if (!return_list) {
+ return Item(self, from);
+ }
+
+ PyObject* list = PyList_New(0);
+ if (list == NULL) {
+ return NULL;
+ }
+ if (from <= to) {
+ if (step < 0) {
+ return list;
+ }
+ for (Py_ssize_t index = from; index < to; index += step) {
+ if (index < 0 || index >= length) {
+ break;
+ }
+ ScopedPyObjectPtr s(Item(self, index));
+ PyList_Append(list, s);
+ }
+ } else {
+ if (step > 0) {
+ return list;
+ }
+ for (Py_ssize_t index = from; index > to; index += step) {
+ if (index < 0 || index >= length) {
+ break;
+ }
+ ScopedPyObjectPtr s(Item(self, index));
+ PyList_Append(list, s);
+ }
+ }
+ return list;
+}
+
+PyObject* Append(RepeatedScalarContainer* self, PyObject* item) {
+ cmessage::AssureWritable(self->parent);
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ self->parent_field->descriptor;
+
+ if (!FIELD_BELONGS_TO_MESSAGE(field_descriptor, message)) {
+ PyErr_SetString(
+ PyExc_KeyError, "Field does not belong to message!");
+ return NULL;
+ }
+
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ switch (field_descriptor->cpp_type()) {
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT32: {
+ GOOGLE_CHECK_GET_INT32(item, value, NULL);
+ reflection->AddInt32(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_INT64: {
+ GOOGLE_CHECK_GET_INT64(item, value, NULL);
+ reflection->AddInt64(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: {
+ GOOGLE_CHECK_GET_UINT32(item, value, NULL);
+ reflection->AddUInt32(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: {
+ GOOGLE_CHECK_GET_UINT64(item, value, NULL);
+ reflection->AddUInt64(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: {
+ GOOGLE_CHECK_GET_FLOAT(item, value, NULL);
+ reflection->AddFloat(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: {
+ GOOGLE_CHECK_GET_DOUBLE(item, value, NULL);
+ reflection->AddDouble(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: {
+ GOOGLE_CHECK_GET_BOOL(item, value, NULL);
+ reflection->AddBool(message, field_descriptor, value);
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
+ if (!CheckAndSetString(
+ item, message, field_descriptor, reflection, true, -1)) {
+ return NULL;
+ }
+ break;
+ }
+ case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
+ GOOGLE_CHECK_GET_INT32(item, value, NULL);
+ const google::protobuf::EnumDescriptor* enum_descriptor =
+ field_descriptor->enum_type();
+ const google::protobuf::EnumValueDescriptor* enum_value =
+ enum_descriptor->FindValueByNumber(value);
+ if (enum_value != NULL) {
+ reflection->AddEnum(message, field_descriptor, enum_value);
+ } else {
+ ScopedPyObjectPtr s(PyObject_Str(item));
+ if (s != NULL) {
+ PyErr_Format(PyExc_ValueError, "Unknown enum value: %s",
+ PyString_AsString(s.get()));
+ }
+ return NULL;
+ }
+ break;
+ }
+ default:
+ PyErr_Format(
+ PyExc_SystemError, "Adding value to a field of unknown type %d",
+ field_descriptor->cpp_type());
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
+}
+
+static int AssSubscript(RepeatedScalarContainer* self,
+ PyObject* slice,
+ PyObject* value) {
+ Py_ssize_t from;
+ Py_ssize_t to;
+ Py_ssize_t step;
+ Py_ssize_t length;
+ Py_ssize_t slicelength;
+ bool create_list = false;
+
+ cmessage::AssureWritable(self->parent);
+ google::protobuf::Message* message = self->message;
+ const google::protobuf::FieldDescriptor* field_descriptor =
+ self->parent_field->descriptor;
+
+#if PY_MAJOR_VERSION < 3
+ if (PyInt_Check(slice)) {
+ from = to = PyInt_AsLong(slice);
+ } else
+#endif
+ if (PyLong_Check(slice)) {
+ from = to = PyLong_AsLong(slice);
+ } else if (PySlice_Check(slice)) {
+ const google::protobuf::Reflection* reflection = message->GetReflection();
+ length = reflection->FieldSize(*message, field_descriptor);
+#if PY_MAJOR_VERSION >= 3
+ if (PySlice_GetIndicesEx(slice,
+#else
+ if (PySlice_GetIndicesEx(reinterpret_cast<PySliceObject*>(slice),
+#endif
+ length, &from, &to, &step, &slicelength) == -1) {
+ return -1;
+ }
+ create_list = true;
+ } else {
+ PyErr_SetString(PyExc_TypeError, "list indices must be integers");
+ return -1;
+ }
+
+ if (value == NULL) {
+ return cmessage::InternalDeleteRepeatedField(
+ message, field_descriptor, slice, NULL);
+ }
+
+ if (!create_list) {
+ return AssignItem(self, from, value);
+ }
+
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr new_list(Subscript(self, full_slice));
+ if (new_list == NULL) {
+ return -1;
+ }
+ if (PySequence_SetSlice(new_list, from, to, value) < 0) {
+ return -1;
+ }
+
+ return InternalAssignRepeatedField(self, new_list);
+}
+
+PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) {
+ cmessage::AssureWritable(self->parent);
+ if (PyObject_Not(value)) {
+ Py_RETURN_NONE;
+ }
+ ScopedPyObjectPtr iter(PyObject_GetIter(value));
+ if (iter == NULL) {
+ PyErr_SetString(PyExc_TypeError, "Value must be iterable");
+ return NULL;
+ }
+ ScopedPyObjectPtr next;
+ while ((next.reset(PyIter_Next(iter))) != NULL) {
+ if (Append(self, next) == NULL) {
+ return NULL;
+ }
+ }
+ if (PyErr_Occurred()) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyObject* Insert(RepeatedScalarContainer* self, PyObject* args) {
+ Py_ssize_t index;
+ PyObject* value;
+ if (!PyArg_ParseTuple(args, "lO", &index, &value)) {
+ return NULL;
+ }
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ ScopedPyObjectPtr new_list(Subscript(self, full_slice));
+ if (PyList_Insert(new_list, index, value) < 0) {
+ return NULL;
+ }
+ int ret = InternalAssignRepeatedField(self, new_list);
+ if (ret < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyObject* Remove(RepeatedScalarContainer* self, PyObject* value) {
+ Py_ssize_t match_index = -1;
+ for (Py_ssize_t i = 0; i < Len(self); ++i) {
+ ScopedPyObjectPtr elem(Item(self, i));
+ if (PyObject_RichCompareBool(elem, value, Py_EQ)) {
+ match_index = i;
+ break;
+ }
+ }
+ if (match_index == -1) {
+ PyErr_SetString(PyExc_ValueError, "remove(x): x not in container");
+ return NULL;
+ }
+ if (AssignItem(self, match_index, NULL) < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyObject* RichCompare(RepeatedScalarContainer* self,
+ PyObject* other,
+ int opid) {
+ if (opid != Py_EQ && opid != Py_NE) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
+ // Copy the contents of this repeated scalar container, and other if it is
+ // also a repeated scalar container, into Python lists so we can delegate
+ // to the list's compare method.
+
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+
+ ScopedPyObjectPtr other_list_deleter;
+ if (PyObject_TypeCheck(other, &RepeatedScalarContainer_Type)) {
+ other_list_deleter.reset(Subscript(
+ reinterpret_cast<RepeatedScalarContainer*>(other), full_slice));
+ other = other_list_deleter.get();
+ }
+
+ ScopedPyObjectPtr list(Subscript(self, full_slice));
+ if (list == NULL) {
+ return NULL;
+ }
+ return PyObject_RichCompare(list, other, opid);
+}
+
+PyObject* Reduce(RepeatedScalarContainer* unused_self) {
+ PyErr_Format(
+ PickleError_class,
+ "can't pickle repeated message fields, convert to list first");
+ return NULL;
+}
+
+static PyObject* Sort(RepeatedScalarContainer* self,
+ PyObject* args,
+ PyObject* kwds) {
+ // Support the old sort_function argument for backwards
+ // compatibility.
+ if (kwds != NULL) {
+ PyObject* sort_func = PyDict_GetItemString(kwds, "sort_function");
+ if (sort_func != NULL) {
+ // Must set before deleting as sort_func is a borrowed reference
+ // and kwds might be the only thing keeping it alive.
+ if (PyDict_SetItemString(kwds, "cmp", sort_func) == -1)
+ return NULL;
+ if (PyDict_DelItemString(kwds, "sort_function") == -1)
+ return NULL;
+ }
+ }
+
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr list(Subscript(self, full_slice));
+ if (list == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr m(PyObject_GetAttrString(list, "sort"));
+ if (m == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr res(PyObject_Call(m, args, kwds));
+ if (res == NULL) {
+ return NULL;
+ }
+ int ret = InternalAssignRepeatedField(self, list);
+ if (ret < 0) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static int Init(RepeatedScalarContainer* self,
+ PyObject* args,
+ PyObject* kwargs) {
+ PyObject* py_parent;
+ PyObject* py_parent_field;
+ if (!PyArg_UnpackTuple(args, "__init__()", 2, 2, &py_parent,
+ &py_parent_field)) {
+ return -1;
+ }
+
+ if (!PyObject_TypeCheck(py_parent, &CMessage_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "expect %s, but got %s",
+ CMessage_Type.tp_name,
+ Py_TYPE(py_parent)->tp_name);
+ return -1;
+ }
+
+ if (!PyObject_TypeCheck(py_parent_field, &CFieldDescriptor_Type)) {
+ PyErr_Format(PyExc_TypeError,
+ "expect %s, but got %s",
+ CFieldDescriptor_Type.tp_name,
+ Py_TYPE(py_parent_field)->tp_name);
+ return -1;
+ }
+
+ CMessage* cmessage = reinterpret_cast<CMessage*>(py_parent);
+ CFieldDescriptor* cdescriptor = reinterpret_cast<CFieldDescriptor*>(
+ py_parent_field);
+
+ if (!FIELD_BELONGS_TO_MESSAGE(cdescriptor->descriptor, cmessage->message)) {
+ PyErr_SetString(
+ PyExc_KeyError, "Field does not belong to message!");
+ return -1;
+ }
+
+ self->message = cmessage->message;
+ self->parent = cmessage;
+ self->parent_field = cdescriptor;
+ self->owner = cmessage->owner;
+ return 0;
+}
+
+// Initializes the underlying Message object of "to" so it becomes a new parent
+// repeated scalar, and copies all the values from "from" to it. A child scalar
+// container can be released by passing it as both from and to (e.g. making it
+// the recipient of the new parent message and copying the values from itself).
+static int InitializeAndCopyToParentContainer(
+ RepeatedScalarContainer* from,
+ RepeatedScalarContainer* to) {
+ ScopedPyObjectPtr full_slice(PySlice_New(NULL, NULL, NULL));
+ if (full_slice == NULL) {
+ return -1;
+ }
+ ScopedPyObjectPtr values(Subscript(from, full_slice));
+ if (values == NULL) {
+ return -1;
+ }
+ google::protobuf::Message* new_message = global_message_factory->GetPrototype(
+ from->message->GetDescriptor())->New();
+ to->parent = NULL;
+ // TODO(anuraag): Document why it's OK to hang on to parent_field,
+ // even though it's a weak reference. It ought to be enough to
+ // hold on to the FieldDescriptor only.
+ to->parent_field = from->parent_field;
+ to->message = new_message;
+ to->owner.reset(new_message);
+ if (InternalAssignRepeatedField(to, values) < 0) {
+ return -1;
+ }
+ return 0;
+}
+
+int Release(RepeatedScalarContainer* self) {
+ return InitializeAndCopyToParentContainer(self, self);
+}
+
+PyObject* DeepCopy(RepeatedScalarContainer* self, PyObject* arg) {
+ ScopedPyObjectPtr init_args(
+ PyTuple_Pack(2, self->parent, self->parent_field));
+ PyObject* clone = PyObject_CallObject(
+ reinterpret_cast<PyObject*>(&RepeatedScalarContainer_Type), init_args);
+ if (clone == NULL) {
+ return NULL;
+ }
+ if (!PyObject_TypeCheck(clone, &RepeatedScalarContainer_Type)) {
+ Py_DECREF(clone);
+ return NULL;
+ }
+ if (InitializeAndCopyToParentContainer(
+ self, reinterpret_cast<RepeatedScalarContainer*>(clone)) < 0) {
+ Py_DECREF(clone);
+ return NULL;
+ }
+ return clone;
+}
+
+static void Dealloc(RepeatedScalarContainer* self) {
+ self->owner.reset();
+ Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
+}
+
+void SetOwner(RepeatedScalarContainer* self,
+ const shared_ptr<Message>& new_owner) {
+ self->owner = new_owner;
+}
+
+static PySequenceMethods SqMethods = {
+ (lenfunc)Len, /* sq_length */
+ 0, /* sq_concat */
+ 0, /* sq_repeat */
+ (ssizeargfunc)Item, /* sq_item */
+ 0, /* sq_slice */
+ (ssizeobjargproc)AssignItem /* sq_ass_item */
+};
+
+static PyMappingMethods MpMethods = {
+ (lenfunc)Len, /* mp_length */
+ (binaryfunc)Subscript, /* mp_subscript */
+ (objobjargproc)AssSubscript, /* mp_ass_subscript */
+};
+
+static PyMethodDef Methods[] = {
+ { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
+ "Makes a deep copy of the class." },
+ { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
+ "Outputs picklable representation of the repeated field." },
+ { "append", (PyCFunction)Append, METH_O,
+ "Appends an object to the repeated container." },
+ { "extend", (PyCFunction)Extend, METH_O,
+ "Appends objects to the repeated container." },
+ { "insert", (PyCFunction)Insert, METH_VARARGS,
+ "Appends objects to the repeated container." },
+ { "remove", (PyCFunction)Remove, METH_O,
+ "Removes an object from the repeated container." },
+ { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS,
+ "Sorts the repeated container."},
+ { NULL, NULL }
+};
+
+} // namespace repeated_scalar_container
+
+PyTypeObject RepeatedScalarContainer_Type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "google.protobuf.internal."
+ "cpp._message.RepeatedScalarContainer", // tp_name
+ sizeof(RepeatedScalarContainer), // tp_basicsize
+ 0, // tp_itemsize
+ (destructor)repeated_scalar_container::Dealloc, // tp_dealloc
+ 0, // tp_print
+ 0, // tp_getattr
+ 0, // tp_setattr
+ 0, // tp_compare
+ 0, // tp_repr
+ 0, // tp_as_number
+ &repeated_scalar_container::SqMethods, // tp_as_sequence
+ &repeated_scalar_container::MpMethods, // tp_as_mapping
+ 0, // tp_hash
+ 0, // tp_call
+ 0, // tp_str
+ 0, // tp_getattro
+ 0, // tp_setattro
+ 0, // tp_as_buffer
+ Py_TPFLAGS_DEFAULT, // tp_flags
+ "A Repeated scalar container", // tp_doc
+ 0, // tp_traverse
+ 0, // tp_clear
+ (richcmpfunc)repeated_scalar_container::RichCompare, // tp_richcompare
+ 0, // tp_weaklistoffset
+ 0, // tp_iter
+ 0, // tp_iternext
+ repeated_scalar_container::Methods, // tp_methods
+ 0, // tp_members
+ 0, // tp_getset
+ 0, // tp_base
+ 0, // tp_dict
+ 0, // tp_descr_get
+ 0, // tp_descr_set
+ 0, // tp_dictoffset
+ (initproc)repeated_scalar_container::Init, // tp_init
+};
+
+} // namespace python
+} // namespace protobuf
+} // namespace google
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h
new file mode 100644
index 0000000..8a30138
--- /dev/null
+++ b/python/google/protobuf/pyext/repeated_scalar_container.h
@@ -0,0 +1,112 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: anuraag@google.com (Anuraag Agrawal)
+// Author: tibell@google.com (Johan Tibell)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
+
+#include <Python.h>
+
+#include <memory>
+#ifndef _SHARED_PTR_H
+#include <google/protobuf/stubs/shared_ptr.h>
+#endif
+
+
+namespace google {
+namespace protobuf {
+
+class Message;
+
+using internal::shared_ptr;
+
+namespace python {
+
+struct CFieldDescriptor;
+struct CMessage;
+
+typedef struct RepeatedScalarContainer {
+ PyObject_HEAD;
+
+ // This is the top-level C++ Message object that owns the whole
+ // proto tree. Every Python RepeatedScalarContainer holds a
+ // reference to it in order to keep it alive as long as there's a
+ // Python object that references any part of the tree.
+ shared_ptr<Message> owner;
+
+ // Pointer to the C++ Message that contains this container. The
+ // RepeatedScalarContainer does not own this pointer.
+ Message* message;
+
+ // Weak reference to a parent CMessage object (i.e. may be NULL.)
+ //
+ // Used to make sure all ancestors are also mutable when first
+ // modifying the container.
+ CMessage* parent;
+
+ // Weak reference to the parent's descriptor that describes this
+ // field. Used together with the parent's message when making a
+ // default message instance mutable.
+ CFieldDescriptor* parent_field;
+} RepeatedScalarContainer;
+
+extern PyTypeObject RepeatedScalarContainer_Type;
+
+namespace repeated_scalar_container {
+
+// Appends the scalar 'item' to the end of the container 'self'.
+//
+// Returns None if successful; returns NULL and sets an exception if
+// unsuccessful.
+PyObject* Append(RepeatedScalarContainer* self, PyObject* item);
+
+// Releases the messages in the container to a new message.
+//
+// Returns 0 on success, -1 on failure.
+int Release(RepeatedScalarContainer* self);
+
+// Appends all the elements in the input iterator to the container.
+//
+// Returns None if successful; returns NULL and sets an exception if
+// unsuccessful.
+PyObject* Extend(RepeatedScalarContainer* self, PyObject* value);
+
+// Set the owner field of self and any children of self.
+void SetOwner(RepeatedScalarContainer* self,
+ const shared_ptr<Message>& new_owner);
+
+} // namespace repeated_scalar_container
+} // namespace python
+} // namespace protobuf
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
new file mode 100644
index 0000000..1b27a89
--- /dev/null
+++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h
@@ -0,0 +1,95 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: tibell@google.com (Johan Tibell)
+
+#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
+#define GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
+
+#include <Python.h>
+
+namespace google {
+class ScopedPyObjectPtr {
+ public:
+ // Constructor. Defaults to intializing with NULL.
+ // There is no way to create an uninitialized ScopedPyObjectPtr.
+ explicit ScopedPyObjectPtr(PyObject* p = NULL) : ptr_(p) { }
+
+ // Destructor. If there is a PyObject object, delete it.
+ ~ScopedPyObjectPtr() {
+ Py_XDECREF(ptr_);
+ }
+
+ // Reset. Deletes the current owned object, if any.
+ // Then takes ownership of a new object, if given.
+ // this->reset(this->get()) works.
+ PyObject* reset(PyObject* p = NULL) {
+ if (p != ptr_) {
+ Py_XDECREF(ptr_);
+ ptr_ = p;
+ }
+ return ptr_;
+ }
+
+ // Releases ownership of the object.
+ PyObject* release() {
+ PyObject* p = ptr_;
+ ptr_ = NULL;
+ return p;
+ }
+
+ operator PyObject*() { return ptr_; }
+
+ PyObject* operator->() const {
+ assert(ptr_ != NULL);
+ return ptr_;
+ }
+
+ PyObject* get() const { return ptr_; }
+
+ Py_ssize_t refcnt() const { return Py_REFCNT(ptr_); }
+
+ void inc() const { Py_INCREF(ptr_); }
+
+ // Comparison operators.
+ // These return whether a ScopedPyObjectPtr and a raw pointer
+ // refer to the same object, not just to two different but equal
+ // objects.
+ bool operator==(const PyObject* p) const { return ptr_ == p; }
+ bool operator!=(const PyObject* p) const { return ptr_ != p; }
+
+ private:
+ PyObject* ptr_;
+
+ GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(ScopedPyObjectPtr);
+};
+
+} // namespace google
+#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCOPED_PYOBJECT_PTR_H__
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 5b23803..7aac623 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -29,9 +29,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# This code is meant to work on Python 2.4 and above only.
-#
-# TODO(robinson): Helpers for verbose, common checks like seeing if a
-# descriptor's cpp_type is CPPTYPE_MESSAGE.
"""Contains a metaclass and helper functions used to create
protocol message classes from Descriptor objects at runtime.
@@ -50,27 +47,29 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
-try:
- from cStringIO import StringIO
-except ImportError:
- from StringIO import StringIO
-import struct
-import weakref
-# We use "as" to avoid name collisions with variables.
-from google.protobuf.internal import containers
-from google.protobuf.internal import decoder
-from google.protobuf.internal import encoder
-from google.protobuf.internal import message_listener as message_listener_mod
-from google.protobuf.internal import type_checkers
-from google.protobuf.internal import wire_format
+from google.protobuf.internal import api_implementation
from google.protobuf import descriptor as descriptor_mod
-from google.protobuf import message as message_mod
-from google.protobuf import text_format
+from google.protobuf import message
_FieldDescriptor = descriptor_mod.FieldDescriptor
+if api_implementation.Type() == 'cpp':
+ if api_implementation.Version() == 2:
+ from google.protobuf.pyext import cpp_message
+ _NewMessage = cpp_message.NewMessage
+ _InitMessage = cpp_message.InitMessage
+ else:
+ from google.protobuf.internal import cpp_message
+ _NewMessage = cpp_message.NewMessage
+ _InitMessage = cpp_message.InitMessage
+else:
+ from google.protobuf.internal import python_message
+ _NewMessage = python_message.NewMessage
+ _InitMessage = python_message.InitMessage
+
+
class GeneratedProtocolMessageType(type):
"""Metaclass for protocol message classes created at runtime from Descriptors.
@@ -92,6 +91,10 @@ class GeneratedProtocolMessageType(type):
myproto_instance = MyProtoClass()
myproto.foo_field = 23
...
+
+ The above example will not work for nested types. If you wish to include them,
+ use reflection.MakeClass() instead of manually instantiating the class in
+ order to create the appropriate class structure.
"""
# Must be consistent with the protocol-compiler code in
@@ -120,10 +123,12 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- _AddSlots(descriptor, dictionary)
- _AddClassAttributesForNestedExtensions(descriptor, dictionary)
+ bases = _NewMessage(bases, descriptor, dictionary)
superclass = super(GeneratedProtocolMessageType, cls)
- return superclass.__new__(cls, name, bases, dictionary)
+
+ new_class = superclass.__new__(cls, name, bases, dictionary)
+ setattr(descriptor, '_concrete_class', new_class)
+ return new_class
def __init__(cls, name, bases, dictionary):
"""Here we perform the majority of our work on the class.
@@ -143,1006 +148,58 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
-
- cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
- if (descriptor.has_options and
- descriptor.GetOptions().message_set_wire_format):
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number))
-
- # We act as a "friend" class of the descriptor, setting
- # its _concrete_class attribute the first time we use a
- # given descriptor to initialize a concrete protocol message
- # class. We also attach stuff to each FieldDescriptor for quick
- # lookup later on.
- concrete_class_attr_name = '_concrete_class'
- if not hasattr(descriptor, concrete_class_attr_name):
- setattr(descriptor, concrete_class_attr_name, cls)
- for field in descriptor.fields:
- _AttachFieldHelpers(cls, field)
-
- _AddEnumValues(descriptor, cls)
- _AddInitMethod(descriptor, cls)
- _AddPropertiesForFields(descriptor, cls)
- _AddPropertiesForExtensions(descriptor, cls)
- _AddStaticMethods(cls)
- _AddMessageMethods(descriptor, cls)
- _AddPrivateHelperMethods(cls)
+ _InitMessage(descriptor, cls)
superclass = super(GeneratedProtocolMessageType, cls)
superclass.__init__(name, bases, dictionary)
-# Stateless helpers for GeneratedProtocolMessageType below.
-# Outside clients should not access these directly.
-#
-# I opted not to make any of these methods on the metaclass, to make it more
-# clear that I'm not really using any state there and to keep clients from
-# thinking that they have direct access to these construction helpers.
-
-
-def _PropertyName(proto_field_name):
- """Returns the name of the public property attribute which
- clients can use to get and (in some cases) set the value
- of a protocol message field.
-
- Args:
- proto_field_name: The protocol message field name, exactly
- as it appears (or would appear) in a .proto file.
- """
- # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
- # nnorwitz makes my day by writing:
- # """
- # FYI. See the keyword module in the stdlib. This could be as simple as:
- #
- # if keyword.iskeyword(proto_field_name):
- # return proto_field_name + "_"
- # return proto_field_name
- # """
- # Kenton says: The above is a BAD IDEA. People rely on being able to use
- # getattr() and setattr() to reflectively manipulate field values. If we
- # rename the properties, then every such user has to also make sure to apply
- # the same transformation. Note that currently if you name a field "yield",
- # you can still access it just fine using getattr/setattr -- it's not even
- # that cumbersome to do so.
- # TODO(kenton): Remove this method entirely if/when everyone agrees with my
- # position.
- return proto_field_name
-
-
-def _VerifyExtensionHandle(message, extension_handle):
- """Verify that the given extension handle is valid."""
-
- if not isinstance(extension_handle, _FieldDescriptor):
- raise KeyError('HasExtension() expects an extension handle, got: %s' %
- extension_handle)
-
- if not extension_handle.is_extension:
- raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
-
- if extension_handle.containing_type is not message.DESCRIPTOR:
- raise KeyError('Extension "%s" extends message type "%s", but this '
- 'message is of type "%s".' %
- (extension_handle.full_name,
- extension_handle.containing_type.full_name,
- message.DESCRIPTOR.full_name))
-
-
-def _AddSlots(message_descriptor, dictionary):
- """Adds a __slots__ entry to dictionary, containing the names of all valid
- attributes for this message type.
-
- Args:
- message_descriptor: A Descriptor instance describing this message type.
- dictionary: Class dictionary to which we'll add a '__slots__' entry.
- """
- dictionary['__slots__'] = ['_cached_byte_size',
- '_cached_byte_size_dirty',
- '_fields',
- '_is_present_in_parent',
- '_listener',
- '_listener_for_children',
- '__weakref__']
-
-
-def _IsMessageSetExtension(field):
- return (field.is_extension and
- field.containing_type.has_options and
- field.containing_type.GetOptions().message_set_wire_format and
- field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == _FieldDescriptor.LABEL_OPTIONAL)
-
-
-def _AttachFieldHelpers(cls, field_descriptor):
- is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
- field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
- sizer = encoder.MessageSetItemSizer(field_descriptor.number)
- else:
- field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed)
- sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed)
-
- field_descriptor._encoder = field_encoder
- field_descriptor._sizer = sizer
- field_descriptor._default_constructor = _DefaultValueConstructorForField(
- field_descriptor)
-
- def AddDecoder(wiretype, is_packed):
- tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor))
-
- AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
- False)
-
- if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
- # To support wire compatibility of adding packed = true, add a decoder for
- # packed values regardless of the field's options.
- AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
-
-
-def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
- extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- assert extension_name not in dictionary
- dictionary[extension_name] = extension_field
-
-
-def _AddEnumValues(descriptor, cls):
- """Sets class-level attributes for all enum fields defined in this message.
-
- Args:
- descriptor: Descriptor object for this message type.
- cls: Class we're constructing for this message type.
- """
- for enum_type in descriptor.enum_types:
- for enum_value in enum_type.values:
- setattr(cls, enum_value.name, enum_value.number)
-
-
-def _DefaultValueConstructorForField(field):
- """Returns a function which returns a default value for a field.
+def ParseMessage(descriptor, byte_str):
+ """Generate a new Message instance from this Descriptor and a byte string.
Args:
- field: FieldDescriptor object for this field.
+ descriptor: Protobuf Descriptor object
+ byte_str: Serialized protocol buffer byte string
- The returned function has one argument:
- message: Message instance containing this field, or a weakref proxy
- of same.
-
- That function in turn returns a default value for this field. The default
- value may refer back to |message| via a weak reference.
- """
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- if field.default_value != []:
- raise ValueError('Repeated field default value not empty list: %s' % (
- field.default_value))
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- # We can't look at _concrete_class yet since it might not have
- # been set. (Depends on order in which we initialize the classes).
- message_type = field.message_type
- def MakeRepeatedMessageDefault(message):
- return containers.RepeatedCompositeFieldContainer(
- message._listener_for_children, field.message_type)
- return MakeRepeatedMessageDefault
- else:
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
- def MakeRepeatedScalarDefault(message):
- return containers.RepeatedScalarFieldContainer(
- message._listener_for_children, type_checker)
- return MakeRepeatedScalarDefault
-
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- # _concrete_class may not yet be initialized.
- message_type = field.message_type
- def MakeSubMessageDefault(message):
- result = message_type._concrete_class()
- result._SetListener(message._listener_for_children)
- return result
- return MakeSubMessageDefault
-
- def MakeScalarDefault(message):
- return field.default_value
- return MakeScalarDefault
-
-
-def _AddInitMethod(message_descriptor, cls):
- """Adds an __init__ method to cls."""
- fields = message_descriptor.fields
- def init(self, **kwargs):
- self._cached_byte_size = 0
- self._cached_byte_size_dirty = False
- self._fields = {}
- self._is_present_in_parent = False
- self._listener = message_listener_mod.NullMessageListener()
- self._listener_for_children = _Listener(self)
- for field_name, field_value in kwargs.iteritems():
- field = _GetFieldByName(message_descriptor, field_name)
- if field is None:
- raise TypeError("%s() got an unexpected keyword argument '%s'" %
- (message_descriptor.name, field_name))
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- copy = field._default_constructor(self)
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
- else: # Scalar
- copy.extend(field_value)
- self._fields[field] = copy
- elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- copy = field._default_constructor(self)
- copy.MergeFrom(field_value)
- self._fields[field] = copy
- else:
- self._fields[field] = field_value
-
- init.__module__ = None
- init.__doc__ = None
- cls.__init__ = init
-
-
-def _GetFieldByName(message_descriptor, field_name):
- """Returns a field descriptor by field name.
-
- Args:
- message_descriptor: A Descriptor describing all fields in message.
- field_name: The name of the field to retrieve.
Returns:
- The field descriptor associated with the field name.
- """
- try:
- return message_descriptor.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
-
-
-def _AddPropertiesForFields(descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- for field in descriptor.fields:
- _AddPropertiesForField(field, cls)
-
- if descriptor.is_extendable:
- # _ExtensionDict is just an adaptor with no state so we allocate a new one
- # every time it is accessed.
- cls.Extensions = property(lambda self: _ExtensionDict(self))
-
-
-def _AddPropertiesForField(field, cls):
- """Adds a public property for a protocol message field.
- Clients can use this property to get and (in the case
- of non-repeated scalar fields) directly set the value
- of a protocol message field.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- # Catch it if we add other types that we should
- # handle specially here.
- assert _FieldDescriptor.MAX_CPPTYPE == 10
-
- constant_name = field.name.upper() + "_FIELD_NUMBER"
- setattr(cls, constant_name, field.number)
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- _AddPropertiesForRepeatedField(field, cls)
- elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- _AddPropertiesForNonRepeatedCompositeField(field, cls)
- else:
- _AddPropertiesForNonRepeatedScalarField(field, cls)
-
-
-def _AddPropertiesForRepeatedField(field, cls):
- """Adds a public property for a "repeated" protocol message field. Clients
- can use this property to get the value of the field, which will be either a
- _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
- below).
-
- Note that when clients add values to these containers, we perform
- type-checking in the case of repeated scalar fields, and we also set any
- necessary "has" bits as a side-effect.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
-
- def getter(self):
- field_value = self._fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = field._default_constructor(self)
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- field_value = self._fields.setdefault(field, field_value)
- return field_value
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
-
- # We define a setter just so we can throw an exception with a more
- # helpful error message.
- def setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % proto_field_name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForNonRepeatedScalarField(field, cls):
- """Adds a public property for a nonrepeated, scalar protocol message field.
- Clients can use this property to get and directly set the value of the field.
- Note that when the client sets the value of a field by using this property,
- all necessary "has" bits are set as a side-effect, and we also perform
- type-checking.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
- default_value = field.default_value
-
- def getter(self):
- return self._fields.get(field, default_value)
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
- def setter(self, new_value):
- type_checker.CheckValue(new_value)
- self._fields[field] = new_value
- # Check _cached_byte_size_dirty inline to improve performance, since scalar
- # setters are called frequently.
- if not self._cached_byte_size_dirty:
- self._Modified()
- setter.__module__ = None
- setter.__doc__ = 'Setter for %s.' % proto_field_name
-
- # Add a property to encapsulate the getter/setter.
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForNonRepeatedCompositeField(field, cls):
- """Adds a public property for a nonrepeated, composite protocol message field.
- A composite field is a "group" or "message" field.
-
- Clients can use this property to get the value of the field, but cannot
- assign to the property directly.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
+ Newly created protobuf Message object.
"""
- # TODO(robinson): Remove duplication with similar method
- # for non-repeated scalars.
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
- message_type = field.message_type
-
- def getter(self):
- field_value = self._fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = message_type._concrete_class()
- field_value._SetListener(self._listener_for_children)
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- field_value = self._fields.setdefault(field, field_value)
- return field_value
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
-
- # We define a setter just so we can throw an exception with a more
- # helpful error message.
- def setter(self, new_value):
- raise AttributeError('Assignment not allowed to composite field '
- '"%s" in protocol message object.' % proto_field_name)
-
- # Add a property to encapsulate the getter.
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForExtensions(descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- constant_name = extension_name.upper() + "_FIELD_NUMBER"
- setattr(cls, constant_name, extension_field.number)
-
-
-def _AddStaticMethods(cls):
- # TODO(robinson): This probably needs to be thread-safe(?)
- def RegisterExtension(extension_handle):
- extension_handle.containing_type = cls.DESCRIPTOR
- _AttachFieldHelpers(cls, extension_handle)
-
- # Try to insert our extension, failing if an extension with the same number
- # already exists.
- actual_handle = cls._extensions_by_number.setdefault(
- extension_handle.number, extension_handle)
- if actual_handle is not extension_handle:
- raise AssertionError(
- 'Extensions "%s" and "%s" both try to extend message type "%s" with '
- 'field number %d.' %
- (extension_handle.full_name, actual_handle.full_name,
- cls.DESCRIPTOR.full_name, extension_handle.number))
-
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- handle = extension_handle # avoid line wrapping
- if _IsMessageSetExtension(handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
-
- cls.RegisterExtension = staticmethod(RegisterExtension)
-
- def FromString(s):
- message = cls()
- message.MergeFromString(s)
- return message
- cls.FromString = staticmethod(FromString)
-
-
-def _IsPresent(item):
- """Given a (FieldDescriptor, value) tuple from _fields, return true if the
- value should be included in the list returned by ListFields()."""
-
- if item[0].label == _FieldDescriptor.LABEL_REPEATED:
- return bool(item[1])
- elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- return item[1]._is_present_in_parent
- else:
- return True
-
-
-def _AddListFieldsMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def ListFields(self):
- all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
- all_fields.sort(key = lambda item: item[0].number)
- return all_fields
-
- cls.ListFields = ListFields
-
-
-def _AddHasFieldMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- singular_fields = {}
- for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
-
- def HasField(self, field_name):
- try:
- field = singular_fields[field_name]
- except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
-
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(field)
- return value is not None and value._is_present_in_parent
- else:
- return field in self._fields
- cls.HasField = HasField
+ result_class = MakeClass(descriptor)
+ new_msg = result_class()
+ new_msg.ParseFromString(byte_str)
+ return new_msg
-def _AddClearFieldMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def ClearField(self, field_name):
- try:
- field = message_descriptor.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+def MakeClass(descriptor):
+ """Construct a class object for a protobuf described by descriptor.
- if field in self._fields:
- # Note: If the field is a sub-message, its listener will still point
- # at us. That's fine, because the worst than can happen is that it
- # will call _Modified() and invalidate our byte size. Big deal.
- del self._fields[field]
+ Composite descriptors are handled by defining the new class as a member of the
+ parent class, recursing as deep as necessary.
+ This is the dynamic equivalent to:
- # Always call _Modified() -- even if nothing was changed, this is
- # a mutating method, and thus calling it should cause the field to become
- # present in the parent message.
- self._Modified()
-
- cls.ClearField = ClearField
-
-
-def _AddClearExtensionMethod(cls):
- """Helper for _AddMessageMethods()."""
- def ClearExtension(self, extension_handle):
- _VerifyExtensionHandle(self, extension_handle)
-
- # Similar to ClearField(), above.
- if extension_handle in self._fields:
- del self._fields[extension_handle]
- self._Modified()
- cls.ClearExtension = ClearExtension
-
-
-def _AddClearMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def Clear(self):
- # Clear fields.
- self._fields = {}
- self._Modified()
- cls.Clear = Clear
-
-
-def _AddHasExtensionMethod(cls):
- """Helper for _AddMessageMethods()."""
- def HasExtension(self, extension_handle):
- _VerifyExtensionHandle(self, extension_handle)
- if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
- raise KeyError('"%s" is repeated.' % extension_handle.full_name)
-
- if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(extension_handle)
- return value is not None and value._is_present_in_parent
- else:
- return extension_handle in self._fields
- cls.HasExtension = HasExtension
-
-
-def _AddEqualsMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def __eq__(self, other):
- if (not isinstance(other, message_mod.Message) or
- other.DESCRIPTOR != self.DESCRIPTOR):
- return False
-
- if self is other:
- return True
-
- return self.ListFields() == other.ListFields()
-
- cls.__eq__ = __eq__
-
-
-def _AddStrMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def __str__(self):
- return text_format.MessageToString(self)
- cls.__str__ = __str__
-
-
-def _AddSetListenerMethod(cls):
- """Helper for _AddMessageMethods()."""
- def SetListener(self, listener):
- if listener is None:
- self._listener = message_listener_mod.NullMessageListener()
- else:
- self._listener = listener
- cls._SetListener = SetListener
-
-
-def _BytesForNonRepeatedElement(value, field_number, field_type):
- """Returns the number of bytes needed to serialize a non-repeated element.
- The returned byte count includes space for tag information and any
- other additional space associated with serializing value.
+ class Parent(message.Message):
+ __metaclass__ = GeneratedProtocolMessageType
+ DESCRIPTOR = descriptor
+ class Child(message.Message):
+ __metaclass__ = GeneratedProtocolMessageType
+ DESCRIPTOR = descriptor.nested_types[0]
+
+ Sample usage:
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(proto2_string)
+ msg_descriptor = descriptor.MakeDescriptor(file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
Args:
- value: Value we're serializing.
- field_number: Field number of this value. (Since the field number
- is stored as part of a varint-encoded tag, this has an impact
- on the total bytes required to serialize the value).
- field_type: The type of the field. One of the TYPE_* constants
- within FieldDescriptor.
- """
- try:
- fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
- return fn(field_number, value)
- except KeyError:
- raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
-
-
-def _AddByteSizeMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def ByteSize(self):
- if not self._cached_byte_size_dirty:
- return self._cached_byte_size
-
- size = 0
- for field_descriptor, field_value in self.ListFields():
- size += field_descriptor._sizer(field_value)
-
- self._cached_byte_size = size
- self._cached_byte_size_dirty = False
- self._listener_for_children.dirty = False
- return size
-
- cls.ByteSize = ByteSize
-
-
-def _AddSerializeToStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def SerializeToString(self):
- # Check if the message has all of its required fields set.
- errors = []
- if not self.IsInitialized():
- raise message_mod.EncodeError(
- 'Message is missing required fields: ' +
- ','.join(self.FindInitializationErrors()))
- return self.SerializePartialToString()
- cls.SerializeToString = SerializeToString
-
-
-def _AddSerializePartialToStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def SerializePartialToString(self):
- out = StringIO()
- self._InternalSerialize(out.write)
- return out.getvalue()
- cls.SerializePartialToString = SerializePartialToString
-
- def InternalSerialize(self, write_bytes):
- for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value)
- cls._InternalSerialize = InternalSerialize
-
-
-def _AddMergeFromStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def MergeFromString(self, serialized):
- length = len(serialized)
- try:
- if self._InternalParse(serialized, 0, length) != length:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise message_mod.DecodeError('Unexpected end-group tag.')
- except IndexError:
- raise message_mod.DecodeError('Truncated message.')
- except struct.error, e:
- raise message_mod.DecodeError(e)
- return length # Return this for legacy reasons.
- cls.MergeFromString = MergeFromString
-
- local_ReadTag = decoder.ReadTag
- local_SkipField = decoder.SkipField
- decoders_by_tag = cls._decoders_by_tag
-
- def InternalParse(self, buffer, pos, end):
- self._Modified()
- field_dict = self._fields
- while pos != end:
- (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
- field_decoder = decoders_by_tag.get(tag_bytes)
- if field_decoder is None:
- new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
- if new_pos == -1:
- return pos
- pos = new_pos
- else:
- pos = field_decoder(buffer, new_pos, end, self, field_dict)
- return pos
- cls._InternalParse = InternalParse
-
-
-def _AddIsInitializedMethod(message_descriptor, cls):
- """Adds the IsInitialized and FindInitializationError methods to the
- protocol message class."""
-
- required_fields = [field for field in message_descriptor.fields
- if field.label == _FieldDescriptor.LABEL_REQUIRED]
-
- def IsInitialized(self, errors=None):
- """Checks if all required fields of a message are set.
-
- Args:
- errors: A list which, if provided, will be populated with the field
- paths of all missing required fields.
-
- Returns:
- True iff the specified message has all required fields set.
- """
-
- # Performance is critical so we avoid HasField() and ListFields().
-
- for field in required_fields:
- if (field not in self._fields or
- (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
- not self._fields[field]._is_present_in_parent)):
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
-
- for field, value in self._fields.iteritems():
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for element in value:
- if not element.IsInitialized():
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
- elif value._is_present_in_parent and not value.IsInitialized():
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
-
- return True
-
- cls.IsInitialized = IsInitialized
-
- def FindInitializationErrors(self):
- """Finds required fields which are not initialized.
-
- Returns:
- A list of strings. Each string is a path to an uninitialized field from
- the top-level message, e.g. "foo.bar[5].baz".
- """
-
- errors = [] # simplify things
-
- for field in required_fields:
- if not self.HasField(field.name):
- errors.append(field.name)
-
- for field, value in self.ListFields():
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.is_extension:
- name = "(%s)" % field.full_name
- else:
- name = field.name
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in xrange(len(value)):
- element = value[i]
- prefix = "%s[%d]." % (name, i)
- sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
- else:
- prefix = name + "."
- sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
-
- return errors
-
- cls.FindInitializationErrors = FindInitializationErrors
-
-
-def _AddMergeFromMethod(cls):
- LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
- CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
-
- def MergeFrom(self, msg):
- assert msg is not self
- self._Modified()
-
- fields = self._fields
-
- for field, value in msg._fields.iteritems():
- if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
- field_value = fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = field._default_constructor(self)
- fields[field] = field_value
- field_value.MergeFrom(value)
- else:
- self._fields[field] = value
- cls.MergeFrom = MergeFrom
-
-
-def _AddMessageMethods(message_descriptor, cls):
- """Adds implementations of all Message methods to cls."""
- _AddListFieldsMethod(message_descriptor, cls)
- _AddHasFieldMethod(message_descriptor, cls)
- _AddClearFieldMethod(message_descriptor, cls)
- if message_descriptor.is_extendable:
- _AddClearExtensionMethod(cls)
- _AddHasExtensionMethod(cls)
- _AddClearMethod(message_descriptor, cls)
- _AddEqualsMethod(message_descriptor, cls)
- _AddStrMethod(message_descriptor, cls)
- _AddSetListenerMethod(cls)
- _AddByteSizeMethod(message_descriptor, cls)
- _AddSerializeToStringMethod(message_descriptor, cls)
- _AddSerializePartialToStringMethod(message_descriptor, cls)
- _AddMergeFromStringMethod(message_descriptor, cls)
- _AddIsInitializedMethod(message_descriptor, cls)
- _AddMergeFromMethod(cls)
-
-
-def _AddPrivateHelperMethods(cls):
- """Adds implementation of private helper methods to cls."""
-
- def Modified(self):
- """Sets the _cached_byte_size_dirty bit to true,
- and propagates this to our listener iff this was a state change.
- """
-
- # Note: Some callers check _cached_byte_size_dirty before calling
- # _Modified() as an extra optimization. So, if this method is ever
- # changed such that it does stuff even when _cached_byte_size_dirty is
- # already true, the callers need to be updated.
- if not self._cached_byte_size_dirty:
- self._cached_byte_size_dirty = True
- self._listener_for_children.dirty = True
- self._is_present_in_parent = True
- self._listener.Modified()
-
- cls._Modified = Modified
- cls.SetInParent = Modified
-
-
-class _Listener(object):
-
- """MessageListener implementation that a parent message registers with its
- child message.
-
- In order to support semantics like:
-
- foo.bar.baz.qux = 23
- assert foo.HasField('bar')
-
- ...child objects must have back references to their parents.
- This helper class is at the heart of this support.
- """
-
- def __init__(self, parent_message):
- """Args:
- parent_message: The message whose _Modified() method we should call when
- we receive Modified() messages.
- """
- # This listener establishes a back reference from a child (contained) object
- # to its parent (containing) object. We make this a weak reference to avoid
- # creating cyclic garbage when the client finishes with the 'parent' object
- # in the tree.
- if isinstance(parent_message, weakref.ProxyType):
- self._parent_message_weakref = parent_message
- else:
- self._parent_message_weakref = weakref.proxy(parent_message)
-
- # As an optimization, we also indicate directly on the listener whether
- # or not the parent message is dirty. This way we can avoid traversing
- # up the tree in the common case.
- self.dirty = False
-
- def Modified(self):
- if self.dirty:
- return
- try:
- # Propagate the signal to our parents iff this is the first field set.
- self._parent_message_weakref._Modified()
- except ReferenceError:
- # We can get here if a client has kept a reference to a child object,
- # and is now setting a field on it, but the child's parent has been
- # garbage-collected. This is not an error.
- pass
-
-
-# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
-# TODO(robinson): Unify error handling of "unknown extension" crap.
-# TODO(robinson): Support iteritems()-style iteration over all
-# extensions with the "has" bits turned on?
-class _ExtensionDict(object):
-
- """Dict-like container for supporting an indexable "Extensions"
- field on proto instances.
-
- Note that in all cases we expect extension handles to be
- FieldDescriptors.
+ descriptor: A descriptor.Descriptor object describing the protobuf.
+ Returns:
+ The Message class object described by the descriptor.
"""
+ attributes = {}
+ for name, nested_type in descriptor.nested_types_by_name.items():
+ attributes[name] = MakeClass(nested_type)
- def __init__(self, extended_message):
- """extended_message: Message instance for which we are the Extensions dict.
- """
-
- self._extended_message = extended_message
-
- def __getitem__(self, extension_handle):
- """Returns the current value of the given extension handle."""
-
- _VerifyExtensionHandle(self._extended_message, extension_handle)
-
- result = self._extended_message._fields.get(extension_handle)
- if result is not None:
- return result
-
- if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
- result = extension_handle._default_constructor(self._extended_message)
- elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- result = extension_handle.message_type._concrete_class()
- try:
- result._SetListener(self._extended_message._listener_for_children)
- except ReferenceError:
- pass
- else:
- # Singular scalar -- just return the default without inserting into the
- # dict.
- return extension_handle.default_value
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- result = self._extended_message._fields.setdefault(
- extension_handle, result)
-
- return result
-
- def __eq__(self, other):
- if not isinstance(other, self.__class__):
- return False
+ attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor
- my_fields = self._extended_message.ListFields()
- other_fields = other._extended_message.ListFields()
-
- # Get rid of non-extension fields.
- my_fields = [ field for field in my_fields if field.is_extension ]
- other_fields = [ field for field in other_fields if field.is_extension ]
-
- return my_fields == other_fields
-
- def __ne__(self, other):
- return not self == other
-
- # Note that this is only meaningful for non-repeated, scalar extension
- # fields. Note also that we may have to call _Modified() when we do
- # successfully set a field this way, to set any necssary "has" bits in the
- # ancestors of the extended message.
- def __setitem__(self, extension_handle, value):
- """If extension_handle specifies a non-repeated, scalar extension
- field, sets the value of that field.
- """
-
- _VerifyExtensionHandle(self._extended_message, extension_handle)
-
- if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
- extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
- raise TypeError(
- 'Cannot assign to extension "%s" because it is a repeated or '
- 'composite type.' % extension_handle.full_name)
-
- # It's slightly wasteful to lookup the type checker each time,
- # but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle.cpp_type, extension_handle.type)
- type_checker.CheckValue(value)
- self._extended_message._fields[extension_handle] = value
- self._extended_message._Modified()
-
- def _FindExtensionByName(self, name):
- """Tries to find a known extension with the specified name.
-
- Args:
- name: Extension full name.
-
- Returns:
- Extension field descriptor.
- """
- return self._extended_message._extensions_by_name.get(name, None)
+ return GeneratedProtocolMessageType(str(descriptor.name), (message.Message,),
+ attributes)
diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py
new file mode 100644
index 0000000..7466fec
--- /dev/null
+++ b/python/google/protobuf/symbol_database.py
@@ -0,0 +1,185 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""A database of Python protocol buffer generated symbols.
+
+SymbolDatabase makes it easy to create new instances of a registered type, given
+only the type's protocol buffer symbol name. Once all symbols are registered,
+they can be accessed using either the MessageFactory interface which
+SymbolDatabase exposes, or the DescriptorPool interface of the underlying
+pool.
+
+Example usage:
+
+ db = symbol_database.SymbolDatabase()
+
+ # Register symbols of interest, from one or multiple files.
+ db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR)
+ db.RegisterMessage(my_proto_pb2.MyMessage)
+ db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR)
+
+ # The database can be used as a MessageFactory, to generate types based on
+ # their name:
+ types = db.GetMessages(['my_proto.proto'])
+ my_message_instance = types['MyMessage']()
+
+ # The database's underlying descriptor pool can be queried, so it's not
+ # necessary to know a type's filename to be able to generate it:
+ filename = db.pool.FindFileContainingSymbol('MyMessage')
+ my_message_instance = db.GetMessages([filename])['MyMessage']()
+
+ # This functionality is also provided directly via a convenience method:
+ my_message_instance = db.GetSymbol('MyMessage')()
+"""
+
+
+from google.protobuf import descriptor_pool
+
+
+class SymbolDatabase(object):
+ """A database of Python generated symbols.
+
+ SymbolDatabase also models message_factory.MessageFactory.
+
+ The symbol database can be used to keep a global registry of all protocol
+ buffer types used within a program.
+ """
+
+ def __init__(self):
+ """Constructor."""
+
+ self._symbols = {}
+ self._symbols_by_file = {}
+ self.pool = descriptor_pool.DescriptorPool()
+
+ def RegisterMessage(self, message):
+ """Registers the given message type in the local database.
+
+ Args:
+ message: a message.Message, to be registered.
+
+ Returns:
+ The provided message.
+ """
+
+ desc = message.DESCRIPTOR
+ self._symbols[desc.full_name] = message
+ if desc.file.name not in self._symbols_by_file:
+ self._symbols_by_file[desc.file.name] = {}
+ self._symbols_by_file[desc.file.name][desc.full_name] = message
+ self.pool.AddDescriptor(desc)
+ return message
+
+ def RegisterEnumDescriptor(self, enum_descriptor):
+ """Registers the given enum descriptor in the local database.
+
+ Args:
+ enum_descriptor: a descriptor.EnumDescriptor.
+
+ Returns:
+ The provided descriptor.
+ """
+ self.pool.AddEnumDescriptor(enum_descriptor)
+ return enum_descriptor
+
+ def RegisterFileDescriptor(self, file_descriptor):
+ """Registers the given file descriptor in the local database.
+
+ Args:
+ file_descriptor: a descriptor.FileDescriptor.
+
+ Returns:
+ The provided descriptor.
+ """
+ self.pool.AddFileDescriptor(file_descriptor)
+
+ def GetSymbol(self, symbol):
+ """Tries to find a symbol in the local database.
+
+ Currently, this method only returns message.Message instances, however, if
+ may be extended in future to support other symbol types.
+
+ Args:
+ symbol: A str, a protocol buffer symbol.
+
+ Returns:
+ A Python class corresponding to the symbol.
+
+ Raises:
+ KeyError: if the symbol could not be found.
+ """
+
+ return self._symbols[symbol]
+
+ def GetPrototype(self, descriptor):
+ """Builds a proto2 message class based on the passed in descriptor.
+
+ Passing a descriptor with a fully qualified name matching a previous
+ invocation will cause the same class to be returned.
+
+ Args:
+ descriptor: The descriptor to build from.
+
+ Returns:
+ A class describing the passed in descriptor.
+ """
+
+ return self.GetSymbol(descriptor.full_name)
+
+ def GetMessages(self, files):
+ """Gets all the messages from a specified file.
+
+ This will find and resolve dependencies, failing if they are not registered
+ in the symbol database.
+
+
+ Args:
+ files: The file names to extract messages from.
+
+ Returns:
+ A dictionary mapping proto names to the message classes. This will include
+ any dependent messages as well as any messages defined in the same file as
+ a specified message.
+
+ Raises:
+ KeyError: if a file could not be found.
+ """
+
+ result = {}
+ for f in files:
+ result.update(self._symbols_by_file[f])
+ return result
+
+_DEFAULT = SymbolDatabase()
+
+
+def Default():
+ """Returns the default SymbolDatabase."""
+ return _DEFAULT
diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py
new file mode 100644
index 0000000..ed0aabf
--- /dev/null
+++ b/python/google/protobuf/text_encoding.py
@@ -0,0 +1,110 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#PY25 compatible for GAE.
+#
+"""Encoding related utilities."""
+
+import re
+import sys ##PY25
+
+# Lookup table for utf8
+_cescape_utf8_to_str = [chr(i) for i in xrange(0, 256)]
+_cescape_utf8_to_str[9] = r'\t' # optional escape
+_cescape_utf8_to_str[10] = r'\n' # optional escape
+_cescape_utf8_to_str[13] = r'\r' # optional escape
+_cescape_utf8_to_str[39] = r"\'" # optional escape
+
+_cescape_utf8_to_str[34] = r'\"' # necessary escape
+_cescape_utf8_to_str[92] = r'\\' # necessary escape
+
+# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32)
+_cescape_byte_to_str = ([r'\%03o' % i for i in xrange(0, 32)] +
+ [chr(i) for i in xrange(32, 127)] +
+ [r'\%03o' % i for i in xrange(127, 256)])
+_cescape_byte_to_str[9] = r'\t' # optional escape
+_cescape_byte_to_str[10] = r'\n' # optional escape
+_cescape_byte_to_str[13] = r'\r' # optional escape
+_cescape_byte_to_str[39] = r"\'" # optional escape
+
+_cescape_byte_to_str[34] = r'\"' # necessary escape
+_cescape_byte_to_str[92] = r'\\' # necessary escape
+
+
+def CEscape(text, as_utf8):
+ """Escape a bytes string for use in an ascii protocol buffer.
+
+ text.encode('string_escape') does not seem to satisfy our needs as it
+ encodes unprintable characters using two-digit hex escapes whereas our
+ C++ unescaping function allows hex escapes to be any length. So,
+ "\0011".encode('string_escape') ends up being "\\x011", which will be
+ decoded in C++ as a single-character string with char code 0x11.
+
+ Args:
+ text: A byte string to be escaped
+ as_utf8: Specifies if result should be returned in UTF-8 encoding
+ Returns:
+ Escaped string
+ """
+ # PY3 hack: make Ord work for str and bytes:
+ # //platforms/networking/data uses unicode here, hence basestring.
+ Ord = ord if isinstance(text, basestring) else lambda x: x
+ if as_utf8:
+ return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text)
+ return ''.join(_cescape_byte_to_str[Ord(c)] for c in text)
+
+
+_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])')
+_cescape_highbit_to_str = ([chr(i) for i in range(0, 127)] +
+ [r'\%03o' % i for i in range(127, 256)])
+
+
+def CUnescape(text):
+ """Unescape a text string with C-style escape sequences to UTF-8 bytes."""
+
+ def ReplaceHex(m):
+ # Only replace the match if the number of leading back slashes is odd. i.e.
+ # the slash itself is not escaped.
+ if len(m.group(1)) & 1:
+ return m.group(1) + 'x0' + m.group(2)
+ return m.group(0)
+
+ # This is required because the 'string_escape' encoding doesn't
+ # allow single-digit hex escapes (like '\xf').
+ result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
+
+ if sys.version_info[0] < 3: ##PY25
+##!PY25 if str is bytes: # PY2
+ return result.decode('string_escape')
+ result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result)
+ return (result.encode('ascii') # Make it bytes to allow decode.
+ .decode('unicode_escape')
+ # Make it bytes again to return the proper type.
+ .encode('raw_unicode_escape'))
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index cc6ac90..50f76f2 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#PY25 compatible for GAE.
+#
+# Copyright 2007 Google Inc. All Rights Reserved.
+
"""Contains routines for printing protocol messages in text format."""
__author__ = 'kenton@google.com (Kenton Varda)'
@@ -35,46 +39,92 @@ __author__ = 'kenton@google.com (Kenton Varda)'
import cStringIO
import re
-from collections import deque
from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
+from google.protobuf import text_encoding
+
+__all__ = ['MessageToString', 'PrintMessage', 'PrintField',
+ 'PrintFieldValue', 'Merge']
-__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField',
- 'PrintFieldValue', 'Merge' ]
+_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
+ type_checkers.Int32ValueChecker(),
+ type_checkers.Uint64ValueChecker(),
+ type_checkers.Int64ValueChecker())
+_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
+_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
+_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
+ descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
-# Infinity and NaN are not explicitly supported by Python pre-2.6, and
-# float('inf') does not work on Windows (pre-2.6).
-_INFINITY = 1e10000 # overflows, thus will actually be infinity.
-_NAN = _INFINITY * 0
+class Error(Exception):
+ """Top-level module error for text_format."""
-class ParseError(Exception):
+
+class ParseError(Error):
"""Thrown in case of ASCII parsing error."""
-def MessageToString(message):
+def MessageToString(message, as_utf8=False, as_one_line=False,
+ pointy_brackets=False, use_index_order=False,
+ float_format=None):
+ """Convert protobuf message to text format.
+
+ Floating point values can be formatted compactly with 15 digits of
+ precision (which is the most that IEEE 754 "double" can guarantee)
+ using float_format='.15g'.
+
+ Args:
+ message: The protocol buffers message.
+ as_utf8: Produce text output in UTF8 format.
+ as_one_line: Don't introduce newlines between fields.
+ pointy_brackets: If True, use angle brackets instead of curly braces for
+ nesting.
+ use_index_order: If True, print fields of a proto message using the order
+ defined in source code instead of the field number. By default, use the
+ field number order.
+ float_format: If set, use this to specify floating point number formatting
+ (per the "Format Specification Mini-Language"); otherwise, str() is used.
+
+ Returns:
+ A string of the text formatted protocol buffer message.
+ """
out = cStringIO.StringIO()
- PrintMessage(message, out)
+ PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line,
+ pointy_brackets=pointy_brackets,
+ use_index_order=use_index_order,
+ float_format=float_format)
result = out.getvalue()
out.close()
+ if as_one_line:
+ return result.rstrip()
return result
-def PrintMessage(message, out, indent = 0):
- for field, value in message.ListFields():
+def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
+ pointy_brackets=False, use_index_order=False,
+ float_format=None):
+ fields = message.ListFields()
+ if use_index_order:
+ fields.sort(key=lambda x: x[0].index)
+ for field, value in fields:
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
for element in value:
- PrintField(field, element, out, indent)
+ PrintField(field, element, out, indent, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ float_format=float_format)
else:
- PrintField(field, value, out, indent)
+ PrintField(field, value, out, indent, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ float_format=float_format)
-def PrintField(field, value, out, indent = 0):
+def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
+ pointy_brackets=False, float_format=None):
"""Print a single field name/value pair. For repeated fields, the value
should be a single element."""
- out.write(' ' * indent);
+ out.write(' ' * indent)
if field.is_extension:
out.write('[')
if (field.containing_type.GetOptions().message_set_wire_format and
@@ -96,54 +146,168 @@ def PrintField(field, value, out, indent = 0):
# don't include it.
out.write(': ')
- PrintFieldValue(field, value, out, indent)
- out.write('\n')
+ PrintFieldValue(field, value, out, indent, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ float_format=float_format)
+ if as_one_line:
+ out.write(' ')
+ else:
+ out.write('\n')
-def PrintFieldValue(field, value, out, indent = 0):
+def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
+ as_one_line=False, pointy_brackets=False,
+ float_format=None):
"""Print a single field value (not including name). For repeated fields,
the value should be a single element."""
+ if pointy_brackets:
+ openb = '<'
+ closeb = '>'
+ else:
+ openb = '{'
+ closeb = '}'
+
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- out.write(' {\n')
- PrintMessage(value, out, indent + 2)
- out.write(' ' * indent + '}')
+ if as_one_line:
+ out.write(' %s ' % openb)
+ PrintMessage(value, out, indent, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ float_format=float_format)
+ out.write(closeb)
+ else:
+ out.write(' %s\n' % openb)
+ PrintMessage(value, out, indent + 2, as_utf8, as_one_line,
+ pointy_brackets=pointy_brackets,
+ float_format=float_format)
+ out.write(' ' * indent + closeb)
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- out.write(field.enum_type.values_by_number[value].name)
+ enum_value = field.enum_type.values_by_number.get(value, None)
+ if enum_value is not None:
+ out.write(enum_value.name)
+ else:
+ out.write(str(value))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
out.write('\"')
- out.write(_CEscape(value))
+ if isinstance(value, unicode):
+ out_value = value.encode('utf-8')
+ else:
+ out_value = value
+ if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ # We need to escape non-UTF8 chars in TYPE_BYTES field.
+ out_as_utf8 = False
+ else:
+ out_as_utf8 = as_utf8
+ out.write(text_encoding.CEscape(out_value, out_as_utf8))
out.write('\"')
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
if value:
- out.write("true")
+ out.write('true')
else:
- out.write("false")
+ out.write('false')
+ elif field.cpp_type in _FLOAT_TYPES and float_format is not None:
+ out.write('{1:{0}}'.format(float_format, value))
else:
out.write(str(value))
+def _ParseOrMerge(lines, message, allow_multiple_scalars):
+ """Converts an ASCII representation of a protocol message into a message.
+
+ Args:
+ lines: Lines of a message's ASCII representation.
+ message: A protocol buffer message to merge into.
+ allow_multiple_scalars: Determines if repeated values for a non-repeated
+ field are permitted, e.g., the string "foo: 1 foo: 2" for a
+ required/optional field named "foo".
+
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ tokenizer = _Tokenizer(lines)
+ while not tokenizer.AtEnd():
+ _MergeField(tokenizer, message, allow_multiple_scalars)
+
+
+def Parse(text, message):
+ """Parses an ASCII representation of a protocol message into a message.
+
+ Args:
+ text: Message ASCII representation.
+ message: A protocol buffer message to merge into.
+
+ Returns:
+ The same message passed as argument.
+
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ if not isinstance(text, str): text = text.decode('utf-8')
+ return ParseLines(text.split('\n'), message)
+
+
def Merge(text, message):
- """Merges an ASCII representation of a protocol message into a message.
+ """Parses an ASCII representation of a protocol message into a message.
+
+ Like Parse(), but allows repeated values for a non-repeated field, and uses
+ the last one.
Args:
text: Message ASCII representation.
message: A protocol buffer message to merge into.
+ Returns:
+ The same message passed as argument.
+
Raises:
ParseError: On ASCII parsing problems.
"""
- tokenizer = _Tokenizer(text)
- while not tokenizer.AtEnd():
- _MergeField(tokenizer, message)
+ return MergeLines(text.split('\n'), message)
+
+
+def ParseLines(lines, message):
+ """Parses an ASCII representation of a protocol message into a message.
+
+ Args:
+ lines: An iterable of lines of a message's ASCII representation.
+ message: A protocol buffer message to merge into.
+
+ Returns:
+ The same message passed as argument.
+
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ _ParseOrMerge(lines, message, False)
+ return message
+
+
+def MergeLines(lines, message):
+ """Parses an ASCII representation of a protocol message into a message.
+
+ Args:
+ lines: An iterable of lines of a message's ASCII representation.
+ message: A protocol buffer message to merge into.
+
+ Returns:
+ The same message passed as argument.
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ _ParseOrMerge(lines, message, True)
+ return message
-def _MergeField(tokenizer, message):
+
+def _MergeField(tokenizer, message, allow_multiple_scalars):
"""Merges a single protocol message field into a message.
Args:
tokenizer: A tokenizer to parse the field name and values.
message: A protocol message to record the data.
+ allow_multiple_scalars: Determines if repeated values for a non-repeated
+ field are permitted, e.g., the string "foo: 1 foo: 2" for a
+ required/optional field named "foo".
Raises:
ParseError: In case of ASCII parsing problems.
@@ -159,7 +323,9 @@ def _MergeField(tokenizer, message):
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" does not have extensions.' %
message_descriptor.full_name)
+ # pylint: disable=protected-access
field = message.Extensions._FindExtensionByName(name)
+ # pylint: enable=protected-access
if not field:
raise tokenizer.ParseErrorPreviousToken(
'Extension "%s" not registered.' % name)
@@ -208,23 +374,31 @@ def _MergeField(tokenizer, message):
sub_message = message.Extensions[field]
else:
sub_message = getattr(message, field.name)
- sub_message.SetInParent()
+ sub_message.SetInParent()
while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd():
raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
- _MergeField(tokenizer, sub_message)
+ _MergeField(tokenizer, sub_message, allow_multiple_scalars)
else:
- _MergeScalarField(tokenizer, message, field)
+ _MergeScalarField(tokenizer, message, field, allow_multiple_scalars)
+
+ # For historical reasons, fields may optionally be separated by commas or
+ # semicolons.
+ if not tokenizer.TryConsume(','):
+ tokenizer.TryConsume(';')
-def _MergeScalarField(tokenizer, message, field):
+def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars):
"""Merges a single protocol message scalar field into a message.
Args:
tokenizer: A tokenizer to parse the field value.
message: A protocol message to record the data.
field: The descriptor of the field to be merged.
+ allow_multiple_scalars: Determines if repeated values for a non-repeated
+ field are permitted, e.g., the string "foo: 1 foo: 2" for a
+ required/optional field named "foo".
Raises:
ParseError: In case of ASCII parsing problems.
@@ -257,24 +431,7 @@ def _MergeScalarField(tokenizer, message, field):
elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
value = tokenizer.ConsumeByteString()
elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
- # Enum can be specified by a number (the enum value), or by
- # a string literal (the enum name).
- enum_descriptor = field.enum_type
- if tokenizer.LookingAtInteger():
- number = tokenizer.ConsumeInt32()
- enum_value = enum_descriptor.values_by_number.get(number, None)
- if enum_value is None:
- raise tokenizer.ParseErrorPreviousToken(
- 'Enum type "%s" has no value with number %d.' % (
- enum_descriptor.full_name, number))
- else:
- identifier = tokenizer.ConsumeIdentifier()
- enum_value = enum_descriptor.values_by_name.get(identifier, None)
- if enum_value is None:
- raise tokenizer.ParseErrorPreviousToken(
- 'Enum type "%s" has no value named %s.' % (
- enum_descriptor.full_name, identifier))
- value = enum_value.number
+ value = tokenizer.ConsumeEnum(field)
else:
raise RuntimeError('Unknown field type %d' % field.type)
@@ -285,9 +442,19 @@ def _MergeScalarField(tokenizer, message, field):
getattr(message, field.name).append(value)
else:
if field.is_extension:
- message.Extensions[field] = value
+ if not allow_multiple_scalars and message.HasExtension(field):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" extensions.' %
+ (message.DESCRIPTOR.full_name, field.full_name))
+ else:
+ message.Extensions[field] = value
else:
- setattr(message, field.name, value)
+ if not allow_multiple_scalars and message.HasField(field.name):
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" should not have multiple "%s" fields.' %
+ (message.DESCRIPTOR.full_name, field.name))
+ else:
+ setattr(message, field.name, value)
class _Tokenizer(object):
@@ -305,26 +472,19 @@ class _Tokenizer(object):
'[0-9+-][0-9a-zA-Z_.+-]*|' # a number
'\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
'\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
- _IDENTIFIER = re.compile('\w+')
- _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(),
- type_checkers.Int32ValueChecker(),
- type_checkers.Uint64ValueChecker(),
- type_checkers.Int64ValueChecker()]
- _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE)
- _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE)
-
- def __init__(self, text_message):
- self._text_message = text_message
+ _IDENTIFIER = re.compile(r'\w+')
+ def __init__(self, lines):
self._position = 0
self._line = -1
self._column = 0
self._token_start = None
self.token = ''
- self._lines = deque(text_message.split('\n'))
+ self._lines = iter(lines)
self._current_line = ''
self._previous_line = 0
self._previous_column = 0
+ self._more_lines = True
self._SkipWhitespace()
self.NextToken()
@@ -334,25 +494,27 @@ class _Tokenizer(object):
Returns:
True iff the end was reached.
"""
- return not self._lines and not self._current_line
+ return not self.token
def _PopLine(self):
- while not self._current_line:
- if not self._lines:
+ while len(self._current_line) <= self._column:
+ try:
+ self._current_line = self._lines.next()
+ except StopIteration:
self._current_line = ''
+ self._more_lines = False
return
- self._line += 1
- self._column = 0
- self._current_line = self._lines.popleft()
+ else:
+ self._line += 1
+ self._column = 0
def _SkipWhitespace(self):
while True:
self._PopLine()
- match = re.match(self._WHITESPACE, self._current_line)
+ match = self._WHITESPACE.match(self._current_line, self._column)
if not match:
break
length = len(match.group(0))
- self._current_line = self._current_line[length:]
self._column += length
def TryConsume(self, token):
@@ -381,17 +543,6 @@ class _Tokenizer(object):
if not self.TryConsume(token):
raise self._ParseError('Expected "%s".' % token)
- def LookingAtInteger(self):
- """Checks if the current token is an integer.
-
- Returns:
- True iff the current token is an integer.
- """
- if not self.token:
- return False
- c = self.token[0]
- return (c >= '0' and c <= '9') or c == '-' or c == '+'
-
def ConsumeIdentifier(self):
"""Consumes protocol message field identifier.
@@ -402,7 +553,7 @@ class _Tokenizer(object):
ParseError: If an identifier couldn't be consumed.
"""
result = self.token
- if not re.match(self._IDENTIFIER, result):
+ if not self._IDENTIFIER.match(result):
raise self._ParseError('Expected identifier.')
self.NextToken()
return result
@@ -417,9 +568,9 @@ class _Tokenizer(object):
ParseError: If a signed 32bit integer couldn't be consumed.
"""
try:
- result = self._ParseInteger(self.token, is_signed=True, is_long=False)
+ result = ParseInteger(self.token, is_signed=True, is_long=False)
except ValueError, e:
- raise self._IntegerParseError(e)
+ raise self._ParseError(str(e))
self.NextToken()
return result
@@ -433,9 +584,9 @@ class _Tokenizer(object):
ParseError: If an unsigned 32bit integer couldn't be consumed.
"""
try:
- result = self._ParseInteger(self.token, is_signed=False, is_long=False)
+ result = ParseInteger(self.token, is_signed=False, is_long=False)
except ValueError, e:
- raise self._IntegerParseError(e)
+ raise self._ParseError(str(e))
self.NextToken()
return result
@@ -449,9 +600,9 @@ class _Tokenizer(object):
ParseError: If a signed 64bit integer couldn't be consumed.
"""
try:
- result = self._ParseInteger(self.token, is_signed=True, is_long=True)
+ result = ParseInteger(self.token, is_signed=True, is_long=True)
except ValueError, e:
- raise self._IntegerParseError(e)
+ raise self._ParseError(str(e))
self.NextToken()
return result
@@ -465,9 +616,9 @@ class _Tokenizer(object):
ParseError: If an unsigned 64bit integer couldn't be consumed.
"""
try:
- result = self._ParseInteger(self.token, is_signed=False, is_long=True)
+ result = ParseInteger(self.token, is_signed=False, is_long=True)
except ValueError, e:
- raise self._IntegerParseError(e)
+ raise self._ParseError(str(e))
self.NextToken()
return result
@@ -480,21 +631,10 @@ class _Tokenizer(object):
Raises:
ParseError: If a floating point number couldn't be consumed.
"""
- text = self.token
- if re.match(self._FLOAT_INFINITY, text):
- self.NextToken()
- if text.startswith('-'):
- return -_INFINITY
- return _INFINITY
-
- if re.match(self._FLOAT_NAN, text):
- self.NextToken()
- return _NAN
-
try:
- result = float(text)
+ result = ParseFloat(self.token)
except ValueError, e:
- raise self._FloatParseError(e)
+ raise self._ParseError(str(e))
self.NextToken()
return result
@@ -507,14 +647,12 @@ class _Tokenizer(object):
Raises:
ParseError: If a boolean value couldn't be consumed.
"""
- if self.token == 'true':
- self.NextToken()
- return True
- elif self.token == 'false':
- self.NextToken()
- return False
- else:
- raise self._ParseError('Expected "true" or "false".')
+ try:
+ result = ParseBool(self.token)
+ except ValueError, e:
+ raise self._ParseError(str(e))
+ self.NextToken()
+ return result
def ConsumeString(self):
"""Consumes a string value.
@@ -525,7 +663,11 @@ class _Tokenizer(object):
Raises:
ParseError: If a string value couldn't be consumed.
"""
- return unicode(self.ConsumeByteString(), 'utf-8')
+ the_bytes = self.ConsumeByteString()
+ try:
+ return unicode(the_bytes, 'utf-8')
+ except UnicodeDecodeError, e:
+ raise self._StringParseError(e)
def ConsumeByteString(self):
"""Consumes a byte array value.
@@ -536,10 +678,11 @@ class _Tokenizer(object):
Raises:
ParseError: If a byte array value couldn't be consumed.
"""
- list = [self._ConsumeSingleByteString()]
- while len(self.token) > 0 and self.token[0] in ('\'', '"'):
- list.append(self._ConsumeSingleByteString())
- return "".join(list)
+ the_list = [self._ConsumeSingleByteString()]
+ while self.token and self.token[0] in ('\'', '"'):
+ the_list.append(self._ConsumeSingleByteString())
+ return ''.encode('latin1').join(the_list) ##PY25
+##!PY25 return b''.join(the_list)
def _ConsumeSingleByteString(self):
"""Consume one token of a string literal.
@@ -550,48 +693,24 @@ class _Tokenizer(object):
"""
text = self.token
if len(text) < 1 or text[0] not in ('\'', '"'):
- raise self._ParseError('Exptected string.')
+ raise self._ParseError('Expected string.')
if len(text) < 2 or text[-1] != text[0]:
raise self._ParseError('String missing ending quote.')
try:
- result = _CUnescape(text[1:-1])
+ result = text_encoding.CUnescape(text[1:-1])
except ValueError, e:
raise self._ParseError(str(e))
self.NextToken()
return result
- def _ParseInteger(self, text, is_signed=False, is_long=False):
- """Parses an integer.
-
- Args:
- text: The text to parse.
- is_signed: True if a signed integer must be parsed.
- is_long: True if a long integer must be parsed.
-
- Returns:
- The integer value.
-
- Raises:
- ValueError: Thrown Iff the text is not a valid integer.
- """
- pos = 0
- if text.startswith('-'):
- pos += 1
-
- base = 10
- if text.startswith('0x', pos) or text.startswith('0X', pos):
- base = 16
- elif text.startswith('0', pos):
- base = 8
-
- # Do the actual parsing. Exception handling is propagated to caller.
- result = int(text, base)
-
- # Check if the integer is sane. Exceptions handled by callers.
- checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
- checker.CheckValue(result)
+ def ConsumeEnum(self, field):
+ try:
+ result = ParseEnum(field, self.token)
+ except ValueError, e:
+ raise self._ParseError(str(e))
+ self.NextToken()
return result
def ParseErrorPreviousToken(self, message):
@@ -611,63 +730,144 @@ class _Tokenizer(object):
return ParseError('%d:%d : %s' % (
self._line + 1, self._column + 1, message))
- def _IntegerParseError(self, e):
- return self._ParseError('Couldn\'t parse integer: ' + str(e))
-
- def _FloatParseError(self, e):
- return self._ParseError('Couldn\'t parse number: ' + str(e))
+ def _StringParseError(self, e):
+ return self._ParseError('Couldn\'t parse string: ' + str(e))
def NextToken(self):
"""Reads the next meaningful token."""
self._previous_line = self._line
self._previous_column = self._column
- if self.AtEnd():
- self.token = ''
- return
+
self._column += len(self.token)
+ self._SkipWhitespace()
- # Make sure there is data to work on.
- self._PopLine()
+ if not self._more_lines:
+ self.token = ''
+ return
- match = re.match(self._TOKEN, self._current_line)
+ match = self._TOKEN.match(self._current_line, self._column)
if match:
token = match.group(0)
- self._current_line = self._current_line[len(token):]
self.token = token
else:
- self.token = self._current_line[0]
- self._current_line = self._current_line[1:]
- self._SkipWhitespace()
+ self.token = self._current_line[self._column]
+
+
+def ParseInteger(text, is_signed=False, is_long=False):
+ """Parses an integer.
+
+ Args:
+ text: The text to parse.
+ is_signed: True if a signed integer must be parsed.
+ is_long: True if a long integer must be parsed.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ValueError: Thrown Iff the text is not a valid integer.
+ """
+ # Do the actual parsing. Exception handling is propagated to caller.
+ try:
+ # We force 32-bit values to int and 64-bit values to long to make
+ # alternate implementations where the distinction is more significant
+ # (e.g. the C++ implementation) simpler.
+ if is_long:
+ result = long(text, 0)
+ else:
+ result = int(text, 0)
+ except ValueError:
+ raise ValueError('Couldn\'t parse integer: %s' % text)
+
+ # Check if the integer is sane. Exceptions handled by callers.
+ checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
+ checker.CheckValue(result)
+ return result
+
+
+def ParseFloat(text):
+ """Parse a floating point number.
+
+ Args:
+ text: Text to parse.
+ Returns:
+ The number parsed.
-# text.encode('string_escape') does not seem to satisfy our needs as it
-# encodes unprintable characters using two-digit hex escapes whereas our
-# C++ unescaping function allows hex escapes to be any length. So,
-# "\0011".encode('string_escape') ends up being "\\x011", which will be
-# decoded in C++ as a single-character string with char code 0x11.
-def _CEscape(text):
- def escape(c):
- o = ord(c)
- if o == 10: return r"\n" # optional escape
- if o == 13: return r"\r" # optional escape
- if o == 9: return r"\t" # optional escape
- if o == 39: return r"\'" # optional escape
+ Raises:
+ ValueError: If a floating point number couldn't be parsed.
+ """
+ try:
+ # Assume Python compatible syntax.
+ return float(text)
+ except ValueError:
+ # Check alternative spellings.
+ if _FLOAT_INFINITY.match(text):
+ if text[0] == '-':
+ return float('-inf')
+ else:
+ return float('inf')
+ elif _FLOAT_NAN.match(text):
+ return float('nan')
+ else:
+ # assume '1.0f' format
+ try:
+ return float(text.rstrip('f'))
+ except ValueError:
+ raise ValueError('Couldn\'t parse float: %s' % text)
+
+
+def ParseBool(text):
+ """Parse a boolean value.
+
+ Args:
+ text: Text to parse.
+
+ Returns:
+ Boolean values parsed
- if o == 34: return r'\"' # necessary escape
- if o == 92: return r"\\" # necessary escape
+ Raises:
+ ValueError: If text is not a valid boolean.
+ """
+ if text in ('true', 't', '1'):
+ return True
+ elif text in ('false', 'f', '0'):
+ return False
+ else:
+ raise ValueError('Expected "true" or "false".')
- if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
- return c
- return "".join([escape(c) for c in text])
+def ParseEnum(field, value):
+ """Parse an enum value.
-_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])')
+ The value can be specified by a number (the enum value), or by
+ a string literal (the enum name).
+ Args:
+ field: Enum field descriptor.
+ value: String value.
-def _CUnescape(text):
- def ReplaceHex(m):
- return chr(int(m.group(0)[2:], 16))
- # This is required because the 'string_escape' encoding doesn't
- # allow single-digit hex escapes (like '\xf').
- result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
- return result.decode('string_escape')
+ Returns:
+ Enum value number.
+
+ Raises:
+ ValueError: If the enum value could not be parsed.
+ """
+ enum_descriptor = field.enum_type
+ try:
+ number = int(value, 0)
+ except ValueError:
+ # Identifier.
+ enum_value = enum_descriptor.values_by_name.get(value, None)
+ if enum_value is None:
+ raise ValueError(
+ 'Enum type "%s" has no value named %s.' % (
+ enum_descriptor.full_name, value))
+ else:
+ # Numeric value.
+ enum_value = enum_descriptor.values_by_number.get(number, None)
+ if enum_value is None:
+ raise ValueError(
+ 'Enum type "%s" has no value with number %d.' % (
+ enum_descriptor.full_name, number))
+ return enum_value.number
diff --git a/python/setup.py b/python/setup.py
index 7242dae..9441d0e 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -1,25 +1,41 @@
#! /usr/bin/python
#
# See README for usage instructions.
+import sys
+import os
+import subprocess
# We must use setuptools, not distutils, because we need to use the
# namespace_packages option for the "google" package.
-from ez_setup import use_setuptools
-use_setuptools()
-
-from setuptools import setup
+try:
+ from setuptools import setup, Extension
+except ImportError:
+ try:
+ from ez_setup import use_setuptools
+ use_setuptools()
+ from setuptools import setup, Extension
+ except ImportError:
+ sys.stderr.write(
+ "Could not import setuptools; make sure you have setuptools or "
+ "ez_setup installed.\n")
+ raise
+from distutils.command.clean import clean as _clean
+from distutils.command.build_py import build_py as _build_py
from distutils.spawn import find_executable
-import sys
-import os
-import subprocess
maintainer_email = "protobuf@googlegroups.com"
# Find the Protocol Compiler.
-if os.path.exists("../src/protoc"):
+if 'PROTOC' in os.environ and os.path.exists(os.environ['PROTOC']):
+ protoc = os.environ['PROTOC']
+elif os.path.exists("../src/protoc"):
protoc = "../src/protoc"
elif os.path.exists("../src/protoc.exe"):
protoc = "../src/protoc.exe"
+elif os.path.exists("../vsprojects/Debug/protoc.exe"):
+ protoc = "../vsprojects/Debug/protoc.exe"
+elif os.path.exists("../vsprojects/Release/protoc.exe"):
+ protoc = "../vsprojects/Release/protoc.exe"
else:
protoc = find_executable("protoc")
@@ -30,14 +46,14 @@ def generate_proto(source):
output = source.replace(".proto", "_pb2.py").replace("../src/", "")
- if not os.path.exists(source):
- print "Can't find required file: " + source
- sys.exit(-1)
-
if (not os.path.exists(output) or
(os.path.exists(source) and
os.path.getmtime(source) > os.path.getmtime(output))):
- print "Generating %s..." % output
+ print ("Generating %s..." % output)
+
+ if not os.path.exists(source):
+ sys.stderr.write("Can't find required file: %s\n" % source)
+ sys.exit(-1)
if protoc == None:
sys.stderr.write(
@@ -49,74 +65,131 @@ def generate_proto(source):
if subprocess.call(protoc_command) != 0:
sys.exit(-1)
-def MakeTestSuite():
- # This is apparently needed on some systems to make sure that the tests
- # work even if a previous version is already installed.
- if 'google' in sys.modules:
- del sys.modules['google']
-
+def GenerateUnittestProtos():
generate_proto("../src/google/protobuf/unittest.proto")
+ generate_proto("../src/google/protobuf/unittest_custom_options.proto")
generate_proto("../src/google/protobuf/unittest_import.proto")
+ generate_proto("../src/google/protobuf/unittest_import_public.proto")
generate_proto("../src/google/protobuf/unittest_mset.proto")
generate_proto("../src/google/protobuf/unittest_no_generic_services.proto")
+ generate_proto("google/protobuf/internal/descriptor_pool_test1.proto")
+ generate_proto("google/protobuf/internal/descriptor_pool_test2.proto")
+ generate_proto("google/protobuf/internal/test_bad_identifiers.proto")
+ generate_proto("google/protobuf/internal/missing_enum_values.proto")
generate_proto("google/protobuf/internal/more_extensions.proto")
+ generate_proto("google/protobuf/internal/more_extensions_dynamic.proto")
generate_proto("google/protobuf/internal/more_messages.proto")
+ generate_proto("google/protobuf/internal/factory_test1.proto")
+ generate_proto("google/protobuf/internal/factory_test2.proto")
+ generate_proto("google/protobuf/pyext/python.proto")
+def MakeTestSuite():
+ # Test C++ implementation
import unittest
- import google.protobuf.internal.generator_test as generator_test
- import google.protobuf.internal.descriptor_test as descriptor_test
- import google.protobuf.internal.reflection_test as reflection_test
- import google.protobuf.internal.service_reflection_test \
- as service_reflection_test
- import google.protobuf.internal.text_format_test as text_format_test
- import google.protobuf.internal.wire_format_test as wire_format_test
+ import google.protobuf.pyext.descriptor_cpp2_test as descriptor_cpp2_test
+ import google.protobuf.pyext.message_factory_cpp2_test \
+ as message_factory_cpp2_test
+ import google.protobuf.pyext.reflection_cpp2_generated_test \
+ as reflection_cpp2_generated_test
loader = unittest.defaultTestLoader
suite = unittest.TestSuite()
- for test in [ generator_test,
- descriptor_test,
- reflection_test,
- service_reflection_test,
- text_format_test,
- wire_format_test ]:
+ for test in [ descriptor_cpp2_test,
+ message_factory_cpp2_test,
+ reflection_cpp2_generated_test]:
suite.addTest(loader.loadTestsFromModule(test))
-
return suite
-if __name__ == '__main__':
- # TODO(kenton): Integrate this into setuptools somehow?
- if len(sys.argv) >= 2 and sys.argv[1] == "clean":
- # Delete generated _pb2.py files and .pyc files in the code tree.
+class clean(_clean):
+ def run(self):
+ # Delete generated files in the code tree.
for (dirpath, dirnames, filenames) in os.walk("."):
for filename in filenames:
filepath = os.path.join(dirpath, filename)
- if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"):
+ if filepath.endswith("_pb2.py") or filepath.endswith(".pyc") or \
+ filepath.endswith(".so") or filepath.endswith(".o") or \
+ filepath.endswith('google/protobuf/compiler/__init__.py'):
os.remove(filepath)
- else:
+ # _clean is an old-style class, so super() doesn't work.
+ _clean.run(self)
+
+class build_py(_build_py):
+ def run(self):
# Generate necessary .proto file if it doesn't exist.
- # TODO(kenton): Maybe we should hook this into a distutils command?
generate_proto("../src/google/protobuf/descriptor.proto")
+ generate_proto("../src/google/protobuf/compiler/plugin.proto")
+ GenerateUnittestProtos()
+
+ # Make sure google.protobuf/** are valid packages.
+ for path in ['', 'internal/', 'compiler/', 'pyext/']:
+ try:
+ open('google/protobuf/%s__init__.py' % path, 'a').close()
+ except EnvironmentError:
+ pass
+ # _build_py is an old-style class, so super() doesn't work.
+ _build_py.run(self)
+ # TODO(mrovner): Subclass to run 2to3 on some files only.
+ # Tracing what https://wiki.python.org/moin/PortingPythonToPy3k's "Approach 2"
+ # section on how to get 2to3 to run on source files during install under
+ # Python 3. This class seems like a good place to put logic that calls
+ # python3's distutils.util.run_2to3 on the subset of the files we have in our
+ # release that are subject to conversion.
+ # See code reference in previous code review.
+
+if __name__ == '__main__':
+ ext_module_list = []
+ cpp_impl = '--cpp_implementation'
+ if cpp_impl in sys.argv:
+ sys.argv.remove(cpp_impl)
+ # C++ implementation extension
+ ext_module_list.append(Extension(
+ "google.protobuf.pyext._message",
+ [ "google/protobuf/pyext/descriptor.cc",
+ "google/protobuf/pyext/message.cc",
+ "google/protobuf/pyext/extension_dict.cc",
+ "google/protobuf/pyext/repeated_scalar_container.cc",
+ "google/protobuf/pyext/repeated_composite_container.cc" ],
+ define_macros=[('GOOGLE_PROTOBUF_HAS_ONEOF', '1')],
+ include_dirs = [ ".", "../src"],
+ libraries = [ "protobuf" ],
+ library_dirs = [ '../src/.libs' ],
+ ))
setup(name = 'protobuf',
- version = '2.3.0',
+ version = '2.6.0',
packages = [ 'google' ],
namespace_packages = [ 'google' ],
test_suite = 'setup.MakeTestSuite',
+ google_test_dir = "google/protobuf/internal",
# Must list modules explicitly so that we don't install tests.
py_modules = [
+ 'google.protobuf.internal.api_implementation',
'google.protobuf.internal.containers',
+ 'google.protobuf.internal.cpp_message',
'google.protobuf.internal.decoder',
'google.protobuf.internal.encoder',
+ 'google.protobuf.internal.enum_type_wrapper',
'google.protobuf.internal.message_listener',
+ 'google.protobuf.internal.python_message',
'google.protobuf.internal.type_checkers',
'google.protobuf.internal.wire_format',
'google.protobuf.descriptor',
'google.protobuf.descriptor_pb2',
+ 'google.protobuf.compiler.plugin_pb2',
'google.protobuf.message',
+ 'google.protobuf.descriptor_database',
+ 'google.protobuf.descriptor_pool',
+ 'google.protobuf.message_factory',
'google.protobuf.reflection',
'google.protobuf.service',
'google.protobuf.service_reflection',
- 'google.protobuf.text_format' ],
+ 'google.protobuf.symbol_database',
+ 'google.protobuf.text_encoding',
+ 'google.protobuf.text_format'],
+ cmdclass = { 'clean': clean, 'build_py': build_py },
+ install_requires = ['setuptools'],
+ setup_requires = ['google-apputils'],
+ ext_modules = ext_module_list,
url = 'http://code.google.com/p/protobuf/',
maintainer = maintainer_email,
maintainer_email = 'protobuf@googlegroups.com',