summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authorWink Saville <wink@google.com>2010-05-27 16:25:37 -0700
committerWink Saville <wink@google.com>2010-05-27 16:25:37 -0700
commitfbaaef999ba563838ebd00874ed8a1c01fbf286d (patch)
tree24ff5c76344e90abc5b0fe6f07120ea0d2d011ee /python
parent79a4a60053f74ab71c7c3ec436d2f6caedc5be61 (diff)
downloadexternal_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.zip
external_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.tar.gz
external_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.tar.bz2
Add protobuf 2.2.0a sources
This is the contents of protobuf-2.2.0a.tar.bz2 from http://code.google.com/p/protobuf/downloads/list and is the base code for the javamicro code generator. Change-Id: Ie9a0440a824d615086445b6636944484b3155afa
Diffstat (limited to 'python')
-rw-r--r--python/README.txt73
-rwxr-xr-xpython/ez_setup.py281
-rwxr-xr-xpython/google/__init__.py1
-rwxr-xr-xpython/google/protobuf/__init__.py0
-rwxr-xr-xpython/google/protobuf/descriptor.py433
-rwxr-xr-xpython/google/protobuf/internal/__init__.py0
-rwxr-xr-xpython/google/protobuf/internal/containers.py227
-rwxr-xr-xpython/google/protobuf/internal/decoder.py209
-rwxr-xr-xpython/google/protobuf/internal/decoder_test.py256
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py113
-rwxr-xr-xpython/google/protobuf/internal/encoder.py280
-rwxr-xr-xpython/google/protobuf/internal/encoder_test.py286
-rwxr-xr-xpython/google/protobuf/internal/generator_test.py100
-rwxr-xr-xpython/google/protobuf/internal/input_stream.py338
-rwxr-xr-xpython/google/protobuf/internal/input_stream_test.py314
-rwxr-xr-xpython/google/protobuf/internal/message_listener.py69
-rwxr-xr-xpython/google/protobuf/internal/message_test.py53
-rw-r--r--python/google/protobuf/internal/more_extensions.proto58
-rw-r--r--python/google/protobuf/internal/more_messages.proto51
-rwxr-xr-xpython/google/protobuf/internal/output_stream.py125
-rwxr-xr-xpython/google/protobuf/internal/output_stream_test.py178
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py1976
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py136
-rwxr-xr-xpython/google/protobuf/internal/test_util.py612
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py397
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py287
-rwxr-xr-xpython/google/protobuf/internal/wire_format.py247
-rwxr-xr-xpython/google/protobuf/internal/wire_format_test.py253
-rwxr-xr-xpython/google/protobuf/message.py245
-rwxr-xr-xpython/google/protobuf/reflection.py1661
-rwxr-xr-xpython/google/protobuf/service.py222
-rwxr-xr-xpython/google/protobuf/service_reflection.py284
-rwxr-xr-xpython/google/protobuf/text_format.py650
-rwxr-xr-xpython/mox.py1401
-rwxr-xr-xpython/setup.py136
-rwxr-xr-xpython/stubout.py140
36 files changed, 12092 insertions, 0 deletions
diff --git a/python/README.txt b/python/README.txt
new file mode 100644
index 0000000..96f1a73
--- /dev/null
+++ b/python/README.txt
@@ -0,0 +1,73 @@
+Protocol Buffers - Google's data interchange format
+Copyright 2008 Google Inc.
+
+This directory contains the Python Protocol Buffers runtime library.
+
+Normally, this directory comes as part of the protobuf package, available
+from:
+
+ http://code.google.com/p/protobuf
+
+The complete package includes the C++ source code, which includes the
+Protocol Compiler (protoc). If you downloaded this package from PyPI
+or some other Python-specific source, you may have received only the
+Python part of the code. In this case, you will need to obtain the
+Protocol Compiler from some other source before you can use this
+package.
+
+Development Warning
+===================
+
+The Python implementation of Protocol Buffers is not as mature as the C++
+and Java implementations. It may be more buggy, and it is known to be
+pretty slow at this time. If you would like to help fix these issues,
+join the Protocol Buffers discussion list and let us know!
+
+Installation
+============
+
+1) Make sure you have Python 2.4 or newer. If in doubt, run:
+
+ $ python -V
+
+2) If you do not have setuptools installed, note that it will be
+ downloaded and installed automatically as soon as you run setup.py.
+ If you would rather install it manually, you may do so by following
+ the instructions on this page:
+
+ http://peak.telecommunity.com/DevCenter/EasyInstall#installation-instructions
+
+3) Build the C++ code, or install a binary distribution of protoc. If
+ you install a binary distribution, make sure that it is the same
+ version as this package. If in doubt, run:
+
+ $ protoc --version
+
+4) Run the tests:
+
+ $ python setup.py test
+
+ If some tests fail, this library may not work correctly on your
+ system. Continue at your own risk.
+
+ Please note that there is a known problem with some versions of
+ Python on Cygwin which causes the tests to fail after printing the
+ error: "sem_init: Resource temporarily unavailable". This appears
+ to be a bug either in Cygwin or in Python:
+ http://www.cygwin.com/ml/cygwin/2005-07/msg01378.html
+ We do not know if or when it might me fixed. We also do not know
+ how likely it is that this bug will affect users in practice.
+
+5) Install:
+
+ $ python setup.py install
+
+ This step may require superuser privileges.
+
+Usage
+=====
+
+The complete documentation for Protocol Buffers is available via the
+web at:
+
+ http://code.google.com/apis/protocolbuffers/
diff --git a/python/ez_setup.py b/python/ez_setup.py
new file mode 100755
index 0000000..b7a9849
--- /dev/null
+++ b/python/ez_setup.py
@@ -0,0 +1,281 @@
+#!python
+
+# This file was obtained from:
+# http://peak.telecommunity.com/dist/ez_setup.py
+# on 2009/4/17.
+
+"""Bootstrap setuptools installation
+
+If you want to use setuptools in your package's setup.py, just include this
+file in the same directory with it, and add this to the top of your setup.py::
+
+ from ez_setup import use_setuptools
+ use_setuptools()
+
+If you want to require a specific version of setuptools, set a download
+mirror, or use an alternate download directory, you can do so by supplying
+the appropriate options to ``use_setuptools()``.
+
+This file can also be run as a script to install or upgrade setuptools.
+"""
+import sys
+DEFAULT_VERSION = "0.6c9"
+DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
+
+md5_data = {
+ 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
+ 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
+ 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
+ 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
+ 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
+ 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
+ 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
+ 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
+ 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
+ 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
+ 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
+ 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
+ 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
+ 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
+ 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
+ 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
+ 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
+ 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
+ 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
+ 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
+ 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
+ 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
+ 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
+ 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
+ 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
+ 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
+ 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
+ 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
+ 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
+ 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
+ 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
+ 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
+ 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
+ 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
+}
+
+import sys, os
+try: from hashlib import md5
+except ImportError: from md5 import md5
+
+def _validate_md5(egg_name, data):
+ if egg_name in md5_data:
+ digest = md5(data).hexdigest()
+ if digest != md5_data[egg_name]:
+ print >>sys.stderr, (
+ "md5 validation of %s failed! (Possible download problem?)"
+ % egg_name
+ )
+ sys.exit(2)
+ return data
+
+def use_setuptools(
+ version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+ download_delay=15
+):
+ """Automatically find/download setuptools and make it available on sys.path
+
+ `version` should be a valid setuptools version number that is available
+ as an egg for download under the `download_base` URL (which should end with
+ a '/'). `to_dir` is the directory where setuptools will be downloaded, if
+ it is not already available. If `download_delay` is specified, it should
+ be the number of seconds that will be paused before initiating a download,
+ should one be required. If an older version of setuptools is installed,
+ this routine will print a message to ``sys.stderr`` and raise SystemExit in
+ an attempt to abort the calling script.
+ """
+ was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
+ def do_download():
+ egg = download_setuptools(version, download_base, to_dir, download_delay)
+ sys.path.insert(0, egg)
+ import setuptools; setuptools.bootstrap_install_from = egg
+ try:
+ import pkg_resources
+ except ImportError:
+ return do_download()
+ try:
+ pkg_resources.require("setuptools>="+version); return
+ except pkg_resources.VersionConflict, e:
+ if was_imported:
+ print >>sys.stderr, (
+ "The required version of setuptools (>=%s) is not available, and\n"
+ "can't be installed while this script is running. Please install\n"
+ " a more recent version first, using 'easy_install -U setuptools'."
+ "\n\n(Currently using %r)"
+ ) % (version, e.args[0])
+ sys.exit(2)
+ else:
+ del pkg_resources, sys.modules['pkg_resources'] # reload ok
+ return do_download()
+ except pkg_resources.DistributionNotFound:
+ return do_download()
+
+def download_setuptools(
+ version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
+ delay = 15
+):
+ """Download setuptools from a specified location and return its filename
+
+ `version` should be a valid setuptools version number that is available
+ as an egg for download under the `download_base` URL (which should end
+ with a '/'). `to_dir` is the directory where the egg will be downloaded.
+ `delay` is the number of seconds to pause before an actual download attempt.
+ """
+ import urllib2, shutil
+ egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
+ url = download_base + egg_name
+ saveto = os.path.join(to_dir, egg_name)
+ src = dst = None
+ if not os.path.exists(saveto): # Avoid repeated downloads
+ try:
+ from distutils import log
+ if delay:
+ log.warn("""
+---------------------------------------------------------------------------
+This script requires setuptools version %s to run (even to display
+help). I will attempt to download it for you (from
+%s), but
+you may need to enable firewall access for this script first.
+I will start the download in %d seconds.
+
+(Note: if this machine does not have network access, please obtain the file
+
+ %s
+
+and place it in this directory before rerunning this script.)
+---------------------------------------------------------------------------""",
+ version, download_base, delay, url
+ ); from time import sleep; sleep(delay)
+ log.warn("Downloading %s", url)
+ src = urllib2.urlopen(url)
+ # Read/write all in one block, so we don't create a corrupt file
+ # if the download is interrupted.
+ data = _validate_md5(egg_name, src.read())
+ dst = open(saveto,"wb"); dst.write(data)
+ finally:
+ if src: src.close()
+ if dst: dst.close()
+ return os.path.realpath(saveto)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+def main(argv, version=DEFAULT_VERSION):
+ """Install or upgrade setuptools and EasyInstall"""
+ try:
+ import setuptools
+ except ImportError:
+ egg = None
+ try:
+ egg = download_setuptools(version, delay=0)
+ sys.path.insert(0,egg)
+ from setuptools.command.easy_install import main
+ return main(list(argv)+[egg]) # we're done here
+ finally:
+ if egg and os.path.exists(egg):
+ os.unlink(egg)
+ else:
+ if setuptools.__version__ == '0.0.1':
+ print >>sys.stderr, (
+ "You have an obsolete version of setuptools installed. Please\n"
+ "remove it from your system entirely before rerunning this script."
+ )
+ sys.exit(2)
+
+ req = "setuptools>="+version
+ import pkg_resources
+ try:
+ pkg_resources.require(req)
+ except pkg_resources.VersionConflict:
+ try:
+ from setuptools.command.easy_install import main
+ except ImportError:
+ from easy_install import main
+ main(list(argv)+[download_setuptools(delay=0)])
+ sys.exit(0) # try to force an exit
+ else:
+ if argv:
+ from setuptools.command.easy_install import main
+ main(argv)
+ else:
+ print "Setuptools version",version,"or greater has been installed."
+ print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
+
+def update_md5(filenames):
+ """Update our built-in md5 registry"""
+
+ import re
+
+ for name in filenames:
+ base = os.path.basename(name)
+ f = open(name,'rb')
+ md5_data[base] = md5(f.read()).hexdigest()
+ f.close()
+
+ data = [" %r: %r,\n" % it for it in md5_data.items()]
+ data.sort()
+ repl = "".join(data)
+
+ import inspect
+ srcfile = inspect.getsourcefile(sys.modules[__name__])
+ f = open(srcfile, 'rb'); src = f.read(); f.close()
+
+ match = re.search("\nmd5_data = {\n([^}]+)}", src)
+ if not match:
+ print >>sys.stderr, "Internal error!"
+ sys.exit(2)
+
+ src = src[:match.start(1)] + repl + src[match.end(1):]
+ f = open(srcfile,'w')
+ f.write(src)
+ f.close()
+
+
+if __name__=='__main__':
+ if len(sys.argv)>2 and sys.argv[1]=='--md5update':
+ update_md5(sys.argv[2:])
+ else:
+ main(sys.argv[1:])
+
+
+
+
+
+
diff --git a/python/google/__init__.py b/python/google/__init__.py
new file mode 100755
index 0000000..de40ea7
--- /dev/null
+++ b/python/google/__init__.py
@@ -0,0 +1 @@
+__import__('pkg_resources').declare_namespace(__name__)
diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/python/google/protobuf/__init__.py
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
new file mode 100755
index 0000000..8e3fc2e
--- /dev/null
+++ b/python/google/protobuf/descriptor.py
@@ -0,0 +1,433 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# TODO(robinson): We probably need to provide deep-copy methods for
+# descriptor types. When a FieldDescriptor is passed into
+# Descriptor.__init__(), we should make a deep copy and then set
+# containing_type on it. Alternatively, we could just get
+# rid of containing_type (iit's not needed for reflection.py, at least).
+#
+# TODO(robinson): Print method?
+#
+# TODO(robinson): Useful __repr__?
+
+"""Descriptors essentially contain exactly the information found in a .proto
+file, in types that make this information accessible in Python.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+class DescriptorBase(object):
+
+ """Descriptors base class.
+
+ This class is the base of all descriptor classes. It provides common options
+ related functionaility.
+ """
+
+ def __init__(self, options, options_class_name):
+ """Initialize the descriptor given its options message and the name of the
+ class of the options message. The name of the class is required in case
+ the options message is None and has to be created.
+ """
+ self._options = options
+ self._options_class_name = options_class_name
+
+ def GetOptions(self):
+ """Retrieves descriptor options.
+
+ This method returns the options set or creates the default options for the
+ descriptor.
+ """
+ if self._options:
+ return self._options
+ from google.protobuf import descriptor_pb2
+ try:
+ options_class = getattr(descriptor_pb2, self._options_class_name)
+ except AttributeError:
+ raise RuntimeError('Unknown options class name %s!' %
+ (self._options_class_name))
+ self._options = options_class()
+ return self._options
+
+
+class Descriptor(DescriptorBase):
+
+ """Descriptor for a protocol message type.
+
+ A Descriptor instance has the following attributes:
+
+ name: (str) Name of this protocol message type.
+ full_name: (str) Fully-qualified name of this protocol message type,
+ which will include protocol "package" name and the name of any
+ enclosing types.
+
+ filename: (str) Name of the .proto file containing this message.
+
+ containing_type: (Descriptor) Reference to the descriptor of the
+ type containing us, or None if we have no containing type.
+
+ fields: (list of FieldDescriptors) Field descriptors for all
+ fields in this type.
+ fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor
+ objects as in |fields|, but indexed by "number" attribute in each
+ FieldDescriptor.
+ fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor
+ objects as in |fields|, but indexed by "name" attribute in each
+ FieldDescriptor.
+
+ nested_types: (list of Descriptors) Descriptor references
+ for all protocol message types nested within this one.
+ nested_types_by_name: (dict str -> Descriptor) Same Descriptor
+ objects as in |nested_types|, but indexed by "name" attribute
+ in each Descriptor.
+
+ enum_types: (list of EnumDescriptors) EnumDescriptor references
+ for all enums contained within this type.
+ enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor
+ objects as in |enum_types|, but indexed by "name" attribute
+ in each EnumDescriptor.
+ enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping
+ from enum value name to EnumValueDescriptor for that value.
+
+ extensions: (list of FieldDescriptor) All extensions defined directly
+ within this message type (NOT within a nested type).
+ extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor
+ objects as |extensions|, but indexed by "name" attribute of each
+ FieldDescriptor.
+
+ options: (descriptor_pb2.MessageOptions) Protocol message options or None
+ to use default message options.
+ """
+
+ def __init__(self, name, full_name, filename, containing_type,
+ fields, nested_types, enum_types, extensions, options=None):
+ """Arguments to __init__() are as described in the description
+ of Descriptor fields above.
+ """
+ super(Descriptor, self).__init__(options, 'MessageOptions')
+ self.name = name
+ self.full_name = full_name
+ self.filename = filename
+ self.containing_type = containing_type
+
+ # We have fields in addition to fields_by_name and fields_by_number,
+ # so that:
+ # 1. Clients can index fields by "order in which they're listed."
+ # 2. Clients can easily iterate over all fields with the terse
+ # syntax: for f in descriptor.fields: ...
+ self.fields = fields
+ for field in self.fields:
+ field.containing_type = self
+ self.fields_by_number = dict((f.number, f) for f in fields)
+ self.fields_by_name = dict((f.name, f) for f in fields)
+
+ self.nested_types = nested_types
+ self.nested_types_by_name = dict((t.name, t) for t in nested_types)
+
+ self.enum_types = enum_types
+ for enum_type in self.enum_types:
+ enum_type.containing_type = self
+ self.enum_types_by_name = dict((t.name, t) for t in enum_types)
+ self.enum_values_by_name = dict(
+ (v.name, v) for t in enum_types for v in t.values)
+
+ self.extensions = extensions
+ for extension in self.extensions:
+ extension.extension_scope = self
+ self.extensions_by_name = dict((f.name, f) for f in extensions)
+
+
+# TODO(robinson): We should have aggressive checking here,
+# for example:
+# * If you specify a repeated field, you should not be allowed
+# to specify a default value.
+# * [Other examples here as needed].
+#
+# TODO(robinson): for this and other *Descriptor classes, we
+# might also want to lock things down aggressively (e.g.,
+# prevent clients from setting the attributes). Having
+# stronger invariants here in general will reduce the number
+# of runtime checks we must do in reflection.py...
+class FieldDescriptor(DescriptorBase):
+
+ """Descriptor for a single field in a .proto file.
+
+ A FieldDescriptor instance has the following attriubtes:
+
+ name: (str) Name of this field, exactly as it appears in .proto.
+ full_name: (str) Name of this field, including containing scope. This is
+ particularly relevant for extensions.
+ index: (int) Dense, 0-indexed index giving the order that this
+ field textually appears within its message in the .proto file.
+ number: (int) Tag number declared for this field in the .proto file.
+
+ type: (One of the TYPE_* constants below) Declared type.
+ cpp_type: (One of the CPPTYPE_* constants below) C++ type used to
+ represent this field.
+
+ label: (One of the LABEL_* constants below) Tells whether this
+ field is optional, required, or repeated.
+ default_value: (Varies) Default value of this field. Only
+ meaningful for non-repeated scalar fields. Repeated fields
+ should always set this to [], and non-repeated composite
+ fields should always set this to None.
+
+ containing_type: (Descriptor) Descriptor of the protocol message
+ type that contains this field. Set by the Descriptor constructor
+ if we're passed into one.
+ Somewhat confusingly, for extension fields, this is the
+ descriptor of the EXTENDED message, not the descriptor
+ of the message containing this field. (See is_extension and
+ extension_scope below).
+ message_type: (Descriptor) If a composite field, a descriptor
+ of the message type contained in this field. Otherwise, this is None.
+ enum_type: (EnumDescriptor) If this field contains an enum, a
+ descriptor of that enum. Otherwise, this is None.
+
+ is_extension: True iff this describes an extension field.
+ extension_scope: (Descriptor) Only meaningful if is_extension is True.
+ Gives the message that immediately contains this extension field.
+ Will be None iff we're a top-level (file-level) extension field.
+
+ options: (descriptor_pb2.FieldOptions) Protocol message field options or
+ None to use default field options.
+ """
+
+ # Must be consistent with C++ FieldDescriptor::Type enum in
+ # descriptor.h.
+ #
+ # TODO(robinson): Find a way to eliminate this repetition.
+ TYPE_DOUBLE = 1
+ TYPE_FLOAT = 2
+ TYPE_INT64 = 3
+ TYPE_UINT64 = 4
+ TYPE_INT32 = 5
+ TYPE_FIXED64 = 6
+ TYPE_FIXED32 = 7
+ TYPE_BOOL = 8
+ TYPE_STRING = 9
+ TYPE_GROUP = 10
+ TYPE_MESSAGE = 11
+ TYPE_BYTES = 12
+ TYPE_UINT32 = 13
+ TYPE_ENUM = 14
+ TYPE_SFIXED32 = 15
+ TYPE_SFIXED64 = 16
+ TYPE_SINT32 = 17
+ TYPE_SINT64 = 18
+ MAX_TYPE = 18
+
+ # Must be consistent with C++ FieldDescriptor::CppType enum in
+ # descriptor.h.
+ #
+ # TODO(robinson): Find a way to eliminate this repetition.
+ CPPTYPE_INT32 = 1
+ CPPTYPE_INT64 = 2
+ CPPTYPE_UINT32 = 3
+ CPPTYPE_UINT64 = 4
+ CPPTYPE_DOUBLE = 5
+ CPPTYPE_FLOAT = 6
+ CPPTYPE_BOOL = 7
+ CPPTYPE_ENUM = 8
+ CPPTYPE_STRING = 9
+ CPPTYPE_MESSAGE = 10
+ MAX_CPPTYPE = 10
+
+ # Must be consistent with C++ FieldDescriptor::Label enum in
+ # descriptor.h.
+ #
+ # TODO(robinson): Find a way to eliminate this repetition.
+ LABEL_OPTIONAL = 1
+ LABEL_REQUIRED = 2
+ LABEL_REPEATED = 3
+ MAX_LABEL = 3
+
+ def __init__(self, name, full_name, index, number, type, cpp_type, label,
+ default_value, message_type, enum_type, containing_type,
+ is_extension, extension_scope, options=None):
+ """The arguments are as described in the description of FieldDescriptor
+ attributes above.
+
+ Note that containing_type may be None, and may be set later if necessary
+ (to deal with circular references between message types, for example).
+ Likewise for extension_scope.
+ """
+ super(FieldDescriptor, self).__init__(options, 'FieldOptions')
+ self.name = name
+ self.full_name = full_name
+ self.index = index
+ self.number = number
+ self.type = type
+ self.cpp_type = cpp_type
+ self.label = label
+ self.default_value = default_value
+ self.containing_type = containing_type
+ self.message_type = message_type
+ self.enum_type = enum_type
+ self.is_extension = is_extension
+ self.extension_scope = extension_scope
+
+
+class EnumDescriptor(DescriptorBase):
+
+ """Descriptor for an enum defined in a .proto file.
+
+ An EnumDescriptor instance has the following attributes:
+
+ name: (str) Name of the enum type.
+ full_name: (str) Full name of the type, including package name
+ and any enclosing type(s).
+ filename: (str) Name of the .proto file in which this appears.
+
+ values: (list of EnumValueDescriptors) List of the values
+ in this enum.
+ values_by_name: (dict str -> EnumValueDescriptor) Same as |values|,
+ but indexed by the "name" field of each EnumValueDescriptor.
+ values_by_number: (dict int -> EnumValueDescriptor) Same as |values|,
+ but indexed by the "number" field of each EnumValueDescriptor.
+ containing_type: (Descriptor) Descriptor of the immediate containing
+ type of this enum, or None if this is an enum defined at the
+ top level in a .proto file. Set by Descriptor's constructor
+ if we're passed into one.
+ options: (descriptor_pb2.EnumOptions) Enum options message or
+ None to use default enum options.
+ """
+
+ def __init__(self, name, full_name, filename, values,
+ containing_type=None, options=None):
+ """Arguments are as described in the attribute description above."""
+ super(EnumDescriptor, self).__init__(options, 'EnumOptions')
+ self.name = name
+ self.full_name = full_name
+ self.filename = filename
+ self.values = values
+ for value in self.values:
+ value.type = self
+ self.values_by_name = dict((v.name, v) for v in values)
+ self.values_by_number = dict((v.number, v) for v in values)
+ self.containing_type = containing_type
+
+
+class EnumValueDescriptor(DescriptorBase):
+
+ """Descriptor for a single value within an enum.
+
+ name: (str) Name of this value.
+ index: (int) Dense, 0-indexed index giving the order that this
+ value appears textually within its enum in the .proto file.
+ number: (int) Actual number assigned to this enum value.
+ type: (EnumDescriptor) EnumDescriptor to which this value
+ belongs. Set by EnumDescriptor's constructor if we're
+ passed into one.
+ options: (descriptor_pb2.EnumValueOptions) Enum value options message or
+ None to use default enum value options options.
+ """
+
+ def __init__(self, name, index, number, type=None, options=None):
+ """Arguments are as described in the attribute description above."""
+ super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions')
+ self.name = name
+ self.index = index
+ self.number = number
+ self.type = type
+
+
+class ServiceDescriptor(DescriptorBase):
+
+ """Descriptor for a service.
+
+ name: (str) Name of the service.
+ full_name: (str) Full name of the service, including package name.
+ index: (int) 0-indexed index giving the order that this services
+ definition appears withing the .proto file.
+ methods: (list of MethodDescriptor) List of methods provided by this
+ service.
+ options: (descriptor_pb2.ServiceOptions) Service options message or
+ None to use default service options.
+ """
+
+ def __init__(self, name, full_name, index, methods, options=None):
+ super(ServiceDescriptor, self).__init__(options, 'ServiceOptions')
+ self.name = name
+ self.full_name = full_name
+ self.index = index
+ self.methods = methods
+ # Set the containing service for each method in this service.
+ for method in self.methods:
+ method.containing_service = self
+
+ def FindMethodByName(self, name):
+ """Searches for the specified method, and returns its descriptor."""
+ for method in self.methods:
+ if name == method.name:
+ return method
+ return None
+
+
+class MethodDescriptor(DescriptorBase):
+
+ """Descriptor for a method in a service.
+
+ name: (str) Name of the method within the service.
+ full_name: (str) Full name of method.
+ index: (int) 0-indexed index of the method inside the service.
+ containing_service: (ServiceDescriptor) The service that contains this
+ method.
+ input_type: The descriptor of the message that this method accepts.
+ output_type: The descriptor of the message that this method returns.
+ options: (descriptor_pb2.MethodOptions) Method options message or
+ None to use default method options.
+ """
+
+ def __init__(self, name, full_name, index, containing_service,
+ input_type, output_type, options=None):
+ """The arguments are as described in the description of MethodDescriptor
+ attributes above.
+
+ Note that containing_service may be None, and may be set later if necessary.
+ """
+ super(MethodDescriptor, self).__init__(options, 'MethodOptions')
+ self.name = name
+ self.full_name = full_name
+ self.index = index
+ self.containing_service = containing_service
+ self.input_type = input_type
+ self.output_type = output_type
+
+
+def _ParseOptions(message, string):
+ """Parses serialized options.
+
+ This helper function is used to parse serialized options in generated
+ proto2 files. It must not be used outside proto2.
+ """
+ message.ParseFromString(string)
+ return message;
diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py
new file mode 100755
index 0000000..e69de29
--- /dev/null
+++ b/python/google/protobuf/internal/__init__.py
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
new file mode 100755
index 0000000..d8a825d
--- /dev/null
+++ b/python/google/protobuf/internal/containers.py
@@ -0,0 +1,227 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains container classes to represent different protocol buffer types.
+
+This file defines container classes which represent categories of protocol
+buffer field types which need extra maintenance. Currently these categories
+are:
+ - Repeated scalar fields - These are all repeated fields which aren't
+ composite (e.g. they are of simple types like int32, string, etc).
+ - Repeated composite fields - Repeated fields which are composite. This
+ includes groups and nested messages.
+"""
+
+__author__ = 'petar@google.com (Petar Petrov)'
+
+
+class BaseContainer(object):
+
+ """Base container class."""
+
+ # Minimizes memory usage and disallows assignment to other attributes.
+ __slots__ = ['_message_listener', '_values']
+
+ def __init__(self, message_listener):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The RepeatedScalarFieldContainer will call this object's
+ TransitionToNonempty() method when it transitions from being empty to
+ being nonempty.
+ """
+ self._message_listener = message_listener
+ self._values = []
+
+ def __getitem__(self, key):
+ """Retrieves item by the specified key."""
+ return self._values[key]
+
+ def __len__(self):
+ """Returns the number of elements in the container."""
+ return len(self._values)
+
+ def __ne__(self, other):
+ """Checks if another instance isn't equal to this one."""
+ # The concrete classes should define __eq__.
+ return not self == other
+
+
+class RepeatedScalarFieldContainer(BaseContainer):
+
+ """Simple, type-checked, list-like container for holding repeated scalars."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_type_checker']
+
+ def __init__(self, message_listener, type_checker):
+ """
+ Args:
+ message_listener: A MessageListener implementation.
+ The RepeatedScalarFieldContainer will call this object's
+ TransitionToNonempty() method when it transitions from being empty to
+ being nonempty.
+ type_checker: A type_checkers.ValueChecker instance to run on elements
+ inserted into this container.
+ """
+ super(RepeatedScalarFieldContainer, self).__init__(message_listener)
+ self._type_checker = type_checker
+
+ def append(self, value):
+ """Appends an item to the list. Similar to list.append()."""
+ self.insert(len(self._values), value)
+
+ def insert(self, key, value):
+ """Inserts the item at the specified position. Similar to list.insert()."""
+ self._type_checker.CheckValue(value)
+ self._values.insert(key, value)
+ self._message_listener.ByteSizeDirty()
+ if len(self._values) == 1:
+ self._message_listener.TransitionToNonempty()
+
+ def extend(self, elem_seq):
+ """Extends by appending the given sequence. Similar to list.extend()."""
+ if not elem_seq:
+ return
+
+ orig_empty = len(self._values) == 0
+ new_values = []
+ for elem in elem_seq:
+ self._type_checker.CheckValue(elem)
+ new_values.append(elem)
+ self._values.extend(new_values)
+ self._message_listener.ByteSizeDirty()
+ if orig_empty:
+ self._message_listener.TransitionToNonempty()
+
+ def remove(self, elem):
+ """Removes an item from the list. Similar to list.remove()."""
+ self._values.remove(elem)
+ self._message_listener.ByteSizeDirty()
+
+ def __setitem__(self, key, value):
+ """Sets the item on the specified position."""
+ # No need to call TransitionToNonempty(), since if we're able to
+ # set the element at this index, we were already nonempty before
+ # this method was called.
+ self._message_listener.ByteSizeDirty()
+ self._type_checker.CheckValue(value)
+ self._values[key] = value
+
+ def __getslice__(self, start, stop):
+ """Retrieves the subset of items from between the specified indices."""
+ return self._values[start:stop]
+
+ def __setslice__(self, start, stop, values):
+ """Sets the subset of items from between the specified indices."""
+ new_values = []
+ for value in values:
+ self._type_checker.CheckValue(value)
+ new_values.append(value)
+ self._values[start:stop] = new_values
+ self._message_listener.ByteSizeDirty()
+
+ def __delitem__(self, key):
+ """Deletes the item at the specified position."""
+ del self._values[key]
+ self._message_listener.ByteSizeDirty()
+
+ def __delslice__(self, start, stop):
+ """Deletes the subset of items from between the specified indices."""
+ del self._values[start:stop]
+ self._message_listener.ByteSizeDirty()
+
+ def __eq__(self, other):
+ """Compares the current instance with another one."""
+ if self is other:
+ return True
+ # Special case for the same type which should be common and fast.
+ if isinstance(other, self.__class__):
+ return other._values == self._values
+ # We are presumably comparing against some other sequence type.
+ return other == self._values
+
+
+class RepeatedCompositeFieldContainer(BaseContainer):
+
+ """Simple, list-like container for holding repeated composite fields."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_message_descriptor']
+
+ def __init__(self, message_listener, message_descriptor):
+ """
+ Note that we pass in a descriptor instead of the generated directly,
+ since at the time we construct a _RepeatedCompositeFieldContainer we
+ haven't yet necessarily initialized the type that will be contained in the
+ container.
+
+ Args:
+ message_listener: A MessageListener implementation.
+ The RepeatedCompositeFieldContainer will call this object's
+ TransitionToNonempty() method when it transitions from being empty to
+ being nonempty.
+ message_descriptor: A Descriptor instance describing the protocol type
+ that should be present in this container. We'll use the
+ _concrete_class field of this descriptor when the client calls add().
+ """
+ super(RepeatedCompositeFieldContainer, self).__init__(message_listener)
+ self._message_descriptor = message_descriptor
+
+ def add(self):
+ new_element = self._message_descriptor._concrete_class()
+ new_element._SetListener(self._message_listener)
+ self._values.append(new_element)
+ self._message_listener.ByteSizeDirty()
+ self._message_listener.TransitionToNonempty()
+ return new_element
+
+ def __getslice__(self, start, stop):
+ """Retrieves the subset of items from between the specified indices."""
+ return self._values[start:stop]
+
+ def __delitem__(self, key):
+ """Deletes the item at the specified position."""
+ del self._values[key]
+ self._message_listener.ByteSizeDirty()
+
+ def __delslice__(self, start, stop):
+ """Deletes the subset of items from between the specified indices."""
+ del self._values[start:stop]
+ self._message_listener.ByteSizeDirty()
+
+ def __eq__(self, other):
+ """Compares the current instance with another one."""
+ if self is other:
+ return True
+ if not isinstance(other, self.__class__):
+ raise TypeError('Can only compare repeated composite fields against '
+ 'other repeated composite fields.')
+ return self._values == other._values
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
new file mode 100755
index 0000000..83d6fe0
--- /dev/null
+++ b/python/google/protobuf/internal/decoder.py
@@ -0,0 +1,209 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Class for decoding protocol buffer primitives.
+
+Contains the logic for decoding every logical protocol field type
+from one of the 5 physical wire types.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import struct
+from google.protobuf import message
+from google.protobuf.internal import input_stream
+from google.protobuf.internal import wire_format
+
+
+
+# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
+# that the interface is strongly inspired by WireFormat from the C++ proto2
+# implementation.
+
+
+class Decoder(object):
+
+ """Decodes logical protocol buffer fields from the wire."""
+
+ def __init__(self, s):
+ """Initializes the decoder to read from s.
+
+ Args:
+ s: An immutable sequence of bytes, which must be accessible
+ via the Python buffer() primitive (i.e., buffer(s)).
+ """
+ self._stream = input_stream.InputStream(s)
+
+ def EndOfStream(self):
+ """Returns true iff we've reached the end of the bytes we're reading."""
+ return self._stream.EndOfStream()
+
+ def Position(self):
+ """Returns the 0-indexed position in |s|."""
+ return self._stream.Position()
+
+ def ReadFieldNumberAndWireType(self):
+ """Reads a tag from the wire. Returns a (field_number, wire_type) pair."""
+ tag_and_type = self.ReadUInt32()
+ return wire_format.UnpackTag(tag_and_type)
+
+ def SkipBytes(self, bytes):
+ """Skips the specified number of bytes on the wire."""
+ self._stream.SkipBytes(bytes)
+
+ # Note that the Read*() methods below are not exactly symmetrical with the
+ # corresponding Encoder.Append*() methods. Those Encoder methods first
+ # encode a tag, but the Read*() methods below assume that the tag has already
+ # been read, and that the client wishes to read a field of the specified type
+ # starting at the current position.
+
+ def ReadInt32(self):
+ """Reads and returns a signed, varint-encoded, 32-bit integer."""
+ return self._stream.ReadVarint32()
+
+ def ReadInt64(self):
+ """Reads and returns a signed, varint-encoded, 64-bit integer."""
+ return self._stream.ReadVarint64()
+
+ def ReadUInt32(self):
+ """Reads and returns an signed, varint-encoded, 32-bit integer."""
+ return self._stream.ReadVarUInt32()
+
+ def ReadUInt64(self):
+ """Reads and returns an signed, varint-encoded,64-bit integer."""
+ return self._stream.ReadVarUInt64()
+
+ def ReadSInt32(self):
+ """Reads and returns a signed, zigzag-encoded, varint-encoded,
+ 32-bit integer."""
+ return wire_format.ZigZagDecode(self._stream.ReadVarUInt32())
+
+ def ReadSInt64(self):
+ """Reads and returns a signed, zigzag-encoded, varint-encoded,
+ 64-bit integer."""
+ return wire_format.ZigZagDecode(self._stream.ReadVarUInt64())
+
+ def ReadFixed32(self):
+ """Reads and returns an unsigned, fixed-width, 32-bit integer."""
+ return self._stream.ReadLittleEndian32()
+
+ def ReadFixed64(self):
+ """Reads and returns an unsigned, fixed-width, 64-bit integer."""
+ return self._stream.ReadLittleEndian64()
+
+ def ReadSFixed32(self):
+ """Reads and returns a signed, fixed-width, 32-bit integer."""
+ value = self._stream.ReadLittleEndian32()
+ if value >= (1 << 31):
+ value -= (1 << 32)
+ return value
+
+ def ReadSFixed64(self):
+ """Reads and returns a signed, fixed-width, 64-bit integer."""
+ value = self._stream.ReadLittleEndian64()
+ if value >= (1 << 63):
+ value -= (1 << 64)
+ return value
+
+ def ReadFloat(self):
+ """Reads and returns a 4-byte floating-point number."""
+ serialized = self._stream.ReadBytes(4)
+ return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0]
+
+ def ReadDouble(self):
+ """Reads and returns an 8-byte floating-point number."""
+ serialized = self._stream.ReadBytes(8)
+ return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0]
+
+ def ReadBool(self):
+ """Reads and returns a bool."""
+ i = self._stream.ReadVarUInt32()
+ return bool(i)
+
+ def ReadEnum(self):
+ """Reads and returns an enum value."""
+ return self._stream.ReadVarUInt32()
+
+ def ReadString(self):
+ """Reads and returns a length-delimited string."""
+ bytes = self.ReadBytes()
+ return unicode(bytes, 'utf-8')
+
+ def ReadBytes(self):
+ """Reads and returns a length-delimited byte sequence."""
+ length = self._stream.ReadVarUInt32()
+ return self._stream.ReadBytes(length)
+
+ def ReadMessageInto(self, msg):
+ """Calls msg.MergeFromString() to merge
+ length-delimited serialized message data into |msg|.
+
+ REQUIRES: The decoder must be positioned at the serialized "length"
+ prefix to a length-delmiited serialized message.
+
+ POSTCONDITION: The decoder is positioned just after the
+ serialized message, and we have merged those serialized
+ contents into |msg|.
+ """
+ length = self._stream.ReadVarUInt32()
+ sub_buffer = self._stream.GetSubBuffer(length)
+ num_bytes_used = msg.MergeFromString(sub_buffer)
+ if num_bytes_used != length:
+ raise message.DecodeError(
+ 'Submessage told to deserialize from %d-byte encoding, '
+ 'but used only %d bytes' % (length, num_bytes_used))
+ self._stream.SkipBytes(num_bytes_used)
+
+ def ReadGroupInto(self, expected_field_number, group):
+ """Calls group.MergeFromString() to merge
+ END_GROUP-delimited serialized message data into |group|.
+ We'll raise an exception if we don't find an END_GROUP
+ tag immediately after the serialized message contents.
+
+ REQUIRES: The decoder is positioned just after the START_GROUP
+ tag for this group.
+
+ POSTCONDITION: The decoder is positioned just after the
+ END_GROUP tag for this group, and we have merged
+ the contents of the group into |group|.
+ """
+ sub_buffer = self._stream.GetSubBuffer() # No a priori length limit.
+ num_bytes_used = group.MergeFromString(sub_buffer)
+ if num_bytes_used < 0:
+ raise message.DecodeError('Group message reported negative bytes read.')
+ self._stream.SkipBytes(num_bytes_used)
+ field_number, field_type = self.ReadFieldNumberAndWireType()
+ if field_type != wire_format.WIRETYPE_END_GROUP:
+ raise message.DecodeError('Group message did not end with an END_GROUP.')
+ if field_number != expected_field_number:
+ raise message.DecodeError('END_GROUP tag had field '
+ 'number %d, was expecting field number %d' % (
+ field_number, expected_field_number))
+ # We're now positioned just after the END_GROUP tag. Perfect.
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
new file mode 100755
index 0000000..98e4647
--- /dev/null
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -0,0 +1,256 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.decoder."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import struct
+import unittest
+from google.protobuf.internal import decoder
+from google.protobuf.internal import encoder
+from google.protobuf.internal import input_stream
+from google.protobuf.internal import wire_format
+from google.protobuf import message
+import logging
+import mox
+
+
+class DecoderTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.mock_stream = self.mox.CreateMock(input_stream.InputStream)
+ self.mock_message = self.mox.CreateMock(message.Message)
+
+ def testReadFieldNumberAndWireType(self):
+ # Test field numbers that will require various varint sizes.
+ for expected_field_number in (1, 15, 16, 2047, 2048):
+ for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
+ e = encoder.Encoder()
+ e.AppendTag(expected_field_number, expected_wire_type)
+ s = e.ToString()
+ d = decoder.Decoder(s)
+ field_number, wire_type = d.ReadFieldNumberAndWireType()
+ self.assertEqual(expected_field_number, field_number)
+ self.assertEqual(expected_wire_type, wire_type)
+
+ def ReadScalarTestHelper(self, test_name, decoder_method, expected_result,
+ expected_stream_method_name,
+ stream_method_return, *args):
+ """Helper for testReadScalars below.
+
+ Calls one of the Decoder.Read*() methods and ensures that the results are
+ as expected.
+
+ Args:
+ test_name: Name of this test, used for logging only.
+ decoder_method: Unbound decoder.Decoder method to call.
+ expected_result: Value we expect returned from decoder_method().
+ expected_stream_method_name: (string) Name of the InputStream
+ method we expect Decoder to call to actually read the value
+ on the wire.
+ stream_method_return: Value our mocked-out stream method should
+ return to the decoder.
+ args: Additional arguments that we expect to be passed to the
+ stream method.
+ """
+ logging.info('Testing %s scalar input.\n'
+ 'Calling %r(), and expecting that to call the '
+ 'stream method %s(%r), which will return %r. Finally, '
+ 'expecting the Decoder method to return %r'% (
+ test_name, decoder_method,
+ expected_stream_method_name, args, stream_method_return,
+ expected_result))
+
+ d = decoder.Decoder('')
+ d._stream = self.mock_stream
+ if decoder_method in (decoder.Decoder.ReadString,
+ decoder.Decoder.ReadBytes):
+ self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return))
+ # We have to use names instead of methods to work around some
+ # mox weirdness. (ResetAll() is overzealous).
+ expected_stream_method = getattr(self.mock_stream,
+ expected_stream_method_name)
+ expected_stream_method(*args).AndReturn(stream_method_return)
+
+ self.mox.ReplayAll()
+ result = decoder_method(d)
+ self.assertEqual(expected_result, result)
+ self.assert_(isinstance(result, type(expected_result)))
+ self.mox.VerifyAll()
+ self.mox.ResetAll()
+
+ VAL = 1.125 # Perfectly representable as a float (no rounding error).
+ LITTLE_FLOAT_VAL = '\x00\x00\x90?'
+ LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
+
+ def testReadScalars(self):
+ test_string = 'I can feel myself getting sutpider.'
+ scalar_tests = [
+ ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0],
+ ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0],
+ ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0],
+ ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0],
+ ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff,
+ 'ReadLittleEndian32', 0xffffffff],
+ ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff,
+ 'ReadLittleEndian64', 0xffffffffffffffff],
+ ['sfixed32', decoder.Decoder.ReadSFixed32, long(-1),
+ 'ReadLittleEndian32', long(0xffffffff)],
+ ['sfixed64', decoder.Decoder.ReadSFixed64, long(-1),
+ 'ReadLittleEndian64', 0xffffffffffffffff],
+ ['float', decoder.Decoder.ReadFloat, self.VAL,
+ 'ReadBytes', self.LITTLE_FLOAT_VAL, 4],
+ ['double', decoder.Decoder.ReadDouble, self.VAL,
+ 'ReadBytes', self.LITTLE_DOUBLE_VAL, 8],
+ ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1],
+ ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23],
+ ['string', decoder.Decoder.ReadString,
+ unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
+ len(test_string)],
+ ['utf8-string', decoder.Decoder.ReadString,
+ unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
+ len(test_string)],
+ ['bytes', decoder.Decoder.ReadBytes,
+ test_string, 'ReadBytes', test_string, len(test_string)],
+ # We test zigzag decoding routines more extensively below.
+ ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1],
+ ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1],
+ ]
+ # Ensure that we're testing different Decoder methods and using
+ # different test names in all test cases above.
+ self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
+ self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
+ for args in scalar_tests:
+ self.ReadScalarTestHelper(*args)
+
+ def testReadMessageInto(self):
+ length = 23
+ def Test(simulate_error):
+ d = decoder.Decoder('')
+ d._stream = self.mock_stream
+ self.mock_stream.ReadVarUInt32().AndReturn(length)
+ sub_buffer = object()
+ self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer)
+
+ if simulate_error:
+ self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1)
+ self.mox.ReplayAll()
+ self.assertRaises(
+ message.DecodeError, d.ReadMessageInto, self.mock_message)
+ else:
+ self.mock_message.MergeFromString(sub_buffer).AndReturn(length)
+ self.mock_stream.SkipBytes(length)
+ self.mox.ReplayAll()
+ d.ReadMessageInto(self.mock_message)
+
+ self.mox.VerifyAll()
+ self.mox.ResetAll()
+
+ Test(simulate_error=False)
+ Test(simulate_error=True)
+
+ def testReadGroupInto_Success(self):
+ # Test both the empty and nonempty cases.
+ for num_bytes in (5, 0):
+ field_number = expected_field_number = 10
+ d = decoder.Decoder('')
+ d._stream = self.mock_stream
+ sub_buffer = object()
+ self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
+ self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes)
+ self.mock_stream.SkipBytes(num_bytes)
+ self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
+ field_number, wire_format.WIRETYPE_END_GROUP))
+ self.mox.ReplayAll()
+ d.ReadGroupInto(expected_field_number, self.mock_message)
+ self.mox.VerifyAll()
+ self.mox.ResetAll()
+
+ def ReadGroupInto_FailureTestHelper(self, bytes_read):
+ d = decoder.Decoder('')
+ d._stream = self.mock_stream
+ sub_buffer = object()
+ self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
+ self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read)
+ return d
+
+ def testReadGroupInto_NegativeBytesReported(self):
+ expected_field_number = 10
+ d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1)
+ self.mox.ReplayAll()
+ self.assertRaises(message.DecodeError,
+ d.ReadGroupInto, expected_field_number,
+ self.mock_message)
+ self.mox.VerifyAll()
+
+ def testReadGroupInto_NoEndGroupTag(self):
+ field_number = expected_field_number = 10
+ num_bytes = 5
+ d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
+ self.mock_stream.SkipBytes(num_bytes)
+ # Right field number, wrong wire type.
+ self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
+ field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
+ self.mox.ReplayAll()
+ self.assertRaises(message.DecodeError,
+ d.ReadGroupInto, expected_field_number,
+ self.mock_message)
+ self.mox.VerifyAll()
+
+ def testReadGroupInto_WrongFieldNumberInEndGroupTag(self):
+ expected_field_number = 10
+ field_number = expected_field_number + 1
+ num_bytes = 5
+ d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
+ self.mock_stream.SkipBytes(num_bytes)
+ # Wrong field number, right wire type.
+ self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
+ field_number, wire_format.WIRETYPE_END_GROUP))
+ self.mox.ReplayAll()
+ self.assertRaises(message.DecodeError,
+ d.ReadGroupInto, expected_field_number,
+ self.mock_message)
+ self.mox.VerifyAll()
+
+ def testSkipBytes(self):
+ d = decoder.Decoder('')
+ num_bytes = 1024
+ self.mock_stream.SkipBytes(num_bytes)
+ d._stream = self.mock_stream
+ self.mox.ReplayAll()
+ d.SkipBytes(num_bytes)
+ self.mox.VerifyAll()
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
new file mode 100755
index 0000000..eb9f2be
--- /dev/null
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -0,0 +1,113 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Unittest for google.protobuf.internal.descriptor."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import unittest
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor
+
+class DescriptorTest(unittest.TestCase):
+
+ def setUp(self):
+ self.my_enum = descriptor.EnumDescriptor(
+ name='ForeignEnum',
+ full_name='protobuf_unittest.ForeignEnum',
+ filename='ForeignEnum',
+ values=[
+ descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
+ descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
+ descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6),
+ ])
+ self.my_message = descriptor.Descriptor(
+ name='NestedMessage',
+ full_name='protobuf_unittest.TestAllTypes.NestedMessage',
+ filename='some/filename/some.proto',
+ containing_type=None,
+ fields=[
+ descriptor.FieldDescriptor(
+ name='bb',
+ full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
+ index=0, number=1,
+ type=5, cpp_type=1, label=1,
+ default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None),
+ ],
+ nested_types=[],
+ enum_types=[
+ self.my_enum,
+ ],
+ extensions=[])
+ self.my_method = descriptor.MethodDescriptor(
+ name='Bar',
+ full_name='protobuf_unittest.TestService.Bar',
+ index=0,
+ containing_service=None,
+ input_type=None,
+ output_type=None)
+ self.my_service = descriptor.ServiceDescriptor(
+ name='TestServiceWithOptions',
+ full_name='protobuf_unittest.TestServiceWithOptions',
+ index=0,
+ methods=[
+ self.my_method
+ ])
+
+ def testEnumFixups(self):
+ self.assertEqual(self.my_enum, self.my_enum.values[0].type)
+
+ def testContainingTypeFixups(self):
+ self.assertEqual(self.my_message, self.my_message.fields[0].containing_type)
+ self.assertEqual(self.my_message, self.my_enum.containing_type)
+
+ def testContainingServiceFixups(self):
+ self.assertEqual(self.my_service, self.my_method.containing_service)
+
+ def testGetOptions(self):
+ self.assertEqual(self.my_enum.GetOptions(),
+ descriptor_pb2.EnumOptions())
+ self.assertEqual(self.my_enum.values[0].GetOptions(),
+ descriptor_pb2.EnumValueOptions())
+ self.assertEqual(self.my_message.GetOptions(),
+ descriptor_pb2.MessageOptions())
+ self.assertEqual(self.my_message.fields[0].GetOptions(),
+ descriptor_pb2.FieldOptions())
+ self.assertEqual(self.my_method.GetOptions(),
+ descriptor_pb2.MethodOptions())
+ self.assertEqual(self.my_service.GetOptions(),
+ descriptor_pb2.ServiceOptions())
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
new file mode 100755
index 0000000..3ec3b2b
--- /dev/null
+++ b/python/google/protobuf/internal/encoder.py
@@ -0,0 +1,280 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Class for encoding protocol message primitives.
+
+Contains the logic for encoding every logical protocol field type
+into one of the 5 physical wire types.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import struct
+from google.protobuf import message
+from google.protobuf.internal import wire_format
+from google.protobuf.internal import output_stream
+
+
+# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
+# that the interface is strongly inspired by WireFormat from the C++ proto2
+# implementation.
+
+
+class Encoder(object):
+
+ """Encodes logical protocol buffer fields to the wire format."""
+
+ def __init__(self):
+ self._stream = output_stream.OutputStream()
+
+ def ToString(self):
+ """Returns all values encoded in this object as a string."""
+ return self._stream.ToString()
+
+ # Append*NoTag methods. These are necessary for serializing packed
+ # repeated fields. The Append*() methods call these methods to do
+ # the actual serialization.
+ def AppendInt32NoTag(self, value):
+ """Appends a 32-bit integer to our buffer, varint-encoded."""
+ self._stream.AppendVarint32(value)
+
+ def AppendInt64NoTag(self, value):
+ """Appends a 64-bit integer to our buffer, varint-encoded."""
+ self._stream.AppendVarint64(value)
+
+ def AppendUInt32NoTag(self, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
+ self._stream.AppendVarUInt32(unsigned_value)
+
+ def AppendUInt64NoTag(self, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
+ self._stream.AppendVarUInt64(unsigned_value)
+
+ def AppendSInt32NoTag(self, value):
+ """Appends a 32-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ zigzag_value = wire_format.ZigZagEncode(value)
+ self._stream.AppendVarUInt32(zigzag_value)
+
+ def AppendSInt64NoTag(self, value):
+ """Appends a 64-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ zigzag_value = wire_format.ZigZagEncode(value)
+ self._stream.AppendVarUInt64(zigzag_value)
+
+ def AppendFixed32NoTag(self, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self._stream.AppendLittleEndian32(unsigned_value)
+
+ def AppendFixed64NoTag(self, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self._stream.AppendLittleEndian64(unsigned_value)
+
+ def AppendSFixed32NoTag(self, value):
+ """Appends a signed 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ sign = (value & 0x80000000) and -1 or 0
+ if value >> 32 != sign:
+ raise message.EncodeError('SFixed32 out of range: %d' % value)
+ self._stream.AppendLittleEndian32(value & 0xffffffff)
+
+ def AppendSFixed64NoTag(self, value):
+ """Appends a signed 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ sign = (value & 0x8000000000000000) and -1 or 0
+ if value >> 64 != sign:
+ raise message.EncodeError('SFixed64 out of range: %d' % value)
+ self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
+
+ def AppendFloatNoTag(self, value):
+ """Appends a floating-point number to our buffer."""
+ self._stream.AppendRawBytes(
+ struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
+
+ def AppendDoubleNoTag(self, value):
+ """Appends a double-precision floating-point number to our buffer."""
+ self._stream.AppendRawBytes(
+ struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
+
+ def AppendBoolNoTag(self, value):
+ """Appends a boolean to our buffer."""
+ self.AppendInt32NoTag(value)
+
+ def AppendEnumNoTag(self, value):
+ """Appends an enum value to our buffer."""
+ self.AppendInt32NoTag(value)
+
+
+ # All the Append*() methods below first append a tag+type pair to the buffer
+ # before appending the specified value.
+
+ def AppendInt32(self, field_number, value):
+ """Appends a 32-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendInt32NoTag(value)
+
+ def AppendInt64(self, field_number, value):
+ """Appends a 64-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendInt64NoTag(value)
+
+ def AppendUInt32(self, field_number, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendUInt32NoTag(unsigned_value)
+
+ def AppendUInt64(self, field_number, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendUInt64NoTag(unsigned_value)
+
+ def AppendSInt32(self, field_number, value):
+ """Appends a 32-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendSInt32NoTag(value)
+
+ def AppendSInt64(self, field_number, value):
+ """Appends a 64-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendSInt64NoTag(value)
+
+ def AppendFixed32(self, field_number, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendFixed32NoTag(unsigned_value)
+
+ def AppendFixed64(self, field_number, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendFixed64NoTag(unsigned_value)
+
+ def AppendSFixed32(self, field_number, value):
+ """Appends a signed 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendSFixed32NoTag(value)
+
+ def AppendSFixed64(self, field_number, value):
+ """Appends a signed 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendSFixed64NoTag(value)
+
+ def AppendFloat(self, field_number, value):
+ """Appends a floating-point number to our buffer."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendFloatNoTag(value)
+
+ def AppendDouble(self, field_number, value):
+ """Appends a double-precision floating-point number to our buffer."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendDoubleNoTag(value)
+
+ def AppendBool(self, field_number, value):
+ """Appends a boolean to our buffer."""
+ self.AppendInt32(field_number, value)
+
+ def AppendEnum(self, field_number, value):
+ """Appends an enum value to our buffer."""
+ self.AppendInt32(field_number, value)
+
+ def AppendString(self, field_number, value):
+ """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer,
+ with the length varint-encoded.
+ """
+ self.AppendBytes(field_number, value.encode('utf-8'))
+
+ def AppendBytes(self, field_number, value):
+ """Appends a length-prefixed sequence of bytes to our buffer, with the
+ length varint-encoded.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ self._stream.AppendVarUInt32(len(value))
+ self._stream.AppendRawBytes(value)
+
+ # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to
+ # avoid the extra string copy here. We can do so if we widen the Message
+ # interface to be able to serialize to a stream in addition to a string. The
+ # challenge when thinking ahead to the Python/C API implementation of Message
+ # is finding a stream-like Python thing to which we can write raw bytes
+ # from C. I'm not sure such a thing exists(?). (array.array is pretty much
+ # what we want, but it's not directly exposed in the Python/C API).
+
+ def AppendGroup(self, field_number, group):
+ """Appends a group to our buffer.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
+ self._stream.AppendRawBytes(group.SerializeToString())
+ self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
+
+ def AppendMessage(self, field_number, msg):
+ """Appends a nested message to our buffer.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ self._stream.AppendVarUInt32(msg.ByteSize())
+ self._stream.AppendRawBytes(msg.SerializeToString())
+
+ def AppendMessageSetItem(self, field_number, msg):
+ """Appends an item using the message set wire format.
+
+ The message set message looks like this:
+ message MessageSet {
+ repeated group Item = 1 {
+ required int32 type_id = 2;
+ required string message = 3;
+ }
+ }
+ """
+ self.AppendTag(1, wire_format.WIRETYPE_START_GROUP)
+ self.AppendInt32(2, field_number)
+ self.AppendMessage(3, msg)
+ self.AppendTag(1, wire_format.WIRETYPE_END_GROUP)
+
+ def AppendTag(self, field_number, wire_type):
+ """Appends a tag containing field number and wire type information."""
+ self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type))
diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py
new file mode 100755
index 0000000..bf75ea8
--- /dev/null
+++ b/python/google/protobuf/internal/encoder_test.py
@@ -0,0 +1,286 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.encoder."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import struct
+import logging
+import unittest
+from google.protobuf.internal import wire_format
+from google.protobuf.internal import encoder
+from google.protobuf.internal import output_stream
+from google.protobuf import message
+import mox
+
+
+class EncoderTest(unittest.TestCase):
+
+ def setUp(self):
+ self.mox = mox.Mox()
+ self.encoder = encoder.Encoder()
+ self.mock_stream = self.mox.CreateMock(output_stream.OutputStream)
+ self.mock_message = self.mox.CreateMock(message.Message)
+ self.encoder._stream = self.mock_stream
+
+ def PackTag(self, field_number, wire_type):
+ return wire_format.PackTag(field_number, wire_type)
+
+ def AppendScalarTestHelper(self, test_name, encoder_method,
+ expected_stream_method_name,
+ wire_type, field_value,
+ expected_value=None, expected_length=None,
+ is_tag_test=True):
+ """Helper for testAppendScalars.
+
+ Calls one of the Encoder methods, and ensures that the Encoder
+ in turn makes the expected calls into its OutputStream.
+
+ Args:
+ test_name: Name of this test, used only for logging.
+ encoder_method: Callable on self.encoder. This is the Encoder
+ method we're testing. If is_tag_test=True, the encoder method
+ accepts a field_number and field_value. if is_tag_test=False,
+ the encoder method accepts a field_value.
+ expected_stream_method_name: (string) Name of the OutputStream
+ method we expect Encoder to call to actually put the value
+ on the wire.
+ wire_type: The WIRETYPE_* constant we expect encoder to
+ use in the specified encoder_method.
+ field_value: The value we're trying to encode. Passed
+ into encoder_method.
+ expected_value: The value we expect Encoder to pass into
+ the OutputStream method. If None, we expect field_value
+ to pass through unmodified.
+ expected_length: The length we expect Encoder to pass to the
+ AppendVarUInt32 method. If None we expect the length of the
+ field_value.
+ is_tag_test: A Boolean. If True (the default), we append the
+ the packed field number and wire_type to the stream before
+ the field value.
+ """
+ if expected_value is None:
+ expected_value = field_value
+
+ logging.info('Testing %s scalar output.\n'
+ 'Calling %r(%r), and expecting that to call the '
+ 'stream method %s(%r).' % (
+ test_name, encoder_method, field_value,
+ expected_stream_method_name, expected_value))
+
+ if is_tag_test:
+ field_number = 10
+ # Should first append the field number and type information.
+ self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
+ # If we're length-delimited, we should then append the length.
+ if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ if expected_length is None:
+ expected_length = len(field_value)
+ self.mock_stream.AppendVarUInt32(expected_length)
+
+ # Should then append the value itself.
+ # We have to use names instead of methods to work around some
+ # mox weirdness. (ResetAll() is overzealous).
+ expected_stream_method = getattr(self.mock_stream,
+ expected_stream_method_name)
+ expected_stream_method(expected_value)
+
+ self.mox.ReplayAll()
+ if is_tag_test:
+ encoder_method(field_number, field_value)
+ else:
+ encoder_method(field_value)
+ self.mox.VerifyAll()
+ self.mox.ResetAll()
+
+ VAL = 1.125 # Perfectly representable as a float (no rounding error).
+ LITTLE_FLOAT_VAL = '\x00\x00\x90?'
+ LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
+
+ def testAppendScalars(self):
+ utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'
+ utf8_string = unicode(utf8_bytes, 'utf-8')
+ scalar_tests = [
+ ['int32', self.encoder.AppendInt32, 'AppendVarint32',
+ wire_format.WIRETYPE_VARINT, 0],
+ ['int64', self.encoder.AppendInt64, 'AppendVarint64',
+ wire_format.WIRETYPE_VARINT, 0],
+ ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32',
+ wire_format.WIRETYPE_VARINT, 0],
+ ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64',
+ wire_format.WIRETYPE_VARINT, 0],
+ ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32',
+ wire_format.WIRETYPE_FIXED32, 0],
+ ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64',
+ wire_format.WIRETYPE_FIXED64, 0],
+ ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32',
+ wire_format.WIRETYPE_FIXED32, -1, 0xffffffff],
+ ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64',
+ wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff],
+ ['float', self.encoder.AppendFloat, 'AppendRawBytes',
+ wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL],
+ ['double', self.encoder.AppendDouble, 'AppendRawBytes',
+ wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL],
+ ['bool', self.encoder.AppendBool, 'AppendVarint32',
+ wire_format.WIRETYPE_VARINT, False],
+ ['enum', self.encoder.AppendEnum, 'AppendVarint32',
+ wire_format.WIRETYPE_VARINT, 0],
+ ['string', self.encoder.AppendString, 'AppendRawBytes',
+ wire_format.WIRETYPE_LENGTH_DELIMITED,
+ "You're in a maze of twisty little passages, all alike."],
+ ['utf8-string', self.encoder.AppendString, 'AppendRawBytes',
+ wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string,
+ utf8_bytes, len(utf8_bytes)],
+ # We test zigzag encoding routines more extensively below.
+ ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32',
+ wire_format.WIRETYPE_VARINT, -1, 1],
+ ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64',
+ wire_format.WIRETYPE_VARINT, -1, 1],
+ ]
+ # Ensure that we're testing different Encoder methods and using
+ # different test names in all test cases above.
+ self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
+ self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
+ for args in scalar_tests:
+ self.AppendScalarTestHelper(*args)
+
+ def testAppendScalarsWithoutTags(self):
+ scalar_no_tag_tests = [
+ ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0],
+ ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0],
+ ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0],
+ ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0],
+ ['fixed32', self.encoder.AppendFixed32NoTag,
+ 'AppendLittleEndian32', None, 0],
+ ['fixed64', self.encoder.AppendFixed64NoTag,
+ 'AppendLittleEndian64', None, 0],
+ ['sfixed32', self.encoder.AppendSFixed32NoTag,
+ 'AppendLittleEndian32', None, 0],
+ ['sfixed64', self.encoder.AppendSFixed64NoTag,
+ 'AppendLittleEndian64', None, 0],
+ ['float', self.encoder.AppendFloatNoTag,
+ 'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL],
+ ['double', self.encoder.AppendDoubleNoTag,
+ 'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL],
+ ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
+ ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
+ ['sint32', self.encoder.AppendSInt32NoTag,
+ 'AppendVarUInt32', None, -1, 1],
+ ['sint64', self.encoder.AppendSInt64NoTag,
+ 'AppendVarUInt64', None, -1, 1],
+ ]
+
+ self.assertEqual(len(scalar_no_tag_tests),
+ len(set(t[0] for t in scalar_no_tag_tests)))
+ self.assert_(len(scalar_no_tag_tests) >=
+ len(set(t[1] for t in scalar_no_tag_tests)))
+ for args in scalar_no_tag_tests:
+ # For no tag tests, the wire_type is not used, so we put in None.
+ self.AppendScalarTestHelper(is_tag_test=False, *args)
+
+ def testAppendGroup(self):
+ field_number = 23
+ # Should first append the start-group marker.
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP))
+ # Should then serialize itself.
+ self.mock_message.SerializeToString().AndReturn('foo')
+ self.mock_stream.AppendRawBytes('foo')
+ # Should finally append the end-group marker.
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP))
+
+ self.mox.ReplayAll()
+ self.encoder.AppendGroup(field_number, self.mock_message)
+ self.mox.VerifyAll()
+
+ def testAppendMessage(self):
+ field_number = 23
+ byte_size = 42
+ # Should first append the field number and type information.
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
+ # Should then append its length.
+ self.mock_message.ByteSize().AndReturn(byte_size)
+ self.mock_stream.AppendVarUInt32(byte_size)
+ # Should then serialize itself to the encoder.
+ self.mock_message.SerializeToString().AndReturn('foo')
+ self.mock_stream.AppendRawBytes('foo')
+
+ self.mox.ReplayAll()
+ self.encoder.AppendMessage(field_number, self.mock_message)
+ self.mox.VerifyAll()
+
+ def testAppendMessageSetItem(self):
+ field_number = 23
+ byte_size = 42
+ # Should first append the field number and type information.
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(1, wire_format.WIRETYPE_START_GROUP))
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(2, wire_format.WIRETYPE_VARINT))
+ self.mock_stream.AppendVarint32(field_number)
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED))
+ # Should then append its length.
+ self.mock_message.ByteSize().AndReturn(byte_size)
+ self.mock_stream.AppendVarUInt32(byte_size)
+ # Should then serialize itself to the encoder.
+ self.mock_message.SerializeToString().AndReturn('foo')
+ self.mock_stream.AppendRawBytes('foo')
+ self.mock_stream.AppendVarUInt32(
+ self.PackTag(1, wire_format.WIRETYPE_END_GROUP))
+
+ self.mox.ReplayAll()
+ self.encoder.AppendMessageSetItem(field_number, self.mock_message)
+ self.mox.VerifyAll()
+
+ def testAppendSFixed(self):
+ # Most of our bounds-checking is done in output_stream.py,
+ # but encoder.py is responsible for transforming signed
+ # fixed-width integers into unsigned ones, so we test here
+ # to ensure that we're not losing any entropy when we do
+ # that conversion.
+ field_number = 10
+ self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
+ 10, wire_format.UINT32_MAX + 1)
+ self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
+ 10, -(1 << 32))
+ self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
+ 10, wire_format.UINT64_MAX + 1)
+ self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
+ 10, -(1 << 64))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py
new file mode 100755
index 0000000..11fcfa0
--- /dev/null
+++ b/python/google/protobuf/internal/generator_test.py
@@ -0,0 +1,100 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py
+# first, since it's testing the subtler code, and since it provides decent
+# indirect testing of the protocol compiler output.
+
+"""Unittest that directly tests the output of the pure-Python protocol
+compiler. See //net/proto2/internal/reflection_test.py for a test which
+further ensures that we can use Python protocol message objects as we expect.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import unittest
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+
+
+class GeneratorTest(unittest.TestCase):
+
+ def testNestedMessageDescriptor(self):
+ field_name = 'optional_nested_message'
+ proto_type = unittest_pb2.TestAllTypes
+ self.assertEqual(
+ proto_type.NestedMessage.DESCRIPTOR,
+ proto_type.DESCRIPTOR.fields_by_name[field_name].message_type)
+
+ def testEnums(self):
+ # We test only module-level enums here.
+ # TODO(robinson): Examine descriptors directly to check
+ # enum descriptor output.
+ self.assertEqual(4, unittest_pb2.FOREIGN_FOO)
+ self.assertEqual(5, unittest_pb2.FOREIGN_BAR)
+ self.assertEqual(6, unittest_pb2.FOREIGN_BAZ)
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(1, proto.FOO)
+ self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
+ self.assertEqual(2, proto.BAR)
+ self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
+ self.assertEqual(3, proto.BAZ)
+ self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
+
+ def testContainingTypeBehaviorForExtensions(self):
+ self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR)
+ self.assertEqual(unittest_pb2.TestRequired.single.containing_type,
+ unittest_pb2.TestAllExtensions.DESCRIPTOR)
+
+ def testExtensionScope(self):
+ self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope,
+ None)
+ self.assertEqual(unittest_pb2.TestRequired.single.extension_scope,
+ unittest_pb2.TestRequired.DESCRIPTOR)
+
+ def testIsExtension(self):
+ self.assertTrue(unittest_pb2.optional_int32_extension.is_extension)
+ self.assertTrue(unittest_pb2.TestRequired.single.is_extension)
+
+ message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR
+ non_extension_descriptor = message_descriptor.fields_by_name['a']
+ self.assertTrue(not non_extension_descriptor.is_extension)
+
+ def testOptions(self):
+ proto = unittest_mset_pb2.TestMessageSet()
+ self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py
new file mode 100755
index 0000000..7bda17e
--- /dev/null
+++ b/python/google/protobuf/internal/input_stream.py
@@ -0,0 +1,338 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""InputStream is the primitive interface for reading bits from the wire.
+
+All protocol buffer deserialization can be expressed in terms of
+the InputStream primitives provided here.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import array
+import struct
+from google.protobuf import message
+from google.protobuf.internal import wire_format
+
+
+# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
+# that the interface is strongly inspired by CodedInputStream from the C++
+# proto2 implementation.
+
+
+class InputStreamBuffer(object):
+
+ """Contains all logic for reading bits, and dealing with stream position.
+
+ If an InputStream method ever raises an exception, the stream is left
+ in an indeterminate state and is not safe for further use.
+ """
+
+ def __init__(self, s):
+ # What we really want is something like array('B', s), where elements we
+ # read from the array are already given to us as one-byte integers. BUT
+ # using array() instead of buffer() would force full string copies to result
+ # from each GetSubBuffer() call.
+ #
+ # So, if the N serialized bytes of a single protocol buffer object are
+ # split evenly between 2 child messages, and so on recursively, using
+ # array('B', s) instead of buffer() would incur an additional N*logN bytes
+ # copied during deserialization.
+ #
+ # The higher constant overhead of having to ord() for every byte we read
+ # from the buffer in _ReadVarintHelper() could definitely lead to worse
+ # performance in many real-world scenarios, even if the asymptotic
+ # complexity is better. However, our real answer is that the mythical
+ # Python/C extension module output mode for the protocol compiler will
+ # be blazing-fast and will eliminate most use of this class anyway.
+ self._buffer = buffer(s)
+ self._pos = 0
+
+ def EndOfStream(self):
+ """Returns true iff we're at the end of the stream.
+ If this returns true, then a call to any other InputStream method
+ will raise an exception.
+ """
+ return self._pos >= len(self._buffer)
+
+ def Position(self):
+ """Returns the current position in the stream, or equivalently, the
+ number of bytes read so far.
+ """
+ return self._pos
+
+ def GetSubBuffer(self, size=None):
+ """Returns a sequence-like object that represents a portion of our
+ underlying sequence.
+
+ Position 0 in the returned object corresponds to self.Position()
+ in this stream.
+
+ If size is specified, then the returned object ends after the
+ next "size" bytes in this stream. If size is not specified,
+ then the returned object ends at the end of this stream.
+
+ We guarantee that the returned object R supports the Python buffer
+ interface (and thus that the call buffer(R) will work).
+
+ Note that the returned buffer is read-only.
+
+ The intended use for this method is for nested-message and nested-group
+ deserialization, where we want to make a recursive MergeFromString()
+ call on the portion of the original sequence that contains the serialized
+ nested message. (And we'd like to do so without making unnecessary string
+ copies).
+
+ REQUIRES: size is nonnegative.
+ """
+ # Note that buffer() doesn't perform any actual string copy.
+ if size is None:
+ return buffer(self._buffer, self._pos)
+ else:
+ if size < 0:
+ raise message.DecodeError('Negative size %d' % size)
+ return buffer(self._buffer, self._pos, size)
+
+ def SkipBytes(self, num_bytes):
+ """Skip num_bytes bytes ahead, or go to the end of the stream, whichever
+ comes first.
+
+ REQUIRES: num_bytes is nonnegative.
+ """
+ if num_bytes < 0:
+ raise message.DecodeError('Negative num_bytes %d' % num_bytes)
+ self._pos += num_bytes
+ self._pos = min(self._pos, len(self._buffer))
+
+ def ReadBytes(self, size):
+ """Reads up to 'size' bytes from the stream, stopping early
+ only if we reach the end of the stream. Returns the bytes read
+ as a string.
+ """
+ if size < 0:
+ raise message.DecodeError('Negative size %d' % size)
+ s = (self._buffer[self._pos : self._pos + size])
+ self._pos += len(s) # Only advance by the number of bytes actually read.
+ return s
+
+ def ReadLittleEndian32(self):
+ """Interprets the next 4 bytes of the stream as a little-endian
+ encoded, unsiged 32-bit integer, and returns that integer.
+ """
+ try:
+ i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
+ self._buffer[self._pos : self._pos + 4])
+ self._pos += 4
+ return i[0] # unpack() result is a 1-element tuple.
+ except struct.error, e:
+ raise message.DecodeError(e)
+
+ def ReadLittleEndian64(self):
+ """Interprets the next 8 bytes of the stream as a little-endian
+ encoded, unsiged 64-bit integer, and returns that integer.
+ """
+ try:
+ i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
+ self._buffer[self._pos : self._pos + 8])
+ self._pos += 8
+ return i[0] # unpack() result is a 1-element tuple.
+ except struct.error, e:
+ raise message.DecodeError(e)
+
+ def ReadVarint32(self):
+ """Reads a varint from the stream, interprets this varint
+ as a signed, 32-bit integer, and returns the integer.
+ """
+ i = self.ReadVarint64()
+ if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
+ raise message.DecodeError('Value out of range for int32: %d' % i)
+ return int(i)
+
+ def ReadVarUInt32(self):
+ """Reads a varint from the stream, interprets this varint
+ as an unsigned, 32-bit integer, and returns the integer.
+ """
+ i = self.ReadVarUInt64()
+ if i > wire_format.UINT32_MAX:
+ raise message.DecodeError('Value out of range for uint32: %d' % i)
+ return i
+
+ def ReadVarint64(self):
+ """Reads a varint from the stream, interprets this varint
+ as a signed, 64-bit integer, and returns the integer.
+ """
+ i = self.ReadVarUInt64()
+ if i > wire_format.INT64_MAX:
+ i -= (1 << 64)
+ return i
+
+ def ReadVarUInt64(self):
+ """Reads a varint from the stream, interprets this varint
+ as an unsigned, 64-bit integer, and returns the integer.
+ """
+ i = self._ReadVarintHelper()
+ if not 0 <= i <= wire_format.UINT64_MAX:
+ raise message.DecodeError('Value out of range for uint64: %d' % i)
+ return i
+
+ def _ReadVarintHelper(self):
+ """Helper for the various varint-reading methods above.
+ Reads an unsigned, varint-encoded integer from the stream and
+ returns this integer.
+
+ Does no bounds checking except to ensure that we read at most as many bytes
+ as could possibly be present in a varint-encoded 64-bit number.
+ """
+ result = 0
+ shift = 0
+ while 1:
+ if shift >= 64:
+ raise message.DecodeError('Too many bytes when decoding varint.')
+ try:
+ b = ord(self._buffer[self._pos])
+ except IndexError:
+ raise message.DecodeError('Truncated varint.')
+ self._pos += 1
+ result |= ((b & 0x7f) << shift)
+ shift += 7
+ if not (b & 0x80):
+ return result
+
+
+class InputStreamArray(object):
+
+ """Contains all logic for reading bits, and dealing with stream position.
+
+ If an InputStream method ever raises an exception, the stream is left
+ in an indeterminate state and is not safe for further use.
+
+ This alternative to InputStreamBuffer is used in environments where buffer()
+ is unavailble, such as Google App Engine.
+ """
+
+ def __init__(self, s):
+ self._buffer = array.array('B', s)
+ self._pos = 0
+
+ def EndOfStream(self):
+ return self._pos >= len(self._buffer)
+
+ def Position(self):
+ return self._pos
+
+ def GetSubBuffer(self, size=None):
+ if size is None:
+ return self._buffer[self._pos : ].tostring()
+ else:
+ if size < 0:
+ raise message.DecodeError('Negative size %d' % size)
+ return self._buffer[self._pos : self._pos + size].tostring()
+
+ def SkipBytes(self, num_bytes):
+ if num_bytes < 0:
+ raise message.DecodeError('Negative num_bytes %d' % num_bytes)
+ self._pos += num_bytes
+ self._pos = min(self._pos, len(self._buffer))
+
+ def ReadBytes(self, size):
+ if size < 0:
+ raise message.DecodeError('Negative size %d' % size)
+ s = self._buffer[self._pos : self._pos + size].tostring()
+ self._pos += len(s) # Only advance by the number of bytes actually read.
+ return s
+
+ def ReadLittleEndian32(self):
+ try:
+ i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
+ self._buffer[self._pos : self._pos + 4])
+ self._pos += 4
+ return i[0] # unpack() result is a 1-element tuple.
+ except struct.error, e:
+ raise message.DecodeError(e)
+
+ def ReadLittleEndian64(self):
+ try:
+ i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
+ self._buffer[self._pos : self._pos + 8])
+ self._pos += 8
+ return i[0] # unpack() result is a 1-element tuple.
+ except struct.error, e:
+ raise message.DecodeError(e)
+
+ def ReadVarint32(self):
+ i = self.ReadVarint64()
+ if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
+ raise message.DecodeError('Value out of range for int32: %d' % i)
+ return int(i)
+
+ def ReadVarUInt32(self):
+ i = self.ReadVarUInt64()
+ if i > wire_format.UINT32_MAX:
+ raise message.DecodeError('Value out of range for uint32: %d' % i)
+ return i
+
+ def ReadVarint64(self):
+ i = self.ReadVarUInt64()
+ if i > wire_format.INT64_MAX:
+ i -= (1 << 64)
+ return i
+
+ def ReadVarUInt64(self):
+ i = self._ReadVarintHelper()
+ if not 0 <= i <= wire_format.UINT64_MAX:
+ raise message.DecodeError('Value out of range for uint64: %d' % i)
+ return i
+
+ def _ReadVarintHelper(self):
+ result = 0
+ shift = 0
+ while 1:
+ if shift >= 64:
+ raise message.DecodeError('Too many bytes when decoding varint.')
+ try:
+ b = self._buffer[self._pos]
+ except IndexError:
+ raise message.DecodeError('Truncated varint.')
+ self._pos += 1
+ result |= ((b & 0x7f) << shift)
+ shift += 7
+ if not (b & 0x80):
+ return result
+
+
+try:
+ buffer('')
+ InputStream = InputStreamBuffer
+except NotImplementedError:
+ # Google App Engine: dev_appserver.py
+ InputStream = InputStreamArray
+except RuntimeError:
+ # Google App Engine: production
+ InputStream = InputStreamArray
diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py
new file mode 100755
index 0000000..ecec7f7
--- /dev/null
+++ b/python/google/protobuf/internal/input_stream_test.py
@@ -0,0 +1,314 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.input_stream."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import unittest
+from google.protobuf import message
+from google.protobuf.internal import wire_format
+from google.protobuf.internal import input_stream
+
+
+class InputStreamBufferTest(unittest.TestCase):
+
+ def setUp(self):
+ self.__original_input_stream = input_stream.InputStream
+ input_stream.InputStream = input_stream.InputStreamBuffer
+
+ def tearDown(self):
+ input_stream.InputStream = self.__original_input_stream
+
+ def testEndOfStream(self):
+ stream = input_stream.InputStream('abcd')
+ self.assertFalse(stream.EndOfStream())
+ self.assertEqual('abcd', stream.ReadBytes(10))
+ self.assertTrue(stream.EndOfStream())
+
+ def testPosition(self):
+ stream = input_stream.InputStream('abcd')
+ self.assertEqual(0, stream.Position())
+ self.assertEqual(0, stream.Position()) # No side-effects.
+ stream.ReadBytes(1)
+ self.assertEqual(1, stream.Position())
+ stream.ReadBytes(1)
+ self.assertEqual(2, stream.Position())
+ stream.ReadBytes(10)
+ self.assertEqual(4, stream.Position()) # Can't go past end of stream.
+
+ def testGetSubBuffer(self):
+ stream = input_stream.InputStream('abcd')
+ # Try leaving out the size.
+ self.assertEqual('abcd', str(stream.GetSubBuffer()))
+ stream.SkipBytes(1)
+ # GetSubBuffer() always starts at current size.
+ self.assertEqual('bcd', str(stream.GetSubBuffer()))
+ # Try 0-size.
+ self.assertEqual('', str(stream.GetSubBuffer(0)))
+ # Negative sizes should raise an error.
+ self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1)
+ # Positive sizes should work as expected.
+ self.assertEqual('b', str(stream.GetSubBuffer(1)))
+ self.assertEqual('bc', str(stream.GetSubBuffer(2)))
+ # Sizes longer than remaining bytes in the buffer should
+ # return the whole remaining buffer.
+ self.assertEqual('bcd', str(stream.GetSubBuffer(1000)))
+
+ def testSkipBytes(self):
+ stream = input_stream.InputStream('')
+ # Skipping bytes when at the end of stream
+ # should have no effect.
+ stream.SkipBytes(0)
+ stream.SkipBytes(1)
+ stream.SkipBytes(2)
+ self.assertTrue(stream.EndOfStream())
+ self.assertEqual(0, stream.Position())
+
+ # Try skipping within a stream.
+ stream = input_stream.InputStream('abcd')
+ self.assertEqual(0, stream.Position())
+ stream.SkipBytes(1)
+ self.assertEqual(1, stream.Position())
+ stream.SkipBytes(10) # Can't skip past the end.
+ self.assertEqual(4, stream.Position())
+
+ # Ensure that a negative skip raises an exception.
+ stream = input_stream.InputStream('abcd')
+ stream.SkipBytes(1)
+ self.assertRaises(message.DecodeError, stream.SkipBytes, -1)
+
+ def testReadBytes(self):
+ s = 'abcd'
+ # Also test going past the total stream length.
+ for i in range(len(s) + 10):
+ stream = input_stream.InputStream(s)
+ self.assertEqual(s[:i], stream.ReadBytes(i))
+ self.assertEqual(min(i, len(s)), stream.Position())
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadBytes, -1)
+
+ def EnsureFailureOnEmptyStream(self, input_stream_method):
+ """Helper for integer-parsing tests below.
+ Ensures that the given InputStream method raises a DecodeError
+ if called on a stream with no bytes remaining.
+ """
+ stream = input_stream.InputStream('')
+ self.assertRaises(message.DecodeError, input_stream_method, stream)
+
+ def testReadLittleEndian32(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32)
+ s = ''
+ # Read 0.
+ s += '\x00\x00\x00\x00'
+ # Read 1.
+ s += '\x01\x00\x00\x00'
+ # Read a bunch of different bytes.
+ s += '\x01\x02\x03\x04'
+ # Read max unsigned 32-bit int.
+ s += '\xff\xff\xff\xff'
+ # Try a read with fewer than 4 bytes left in the stream.
+ s += '\x00\x00\x00'
+ stream = input_stream.InputStream(s)
+ self.assertEqual(0, stream.ReadLittleEndian32())
+ self.assertEqual(4, stream.Position())
+ self.assertEqual(1, stream.ReadLittleEndian32())
+ self.assertEqual(8, stream.Position())
+ self.assertEqual(0x04030201, stream.ReadLittleEndian32())
+ self.assertEqual(12, stream.Position())
+ self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32())
+ self.assertEqual(16, stream.Position())
+ self.assertRaises(message.DecodeError, stream.ReadLittleEndian32)
+
+ def testReadLittleEndian64(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64)
+ s = ''
+ # Read 0.
+ s += '\x00\x00\x00\x00\x00\x00\x00\x00'
+ # Read 1.
+ s += '\x01\x00\x00\x00\x00\x00\x00\x00'
+ # Read a bunch of different bytes.
+ s += '\x01\x02\x03\x04\x05\x06\x07\x08'
+ # Read max unsigned 64-bit int.
+ s += '\xff\xff\xff\xff\xff\xff\xff\xff'
+ # Try a read with fewer than 8 bytes left in the stream.
+ s += '\x00\x00\x00'
+ stream = input_stream.InputStream(s)
+ self.assertEqual(0, stream.ReadLittleEndian64())
+ self.assertEqual(8, stream.Position())
+ self.assertEqual(1, stream.ReadLittleEndian64())
+ self.assertEqual(16, stream.Position())
+ self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64())
+ self.assertEqual(24, stream.Position())
+ self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64())
+ self.assertEqual(32, stream.Position())
+ self.assertRaises(message.DecodeError, stream.ReadLittleEndian64)
+
+ def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method):
+ """Helper for tests below that test successful reads of various varints.
+
+ Args:
+ varints_and_ints: Iterable of (str, integer) pairs, where the string
+ gives the wire encoding and the integer gives the value we expect
+ to be returned by the read_method upon encountering this string.
+ read_method: Unbound InputStream method that is capable of reading
+ the encoded strings provided in the first elements of varints_and_ints.
+ """
+ s = ''.join(s for s, i in varints_and_ints)
+ stream = input_stream.InputStream(s)
+ expected_pos = 0
+ self.assertEqual(expected_pos, stream.Position())
+ for s, expected_int in varints_and_ints:
+ self.assertEqual(expected_int, read_method(stream))
+ expected_pos += len(s)
+ self.assertEqual(expected_pos, stream.Position())
+
+ def testReadVarint32Success(self):
+ varints_and_ints = [
+ ('\x00', 0),
+ ('\x01', 1),
+ ('\x7f', 127),
+ ('\x80\x01', 128),
+ ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
+ ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX),
+ ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN),
+ ]
+ self.ReadVarintSuccessTestHelper(varints_and_ints,
+ input_stream.InputStream.ReadVarint32)
+
+ def testReadVarint32Failure(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32)
+
+ # Try and fail to read INT32_MAX + 1.
+ s = '\x80\x80\x80\x80\x08'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarint32)
+
+ # Try and fail to read INT32_MIN - 1.
+ s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarint32)
+
+ # Try and fail to read something that looks like
+ # a varint with more than 10 bytes.
+ s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarint32)
+
+ def testReadVarUInt32Success(self):
+ varints_and_ints = [
+ ('\x00', 0),
+ ('\x01', 1),
+ ('\x7f', 127),
+ ('\x80\x01', 128),
+ ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX),
+ ]
+ self.ReadVarintSuccessTestHelper(varints_and_ints,
+ input_stream.InputStream.ReadVarUInt32)
+
+ def testReadVarUInt32Failure(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32)
+ # Try and fail to read UINT32_MAX + 1
+ s = '\x80\x80\x80\x80\x10'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
+
+ # Try and fail to read something that looks like
+ # a varint with more than 10 bytes.
+ s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
+
+ def testReadVarint64Success(self):
+ varints_and_ints = [
+ ('\x00', 0),
+ ('\x01', 1),
+ ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
+ ('\x7f', 127),
+ ('\x80\x01', 128),
+ ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX),
+ ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN),
+ ]
+ self.ReadVarintSuccessTestHelper(varints_and_ints,
+ input_stream.InputStream.ReadVarint64)
+
+ def testReadVarint64Failure(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64)
+ # Try and fail to read something with the mythical 64th bit set.
+ s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarint64)
+
+ # Try and fail to read something that looks like
+ # a varint with more than 10 bytes.
+ s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarint64)
+
+ def testReadVarUInt64Success(self):
+ varints_and_ints = [
+ ('\x00', 0),
+ ('\x01', 1),
+ ('\x7f', 127),
+ ('\x80\x01', 128),
+ ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63),
+ ]
+ self.ReadVarintSuccessTestHelper(varints_and_ints,
+ input_stream.InputStream.ReadVarUInt64)
+
+ def testReadVarUInt64Failure(self):
+ self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64)
+ # Try and fail to read something with the mythical 64th bit set.
+ s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
+
+ # Try and fail to read something that looks like
+ # a varint with more than 10 bytes.
+ s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
+ stream = input_stream.InputStream(s)
+ self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
+
+
+class InputStreamArrayTest(InputStreamBufferTest):
+
+ def setUp(self):
+ # Test InputStreamArray against the same tests in InputStreamBuffer
+ self.__original_input_stream = input_stream.InputStream
+ input_stream.InputStream = input_stream.InputStreamArray
+
+ def tearDown(self):
+ input_stream.InputStream = self.__original_input_stream
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py
new file mode 100755
index 0000000..4397895
--- /dev/null
+++ b/python/google/protobuf/internal/message_listener.py
@@ -0,0 +1,69 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Defines a listener interface for observing certain
+state transitions on Message objects.
+
+Also defines a null implementation of this interface.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+
+class MessageListener(object):
+
+ """Listens for transitions to nonempty and for invalidations of cached
+ byte sizes. Meant to be registered via Message._SetListener().
+ """
+
+ def TransitionToNonempty(self):
+ """Called the *first* time that this message becomes nonempty.
+ Implementations are free (but not required) to call this method multiple
+ times after the message has become nonempty.
+ """
+ raise NotImplementedError
+
+ def ByteSizeDirty(self):
+ """Called *every* time the cached byte size value
+ for this object is invalidated (transitions from being
+ "clean" to "dirty").
+ """
+ raise NotImplementedError
+
+
+class NullMessageListener(object):
+
+ """No-op MessageListener implementation."""
+
+ def TransitionToNonempty(self):
+ pass
+
+ def ByteSizeDirty(self):
+ pass
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
new file mode 100755
index 0000000..df344cf
--- /dev/null
+++ b/python/google/protobuf/internal/message_test.py
@@ -0,0 +1,53 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests python protocol buffers against the golden message."""
+
+__author__ = 'gps@google.com (Gregory P. Smith)'
+
+import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf.internal import test_util
+
+
+class MessageTest(test_util.GoldenMessageTestCase):
+
+ def testGoldenMessage(self):
+ golden_data = test_util.GoldenFile('golden_message').read()
+ golden_message = unittest_pb2.TestAllTypes()
+ golden_message.ParseFromString(golden_data)
+ self.ExpectAllFieldsSet(golden_message)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto
new file mode 100644
index 0000000..e2d9701
--- /dev/null
+++ b/python/google/protobuf/internal/more_extensions.proto
@@ -0,0 +1,58 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: robinson@google.com (Will Robinson)
+
+
+package google.protobuf.internal;
+
+
+message TopLevelMessage {
+ optional ExtendedMessage submessage = 1;
+}
+
+
+message ExtendedMessage {
+ extensions 1 to max;
+}
+
+
+message ForeignMessage {
+ optional int32 foreign_message_int = 1;
+}
+
+
+extend ExtendedMessage {
+ optional int32 optional_int_extension = 1;
+ optional ForeignMessage optional_message_extension = 2;
+
+ repeated int32 repeated_int_extension = 3;
+ repeated ForeignMessage repeated_message_extension = 4;
+}
diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto
new file mode 100644
index 0000000..c701b44
--- /dev/null
+++ b/python/google/protobuf/internal/more_messages.proto
@@ -0,0 +1,51 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// http://code.google.com/p/protobuf/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: robinson@google.com (Will Robinson)
+
+
+package google.protobuf.internal;
+
+// A message where tag numbers are listed out of order, to allow us to test our
+// canonicalization of serialized output, which should always be in tag order.
+// We also mix in some extensions for extra fun.
+message OutOfOrderFields {
+ optional sint32 optional_sint32 = 5;
+ extensions 4 to 4;
+ optional uint32 optional_uint32 = 3;
+ extensions 2 to 2;
+ optional int32 optional_int32 = 1;
+};
+
+
+extend OutOfOrderFields {
+ optional uint64 optional_uint64 = 4;
+ optional int64 optional_int64 = 2;
+}
diff --git a/python/google/protobuf/internal/output_stream.py b/python/google/protobuf/internal/output_stream.py
new file mode 100755
index 0000000..6c2d6f6
--- /dev/null
+++ b/python/google/protobuf/internal/output_stream.py
@@ -0,0 +1,125 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""OutputStream is the primitive interface for sticking bits on the wire.
+
+All protocol buffer serialization can be expressed in terms of
+the OutputStream primitives provided here.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import array
+import struct
+from google.protobuf import message
+from google.protobuf.internal import wire_format
+
+
+
+# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
+# that the interface is strongly inspired by CodedOutputStream from the C++
+# proto2 implementation.
+
+
+class OutputStream(object):
+
+ """Contains all logic for writing bits, and ToString() to get the result."""
+
+ def __init__(self):
+ self._buffer = array.array('B')
+
+ def AppendRawBytes(self, raw_bytes):
+ """Appends raw_bytes to our internal buffer."""
+ self._buffer.fromstring(raw_bytes)
+
+ def AppendLittleEndian32(self, unsigned_value):
+ """Appends an unsigned 32-bit integer to the internal buffer,
+ in little-endian byte order.
+ """
+ if not 0 <= unsigned_value <= wire_format.UINT32_MAX:
+ raise message.EncodeError(
+ 'Unsigned 32-bit out of range: %d' % unsigned_value)
+ self._buffer.fromstring(struct.pack(
+ wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value))
+
+ def AppendLittleEndian64(self, unsigned_value):
+ """Appends an unsigned 64-bit integer to the internal buffer,
+ in little-endian byte order.
+ """
+ if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
+ raise message.EncodeError(
+ 'Unsigned 64-bit out of range: %d' % unsigned_value)
+ self._buffer.fromstring(struct.pack(
+ wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value))
+
+ def AppendVarint32(self, value):
+ """Appends a signed 32-bit integer to the internal buffer,
+ encoded as a varint. (Note that a negative varint32 will
+ always require 10 bytes of space.)
+ """
+ if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX:
+ raise message.EncodeError('Value out of range: %d' % value)
+ self.AppendVarint64(value)
+
+ def AppendVarUInt32(self, value):
+ """Appends an unsigned 32-bit integer to the internal buffer,
+ encoded as a varint.
+ """
+ if not 0 <= value <= wire_format.UINT32_MAX:
+ raise message.EncodeError('Value out of range: %d' % value)
+ self.AppendVarUInt64(value)
+
+ def AppendVarint64(self, value):
+ """Appends a signed 64-bit integer to the internal buffer,
+ encoded as a varint.
+ """
+ if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX:
+ raise message.EncodeError('Value out of range: %d' % value)
+ if value < 0:
+ value += (1 << 64)
+ self.AppendVarUInt64(value)
+
+ def AppendVarUInt64(self, unsigned_value):
+ """Appends an unsigned 64-bit integer to the internal buffer,
+ encoded as a varint.
+ """
+ if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
+ raise message.EncodeError('Value out of range: %d' % unsigned_value)
+ while True:
+ bits = unsigned_value & 0x7f
+ unsigned_value >>= 7
+ if not unsigned_value:
+ self._buffer.append(bits)
+ break
+ self._buffer.append(0x80|bits)
+
+ def ToString(self):
+ """Returns a string containing the bytes in our internal buffer."""
+ return self._buffer.tostring()
diff --git a/python/google/protobuf/internal/output_stream_test.py b/python/google/protobuf/internal/output_stream_test.py
new file mode 100755
index 0000000..df92eec
--- /dev/null
+++ b/python/google/protobuf/internal/output_stream_test.py
@@ -0,0 +1,178 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.output_stream."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import unittest
+from google.protobuf import message
+from google.protobuf.internal import output_stream
+from google.protobuf.internal import wire_format
+
+
+class OutputStreamTest(unittest.TestCase):
+
+ def setUp(self):
+ self.stream = output_stream.OutputStream()
+
+ def testAppendRawBytes(self):
+ # Empty string.
+ self.stream.AppendRawBytes('')
+ self.assertEqual('', self.stream.ToString())
+
+ # Nonempty string.
+ self.stream.AppendRawBytes('abc')
+ self.assertEqual('abc', self.stream.ToString())
+
+ # Ensure that we're actually appending.
+ self.stream.AppendRawBytes('def')
+ self.assertEqual('abcdef', self.stream.ToString())
+
+ def AppendNumericTestHelper(self, append_fn, values_and_strings):
+ """For each (value, expected_string) pair in values_and_strings,
+ calls an OutputStream.Append*(value) method on an OutputStream and ensures
+ that the string written to that stream matches expected_string.
+
+ Args:
+ append_fn: Unbound OutputStream method that takes an integer or
+ long value as input.
+ values_and_strings: Iterable of (value, expected_string) pairs.
+ """
+ for conversion in (int, long):
+ for value, string in values_and_strings:
+ stream = output_stream.OutputStream()
+ expected_string = ''
+ append_fn(stream, conversion(value))
+ expected_string += string
+ self.assertEqual(expected_string, stream.ToString())
+
+ def AppendOverflowTestHelper(self, append_fn, value):
+ """Calls an OutputStream.Append*(value) method and asserts
+ that the method raises message.EncodeError.
+
+ Args:
+ append_fn: Unbound OutputStream method that takes an integer or
+ long value as input.
+ value: Value to pass to append_fn which should cause an
+ message.EncodeError.
+ """
+ stream = output_stream.OutputStream()
+ self.assertRaises(message.EncodeError, append_fn, stream, value)
+
+ def testAppendLittleEndian32(self):
+ append_fn = output_stream.OutputStream.AppendLittleEndian32
+ values_and_expected_strings = [
+ (0, '\x00\x00\x00\x00'),
+ (1, '\x01\x00\x00\x00'),
+ ((1 << 32) - 1, '\xff\xff\xff\xff'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, 1 << 32)
+ self.AppendOverflowTestHelper(append_fn, -1)
+
+ def testAppendLittleEndian64(self):
+ append_fn = output_stream.OutputStream.AppendLittleEndian64
+ values_and_expected_strings = [
+ (0, '\x00\x00\x00\x00\x00\x00\x00\x00'),
+ (1, '\x01\x00\x00\x00\x00\x00\x00\x00'),
+ ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, 1 << 64)
+ self.AppendOverflowTestHelper(append_fn, -1)
+
+ def testAppendVarint32(self):
+ append_fn = output_stream.OutputStream.AppendVarint32
+ values_and_expected_strings = [
+ (0, '\x00'),
+ (1, '\x01'),
+ (127, '\x7f'),
+ (128, '\x80\x01'),
+ (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
+ (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'),
+ (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1)
+ self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1)
+
+ def testAppendVarUInt32(self):
+ append_fn = output_stream.OutputStream.AppendVarUInt32
+ values_and_expected_strings = [
+ (0, '\x00'),
+ (1, '\x01'),
+ (127, '\x7f'),
+ (128, '\x80\x01'),
+ (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, -1)
+ self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1)
+
+ def testAppendVarint64(self):
+ append_fn = output_stream.OutputStream.AppendVarint64
+ values_and_expected_strings = [
+ (0, '\x00'),
+ (1, '\x01'),
+ (127, '\x7f'),
+ (128, '\x80\x01'),
+ (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
+ (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'),
+ (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1)
+ self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1)
+
+ def testAppendVarUInt64(self):
+ append_fn = output_stream.OutputStream.AppendVarUInt64
+ values_and_expected_strings = [
+ (0, '\x00'),
+ (1, '\x01'),
+ (127, '\x7f'),
+ (128, '\x80\x01'),
+ (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
+ ]
+ self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
+
+ self.AppendOverflowTestHelper(append_fn, -1)
+ self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
new file mode 100755
index 0000000..8610177
--- /dev/null
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -0,0 +1,1976 @@
+#! /usr/bin/python
+# -*- coding: utf-8 -*-
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Unittest for reflection.py, which also indirectly tests the output of the
+pure-Python protocol compiler.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import operator
+
+import unittest
+# TODO(robinson): When we split this test in two, only some of these imports
+# will be necessary in each test.
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_mset_pb2
+from google.protobuf import unittest_pb2
+from google.protobuf import descriptor_pb2
+from google.protobuf import descriptor
+from google.protobuf import message
+from google.protobuf import reflection
+from google.protobuf.internal import more_extensions_pb2
+from google.protobuf.internal import more_messages_pb2
+from google.protobuf.internal import wire_format
+from google.protobuf.internal import test_util
+from google.protobuf.internal import decoder
+
+
+class ReflectionTest(unittest.TestCase):
+
+ def assertIs(self, values, others):
+ self.assertEqual(len(values), len(others))
+ for i in range(len(values)):
+ self.assertTrue(values[i] is others[i])
+
+ def testSimpleHasBits(self):
+ # Test a scalar.
+ proto = unittest_pb2.TestAllTypes()
+ self.assertTrue(not proto.HasField('optional_int32'))
+ self.assertEqual(0, proto.optional_int32)
+ # HasField() shouldn't be true if all we've done is
+ # read the default value.
+ self.assertTrue(not proto.HasField('optional_int32'))
+ proto.optional_int32 = 1
+ # Setting a value however *should* set the "has" bit.
+ self.assertTrue(proto.HasField('optional_int32'))
+ proto.ClearField('optional_int32')
+ # And clearing that value should unset the "has" bit.
+ self.assertTrue(not proto.HasField('optional_int32'))
+
+ def testHasBitsWithSinglyNestedScalar(self):
+ # Helper used to test foreign messages and groups.
+ #
+ # composite_field_name should be the name of a non-repeated
+ # composite (i.e., foreign or group) field in TestAllTypes,
+ # and scalar_field_name should be the name of an integer-valued
+ # scalar field within that composite.
+ #
+ # I never thought I'd miss C++ macros and templates so much. :(
+ # This helper is semantically just:
+ #
+ # assert proto.composite_field.scalar_field == 0
+ # assert not proto.composite_field.HasField('scalar_field')
+ # assert not proto.HasField('composite_field')
+ #
+ # proto.composite_field.scalar_field = 10
+ # old_composite_field = proto.composite_field
+ #
+ # assert proto.composite_field.scalar_field == 10
+ # assert proto.composite_field.HasField('scalar_field')
+ # assert proto.HasField('composite_field')
+ #
+ # proto.ClearField('composite_field')
+ #
+ # assert not proto.composite_field.HasField('scalar_field')
+ # assert not proto.HasField('composite_field')
+ # assert proto.composite_field.scalar_field == 0
+ #
+ # # Now ensure that ClearField('composite_field') disconnected
+ # # the old field object from the object tree...
+ # assert old_composite_field is not proto.composite_field
+ # old_composite_field.scalar_field = 20
+ # assert not proto.composite_field.HasField('scalar_field')
+ # assert not proto.HasField('composite_field')
+ def TestCompositeHasBits(composite_field_name, scalar_field_name):
+ proto = unittest_pb2.TestAllTypes()
+ # First, check that we can get the scalar value, and see that it's the
+ # default (0), but that proto.HasField('omposite') and
+ # proto.composite.HasField('scalar') will still return False.
+ composite_field = getattr(proto, composite_field_name)
+ original_scalar_value = getattr(composite_field, scalar_field_name)
+ self.assertEqual(0, original_scalar_value)
+ # Assert that the composite object does not "have" the scalar.
+ self.assertTrue(not composite_field.HasField(scalar_field_name))
+ # Assert that proto does not "have" the composite field.
+ self.assertTrue(not proto.HasField(composite_field_name))
+
+ # Now set the scalar within the composite field. Ensure that the setting
+ # is reflected, and that proto.HasField('composite') and
+ # proto.composite.HasField('scalar') now both return True.
+ new_val = 20
+ setattr(composite_field, scalar_field_name, new_val)
+ self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
+ # Hold on to a reference to the current composite_field object.
+ old_composite_field = composite_field
+ # Assert that the has methods now return true.
+ self.assertTrue(composite_field.HasField(scalar_field_name))
+ self.assertTrue(proto.HasField(composite_field_name))
+
+ # Now call the clear method...
+ proto.ClearField(composite_field_name)
+
+ # ...and ensure that the "has" bits are all back to False...
+ composite_field = getattr(proto, composite_field_name)
+ self.assertTrue(not composite_field.HasField(scalar_field_name))
+ self.assertTrue(not proto.HasField(composite_field_name))
+ # ...and ensure that the scalar field has returned to its default.
+ self.assertEqual(0, getattr(composite_field, scalar_field_name))
+
+ # Finally, ensure that modifications to the old composite field object
+ # don't have any effect on the parent.
+ #
+ # (NOTE that when we clear the composite field in the parent, we actually
+ # don't recursively clear down the tree. Instead, we just disconnect the
+ # cleared composite from the tree.)
+ self.assertTrue(old_composite_field is not composite_field)
+ setattr(old_composite_field, scalar_field_name, new_val)
+ self.assertTrue(not composite_field.HasField(scalar_field_name))
+ self.assertTrue(not proto.HasField(composite_field_name))
+ self.assertEqual(0, getattr(composite_field, scalar_field_name))
+
+ # Test simple, single-level nesting when we set a scalar.
+ TestCompositeHasBits('optionalgroup', 'a')
+ TestCompositeHasBits('optional_nested_message', 'bb')
+ TestCompositeHasBits('optional_foreign_message', 'c')
+ TestCompositeHasBits('optional_import_message', 'd')
+
+ def testReferencesToNestedMessage(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ del proto
+ # A previous version had a bug where this would raise an exception when
+ # hitting a now-dead weak reference.
+ nested.bb = 23
+
+ def testDisconnectingNestedMessageBeforeSettingField(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ proto.ClearField('optional_nested_message') # Should disconnect from parent
+ self.assertTrue(nested is not proto.optional_nested_message)
+ nested.bb = 23
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+
+ def testHasBitsWhenModifyingRepeatedFields(self):
+ # Test nesting when we add an element to a repeated field in a submessage.
+ proto = unittest_pb2.TestNestedMessageHasBits()
+ proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
+ self.assertEqual(
+ [5], proto.optional_nested_message.nestedmessage_repeated_int32)
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ # Do the same test, but with a repeated composite field within the
+ # submessage.
+ proto.ClearField('optional_nested_message')
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
+ self.assertTrue(proto.HasField('optional_nested_message'))
+
+ def testHasBitsForManyLevelsOfNesting(self):
+ # Test nesting many levels deep.
+ recursive_proto = unittest_pb2.TestMutualRecursionA()
+ self.assertTrue(not recursive_proto.HasField('bb'))
+ self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
+ self.assertTrue(not recursive_proto.HasField('bb'))
+ recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
+ self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
+ self.assertTrue(recursive_proto.HasField('bb'))
+ self.assertTrue(recursive_proto.bb.HasField('a'))
+ self.assertTrue(recursive_proto.bb.a.HasField('bb'))
+ self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
+ self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
+ self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
+ self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
+
+ def testSingularListFields(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_fixed32 = 1
+ proto.optional_int32 = 5
+ proto.optional_string = 'foo'
+ self.assertEqual(
+ [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
+ (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
+ (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
+ proto.ListFields())
+
+ def testRepeatedListFields(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.repeated_fixed32.append(1)
+ proto.repeated_int32.append(5)
+ proto.repeated_int32.append(11)
+ proto.repeated_string.extend(['foo', 'bar'])
+ proto.repeated_string.extend([])
+ proto.repeated_string.append('baz')
+ proto.repeated_string.extend(str(x) for x in xrange(2))
+ proto.optional_int32 = 21
+ self.assertEqual(
+ [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
+ (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
+ (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
+ (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
+ ['foo', 'bar', 'baz', '0', '1']) ],
+ proto.ListFields())
+
+ def testSingularListExtensions(self):
+ proto = unittest_pb2.TestAllExtensions()
+ proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
+ proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
+ proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
+ self.assertEqual(
+ [ (unittest_pb2.optional_int32_extension , 5),
+ (unittest_pb2.optional_fixed32_extension, 1),
+ (unittest_pb2.optional_string_extension , 'foo') ],
+ proto.ListFields())
+
+ def testRepeatedListExtensions(self):
+ proto = unittest_pb2.TestAllExtensions()
+ proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
+ proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
+ proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
+ proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
+ proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
+ proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
+ proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
+ self.assertEqual(
+ [ (unittest_pb2.optional_int32_extension , 21),
+ (unittest_pb2.repeated_int32_extension , [5, 11]),
+ (unittest_pb2.repeated_fixed32_extension, [1]),
+ (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
+ proto.ListFields())
+
+ def testListFieldsAndExtensions(self):
+ proto = unittest_pb2.TestFieldOrderings()
+ test_util.SetAllFieldsAndExtensions(proto)
+ unittest_pb2.my_extension_int
+ self.assertEqual(
+ [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
+ (unittest_pb2.my_extension_int , 23),
+ (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
+ (unittest_pb2.my_extension_string , 'bar'),
+ (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
+ proto.ListFields())
+
+ def testDefaultValues(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(0, proto.optional_int32)
+ self.assertEqual(0, proto.optional_int64)
+ self.assertEqual(0, proto.optional_uint32)
+ self.assertEqual(0, proto.optional_uint64)
+ self.assertEqual(0, proto.optional_sint32)
+ self.assertEqual(0, proto.optional_sint64)
+ self.assertEqual(0, proto.optional_fixed32)
+ self.assertEqual(0, proto.optional_fixed64)
+ self.assertEqual(0, proto.optional_sfixed32)
+ self.assertEqual(0, proto.optional_sfixed64)
+ self.assertEqual(0.0, proto.optional_float)
+ self.assertEqual(0.0, proto.optional_double)
+ self.assertEqual(False, proto.optional_bool)
+ self.assertEqual('', proto.optional_string)
+ self.assertEqual('', proto.optional_bytes)
+
+ self.assertEqual(41, proto.default_int32)
+ self.assertEqual(42, proto.default_int64)
+ self.assertEqual(43, proto.default_uint32)
+ self.assertEqual(44, proto.default_uint64)
+ self.assertEqual(-45, proto.default_sint32)
+ self.assertEqual(46, proto.default_sint64)
+ self.assertEqual(47, proto.default_fixed32)
+ self.assertEqual(48, proto.default_fixed64)
+ self.assertEqual(49, proto.default_sfixed32)
+ self.assertEqual(-50, proto.default_sfixed64)
+ self.assertEqual(51.5, proto.default_float)
+ self.assertEqual(52e3, proto.default_double)
+ self.assertEqual(True, proto.default_bool)
+ self.assertEqual('hello', proto.default_string)
+ self.assertEqual('world', proto.default_bytes)
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
+ self.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ proto.default_import_enum)
+
+ proto = unittest_pb2.TestExtremeDefaultValues()
+ self.assertEqual(u'\u1234', proto.utf8_string)
+
+ def testHasFieldWithUnknownFieldName(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
+
+ def testClearFieldWithUnknownFieldName(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
+
+ def testDisallowedAssignments(self):
+ # It's illegal to assign values directly to repeated fields
+ # or to nonrepeated composite fields. Ensure that this fails.
+ proto = unittest_pb2.TestAllTypes()
+ # Repeated fields.
+ self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
+ # Lists shouldn't work, either.
+ self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
+ # Composite fields.
+ self.assertRaises(AttributeError, setattr, proto,
+ 'optional_nested_message', 23)
+ # Assignment to a repeated nested message field without specifying
+ # the index in the array of nested messages.
+ self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
+ 'bb', 34)
+ # Assignment to an attribute of a repeated field.
+ self.assertRaises(AttributeError, setattr, proto.repeated_float,
+ 'some_attribute', 34)
+ # proto.nonexistent_field = 23 should fail as well.
+ self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
+
+ # TODO(robinson): Add type-safety check for enums.
+ def testSingleScalarTypeSafety(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
+ self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
+ self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
+ self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+
+ def testSingleScalarBoundsChecking(self):
+ def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
+ pb = unittest_pb2.TestAllTypes()
+ setattr(pb, field_name, expected_min)
+ setattr(pb, field_name, expected_max)
+ self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
+ self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
+
+ TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
+ TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
+ TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
+ TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
+ TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
+
+ def testRepeatedScalarTypeSafety(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
+ self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
+ self.assertRaises(TypeError, proto.repeated_string, 10)
+ self.assertRaises(TypeError, proto.repeated_bytes, 10)
+
+ proto.repeated_int32.append(10)
+ proto.repeated_int32[0] = 23
+ self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
+ self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
+
+ def testSingleScalarGettersAndSetters(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(0, proto.optional_int32)
+ proto.optional_int32 = 1
+ self.assertEqual(1, proto.optional_int32)
+ # TODO(robinson): Test all other scalar field types.
+
+ def testSingleScalarClearField(self):
+ proto = unittest_pb2.TestAllTypes()
+ # Should be allowed to clear something that's not there (a no-op).
+ proto.ClearField('optional_int32')
+ proto.optional_int32 = 1
+ self.assertTrue(proto.HasField('optional_int32'))
+ proto.ClearField('optional_int32')
+ self.assertEqual(0, proto.optional_int32)
+ self.assertTrue(not proto.HasField('optional_int32'))
+ # TODO(robinson): Test all other scalar field types.
+
+ def testEnums(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(1, proto.FOO)
+ self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
+ self.assertEqual(2, proto.BAR)
+ self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
+ self.assertEqual(3, proto.BAZ)
+ self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
+
+ def testRepeatedScalars(self):
+ proto = unittest_pb2.TestAllTypes()
+
+ self.assertTrue(not proto.repeated_int32)
+ self.assertEqual(0, len(proto.repeated_int32))
+ proto.repeated_int32.append(5)
+ proto.repeated_int32.append(10)
+ proto.repeated_int32.append(15)
+ self.assertTrue(proto.repeated_int32)
+ self.assertEqual(3, len(proto.repeated_int32))
+
+ self.assertEqual([5, 10, 15], proto.repeated_int32)
+
+ # Test single retrieval.
+ self.assertEqual(5, proto.repeated_int32[0])
+ self.assertEqual(15, proto.repeated_int32[-1])
+ # Test out-of-bounds indices.
+ self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
+ self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
+ # Test incorrect types passed to __getitem__.
+ self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
+ self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
+
+ # Test single assignment.
+ proto.repeated_int32[1] = 20
+ self.assertEqual([5, 20, 15], proto.repeated_int32)
+
+ # Test insertion.
+ proto.repeated_int32.insert(1, 25)
+ self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
+
+ # Test slice retrieval.
+ proto.repeated_int32.append(30)
+ self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
+ self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
+
+ # Test slice assignment with an iterator
+ proto.repeated_int32[1:4] = (i for i in xrange(3))
+ self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
+
+ # Test slice assignment.
+ proto.repeated_int32[1:4] = [35, 40, 45]
+ self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
+
+ # Test that we can use the field as an iterator.
+ result = []
+ for i in proto.repeated_int32:
+ result.append(i)
+ self.assertEqual([5, 35, 40, 45, 30], result)
+
+ # Test single deletion.
+ del proto.repeated_int32[2]
+ self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
+
+ # Test slice deletion.
+ del proto.repeated_int32[2:]
+ self.assertEqual([5, 35], proto.repeated_int32)
+
+ # Test clearing.
+ proto.ClearField('repeated_int32')
+ self.assertTrue(not proto.repeated_int32)
+ self.assertEqual(0, len(proto.repeated_int32))
+
+ def testRepeatedScalarsRemove(self):
+ proto = unittest_pb2.TestAllTypes()
+
+ self.assertTrue(not proto.repeated_int32)
+ self.assertEqual(0, len(proto.repeated_int32))
+ proto.repeated_int32.append(5)
+ proto.repeated_int32.append(10)
+ proto.repeated_int32.append(5)
+ proto.repeated_int32.append(5)
+
+ self.assertEqual(4, len(proto.repeated_int32))
+ proto.repeated_int32.remove(5)
+ self.assertEqual(3, len(proto.repeated_int32))
+ self.assertEqual(10, proto.repeated_int32[0])
+ self.assertEqual(5, proto.repeated_int32[1])
+ self.assertEqual(5, proto.repeated_int32[2])
+
+ proto.repeated_int32.remove(5)
+ self.assertEqual(2, len(proto.repeated_int32))
+ self.assertEqual(10, proto.repeated_int32[0])
+ self.assertEqual(5, proto.repeated_int32[1])
+
+ proto.repeated_int32.remove(10)
+ self.assertEqual(1, len(proto.repeated_int32))
+ self.assertEqual(5, proto.repeated_int32[0])
+
+ # Remove a non-existent element.
+ self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
+
+ def testRepeatedComposites(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertTrue(not proto.repeated_nested_message)
+ self.assertEqual(0, len(proto.repeated_nested_message))
+ m0 = proto.repeated_nested_message.add()
+ m1 = proto.repeated_nested_message.add()
+ self.assertTrue(proto.repeated_nested_message)
+ self.assertEqual(2, len(proto.repeated_nested_message))
+ self.assertIs([m0, m1], proto.repeated_nested_message)
+ self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
+
+ # Test out-of-bounds indices.
+ self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
+ 1234)
+ self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
+ -1234)
+
+ # Test incorrect types passed to __getitem__.
+ self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
+ 'foo')
+ self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
+ None)
+
+ # Test slice retrieval.
+ m2 = proto.repeated_nested_message.add()
+ m3 = proto.repeated_nested_message.add()
+ m4 = proto.repeated_nested_message.add()
+ self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
+ self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
+
+ # Test that we can use the field as an iterator.
+ result = []
+ for i in proto.repeated_nested_message:
+ result.append(i)
+ self.assertIs([m0, m1, m2, m3, m4], result)
+
+ # Test single deletion.
+ del proto.repeated_nested_message[2]
+ self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
+
+ # Test slice deletion.
+ del proto.repeated_nested_message[2:]
+ self.assertIs([m0, m1], proto.repeated_nested_message)
+
+ # Test clearing.
+ proto.ClearField('repeated_nested_message')
+ self.assertTrue(not proto.repeated_nested_message)
+ self.assertEqual(0, len(proto.repeated_nested_message))
+
+ def testHandWrittenReflection(self):
+ # TODO(robinson): We probably need a better way to specify
+ # protocol types by hand. But then again, this isn't something
+ # we expect many people to do. Hmm.
+ FieldDescriptor = descriptor.FieldDescriptor
+ foo_field_descriptor = FieldDescriptor(
+ name='foo_field', full_name='MyProto.foo_field',
+ index=0, number=1, type=FieldDescriptor.TYPE_INT64,
+ cpp_type=FieldDescriptor.CPPTYPE_INT64,
+ label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
+ containing_type=None, message_type=None, enum_type=None,
+ is_extension=False, extension_scope=None,
+ options=descriptor_pb2.FieldOptions())
+ mydescriptor = descriptor.Descriptor(
+ name='MyProto', full_name='MyProto', filename='ignored',
+ containing_type=None, nested_types=[], enum_types=[],
+ fields=[foo_field_descriptor], extensions=[],
+ options=descriptor_pb2.MessageOptions())
+ class MyProtoClass(message.Message):
+ DESCRIPTOR = mydescriptor
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ myproto_instance = MyProtoClass()
+ self.assertEqual(0, myproto_instance.foo_field)
+ self.assertTrue(not myproto_instance.HasField('foo_field'))
+ myproto_instance.foo_field = 23
+ self.assertEqual(23, myproto_instance.foo_field)
+ self.assertTrue(myproto_instance.HasField('foo_field'))
+
+ def testTopLevelExtensionsForOptionalScalar(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.optional_int32_extension
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ self.assertEqual(0, extendee_proto.Extensions[extension])
+ # As with normal scalar fields, just doing a read doesn't actually set the
+ # "has" bit.
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ # Actually set the thing.
+ extendee_proto.Extensions[extension] = 23
+ self.assertEqual(23, extendee_proto.Extensions[extension])
+ self.assertTrue(extendee_proto.HasExtension(extension))
+ # Ensure that clearing works as well.
+ extendee_proto.ClearExtension(extension)
+ self.assertEqual(0, extendee_proto.Extensions[extension])
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+
+ def testTopLevelExtensionsForRepeatedScalar(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.repeated_string_extension
+ self.assertEqual(0, len(extendee_proto.Extensions[extension]))
+ extendee_proto.Extensions[extension].append('foo')
+ self.assertEqual(['foo'], extendee_proto.Extensions[extension])
+ string_list = extendee_proto.Extensions[extension]
+ extendee_proto.ClearExtension(extension)
+ self.assertEqual(0, len(extendee_proto.Extensions[extension]))
+ self.assertTrue(string_list is not extendee_proto.Extensions[extension])
+ # Shouldn't be allowed to do Extensions[extension] = 'a'
+ self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
+ extension, 'a')
+
+ def testTopLevelExtensionsForOptionalMessage(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.optional_foreign_message_extension
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ self.assertEqual(0, extendee_proto.Extensions[extension].c)
+ # As with normal (non-extension) fields, merely reading from the
+ # thing shouldn't set the "has" bit.
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ extendee_proto.Extensions[extension].c = 23
+ self.assertEqual(23, extendee_proto.Extensions[extension].c)
+ self.assertTrue(extendee_proto.HasExtension(extension))
+ # Save a reference here.
+ foreign_message = extendee_proto.Extensions[extension]
+ extendee_proto.ClearExtension(extension)
+ self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
+ # Setting a field on foreign_message now shouldn't set
+ # any "has" bits on extendee_proto.
+ foreign_message.c = 42
+ self.assertEqual(42, foreign_message.c)
+ self.assertTrue(foreign_message.HasField('c'))
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ # Shouldn't be allowed to do Extensions[extension] = 'a'
+ self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
+ extension, 'a')
+
+ def testTopLevelExtensionsForRepeatedMessage(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.repeatedgroup_extension
+ self.assertEqual(0, len(extendee_proto.Extensions[extension]))
+ group = extendee_proto.Extensions[extension].add()
+ group.a = 23
+ self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
+ group.a = 42
+ self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
+ group_list = extendee_proto.Extensions[extension]
+ extendee_proto.ClearExtension(extension)
+ self.assertEqual(0, len(extendee_proto.Extensions[extension]))
+ self.assertTrue(group_list is not extendee_proto.Extensions[extension])
+ # Shouldn't be allowed to do Extensions[extension] = 'a'
+ self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
+ extension, 'a')
+
+ def testNestedExtensions(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.TestRequired.single
+
+ # We just test the non-repeated case.
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ required = extendee_proto.Extensions[extension]
+ self.assertEqual(0, required.a)
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+ required.a = 23
+ self.assertEqual(23, extendee_proto.Extensions[extension].a)
+ self.assertTrue(extendee_proto.HasExtension(extension))
+ extendee_proto.ClearExtension(extension)
+ self.assertTrue(required is not extendee_proto.Extensions[extension])
+ self.assertTrue(not extendee_proto.HasExtension(extension))
+
+ # If message A directly contains message B, and
+ # a.HasField('b') is currently False, then mutating any
+ # extension in B should change a.HasField('b') to True
+ # (and so on up the object tree).
+ def testHasBitsForAncestorsOfExtendedMessage(self):
+ # Optional scalar extension.
+ toplevel = more_extensions_pb2.TopLevelMessage()
+ self.assertTrue(not toplevel.HasField('submessage'))
+ self.assertEqual(0, toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+ self.assertTrue(not toplevel.HasField('submessage'))
+ toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension] = 23
+ self.assertEqual(23, toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+ self.assertTrue(toplevel.HasField('submessage'))
+
+ # Repeated scalar extension.
+ toplevel = more_extensions_pb2.TopLevelMessage()
+ self.assertTrue(not toplevel.HasField('submessage'))
+ self.assertEqual([], toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_int_extension])
+ self.assertTrue(not toplevel.HasField('submessage'))
+ toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_int_extension].append(23)
+ self.assertEqual([23], toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_int_extension])
+ self.assertTrue(toplevel.HasField('submessage'))
+
+ # Optional message extension.
+ toplevel = more_extensions_pb2.TopLevelMessage()
+ self.assertTrue(not toplevel.HasField('submessage'))
+ self.assertEqual(0, toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_message_extension].foreign_message_int)
+ self.assertTrue(not toplevel.HasField('submessage'))
+ toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_message_extension].foreign_message_int = 23
+ self.assertEqual(23, toplevel.submessage.Extensions[
+ more_extensions_pb2.optional_message_extension].foreign_message_int)
+ self.assertTrue(toplevel.HasField('submessage'))
+
+ # Repeated message extension.
+ toplevel = more_extensions_pb2.TopLevelMessage()
+ self.assertTrue(not toplevel.HasField('submessage'))
+ self.assertEqual(0, len(toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_message_extension]))
+ self.assertTrue(not toplevel.HasField('submessage'))
+ foreign = toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_message_extension].add()
+ self.assertTrue(foreign is toplevel.submessage.Extensions[
+ more_extensions_pb2.repeated_message_extension][0])
+ self.assertTrue(toplevel.HasField('submessage'))
+
+ def testDisconnectionAfterClearingEmptyMessage(self):
+ toplevel = more_extensions_pb2.TopLevelMessage()
+ extendee_proto = toplevel.submessage
+ extension = more_extensions_pb2.optional_message_extension
+ extension_proto = extendee_proto.Extensions[extension]
+ extendee_proto.ClearExtension(extension)
+ extension_proto.foreign_message_int = 23
+
+ self.assertTrue(not toplevel.HasField('submessage'))
+ self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
+
+ def testExtensionFailureModes(self):
+ extendee_proto = unittest_pb2.TestAllExtensions()
+
+ # Try non-extension-handle arguments to HasExtension,
+ # ClearExtension(), and Extensions[]...
+ self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
+ self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
+
+ # Try something that *is* an extension handle, just not for
+ # this message...
+ unknown_handle = more_extensions_pb2.optional_int_extension
+ self.assertRaises(KeyError, extendee_proto.HasExtension,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.ClearExtension,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
+ unknown_handle)
+ self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
+ unknown_handle, 5)
+
+ # Try call HasExtension() with a valid handle, but for a
+ # *repeated* field. (Just as with non-extension repeated
+ # fields, Has*() isn't supported for extension repeated fields).
+ self.assertRaises(KeyError, extendee_proto.HasExtension,
+ unittest_pb2.repeated_string_extension)
+
+ def testStaticParseFrom(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto1)
+
+ string1 = proto1.SerializeToString()
+ proto2 = unittest_pb2.TestAllTypes.FromString(string1)
+
+ # Messages should be equal.
+ self.assertEqual(proto2, proto1)
+
+ def testMergeFromSingularField(self):
+ # Test merge with just a singular field.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optional_int32 = 1
+
+ proto2 = unittest_pb2.TestAllTypes()
+ # This shouldn't get overwritten.
+ proto2.optional_string = 'value'
+
+ proto2.MergeFrom(proto1)
+ self.assertEqual(1, proto2.optional_int32)
+ self.assertEqual('value', proto2.optional_string)
+
+ def testMergeFromRepeatedField(self):
+ # Test merge with just a repeated field.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.repeated_int32.append(1)
+ proto1.repeated_int32.append(2)
+
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.repeated_int32.append(0)
+ proto2.MergeFrom(proto1)
+
+ self.assertEqual(0, proto2.repeated_int32[0])
+ self.assertEqual(1, proto2.repeated_int32[1])
+ self.assertEqual(2, proto2.repeated_int32[2])
+
+ def testMergeFromOptionalGroup(self):
+ # Test merge with an optional group.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optionalgroup.a = 12
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.MergeFrom(proto1)
+ self.assertEqual(12, proto2.optionalgroup.a)
+
+ def testMergeFromRepeatedNestedMessage(self):
+ # Test merge with a repeated nested message.
+ proto1 = unittest_pb2.TestAllTypes()
+ m = proto1.repeated_nested_message.add()
+ m.bb = 123
+ m = proto1.repeated_nested_message.add()
+ m.bb = 321
+
+ proto2 = unittest_pb2.TestAllTypes()
+ m = proto2.repeated_nested_message.add()
+ m.bb = 999
+ proto2.MergeFrom(proto1)
+ self.assertEqual(999, proto2.repeated_nested_message[0].bb)
+ self.assertEqual(123, proto2.repeated_nested_message[1].bb)
+ self.assertEqual(321, proto2.repeated_nested_message[2].bb)
+
+ def testMergeFromAllFields(self):
+ # With all fields set.
+ proto1 = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto1)
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.MergeFrom(proto1)
+
+ # Messages should be equal.
+ self.assertEqual(proto2, proto1)
+
+ # Serialized string should be equal too.
+ string1 = proto1.SerializeToString()
+ string2 = proto2.SerializeToString()
+ self.assertEqual(string1, string2)
+
+ def testMergeFromExtensionsSingular(self):
+ proto1 = unittest_pb2.TestAllExtensions()
+ proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
+
+ proto2 = unittest_pb2.TestAllExtensions()
+ proto2.MergeFrom(proto1)
+ self.assertEqual(
+ 1, proto2.Extensions[unittest_pb2.optional_int32_extension])
+
+ def testMergeFromExtensionsRepeated(self):
+ proto1 = unittest_pb2.TestAllExtensions()
+ proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
+ proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
+
+ proto2 = unittest_pb2.TestAllExtensions()
+ proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
+ proto2.MergeFrom(proto1)
+ self.assertEqual(
+ 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
+ self.assertEqual(
+ 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
+ self.assertEqual(
+ 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
+ self.assertEqual(
+ 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
+
+ def testMergeFromExtensionsNestedMessage(self):
+ proto1 = unittest_pb2.TestAllExtensions()
+ ext1 = proto1.Extensions[
+ unittest_pb2.repeated_nested_message_extension]
+ m = ext1.add()
+ m.bb = 222
+ m = ext1.add()
+ m.bb = 333
+
+ proto2 = unittest_pb2.TestAllExtensions()
+ ext2 = proto2.Extensions[
+ unittest_pb2.repeated_nested_message_extension]
+ m = ext2.add()
+ m.bb = 111
+
+ proto2.MergeFrom(proto1)
+ ext2 = proto2.Extensions[
+ unittest_pb2.repeated_nested_message_extension]
+ self.assertEqual(3, len(ext2))
+ self.assertEqual(111, ext2[0].bb)
+ self.assertEqual(222, ext2[1].bb)
+ self.assertEqual(333, ext2[2].bb)
+
+ def testCopyFromSingularField(self):
+ # Test copy with just a singular field.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optional_int32 = 1
+ proto1.optional_string = 'important-text'
+
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.optional_string = 'value'
+
+ proto2.CopyFrom(proto1)
+ self.assertEqual(1, proto2.optional_int32)
+ self.assertEqual('important-text', proto2.optional_string)
+
+ def testCopyFromRepeatedField(self):
+ # Test copy with a repeated field.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.repeated_int32.append(1)
+ proto1.repeated_int32.append(2)
+
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.repeated_int32.append(0)
+ proto2.CopyFrom(proto1)
+
+ self.assertEqual(1, proto2.repeated_int32[0])
+ self.assertEqual(2, proto2.repeated_int32[1])
+
+ def testCopyFromAllFields(self):
+ # With all fields set.
+ proto1 = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto1)
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.CopyFrom(proto1)
+
+ # Messages should be equal.
+ self.assertEqual(proto2, proto1)
+
+ # Serialized string should be equal too.
+ string1 = proto1.SerializeToString()
+ string2 = proto2.SerializeToString()
+ self.assertEqual(string1, string2)
+
+ def testCopyFromSelf(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.repeated_int32.append(1)
+ proto1.optional_int32 = 2
+ proto1.optional_string = 'important-text'
+
+ proto1.CopyFrom(proto1)
+ self.assertEqual(1, proto1.repeated_int32[0])
+ self.assertEqual(2, proto1.optional_int32)
+ self.assertEqual('important-text', proto1.optional_string)
+
+ def testClear(self):
+ proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto)
+ # Clear the message.
+ proto.Clear()
+ self.assertEquals(proto.ByteSize(), 0)
+ empty_proto = unittest_pb2.TestAllTypes()
+ self.assertEquals(proto, empty_proto)
+
+ # Test if extensions which were set are cleared.
+ proto = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(proto)
+ # Clear the message.
+ proto.Clear()
+ self.assertEquals(proto.ByteSize(), 0)
+ empty_proto = unittest_pb2.TestAllExtensions()
+ self.assertEquals(proto, empty_proto)
+
+ def testIsInitialized(self):
+ # Trivial cases - all optional fields and extensions.
+ proto = unittest_pb2.TestAllTypes()
+ self.assertTrue(proto.IsInitialized())
+ proto = unittest_pb2.TestAllExtensions()
+ self.assertTrue(proto.IsInitialized())
+
+ # The case of uninitialized required fields.
+ proto = unittest_pb2.TestRequired()
+ self.assertFalse(proto.IsInitialized())
+ proto.a = proto.b = proto.c = 2
+ self.assertTrue(proto.IsInitialized())
+
+ # The case of uninitialized submessage.
+ proto = unittest_pb2.TestRequiredForeign()
+ self.assertTrue(proto.IsInitialized())
+ proto.optional_message.a = 1
+ self.assertFalse(proto.IsInitialized())
+ proto.optional_message.b = 0
+ proto.optional_message.c = 0
+ self.assertTrue(proto.IsInitialized())
+
+ # Uninitialized repeated submessage.
+ message1 = proto.repeated_message.add()
+ self.assertFalse(proto.IsInitialized())
+ message1.a = message1.b = message1.c = 0
+ self.assertTrue(proto.IsInitialized())
+
+ # Uninitialized repeated group in an extension.
+ proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.TestRequired.multi
+ message1 = proto.Extensions[extension].add()
+ message2 = proto.Extensions[extension].add()
+ self.assertFalse(proto.IsInitialized())
+ message1.a = 1
+ message1.b = 1
+ message1.c = 1
+ self.assertFalse(proto.IsInitialized())
+ message2.a = 2
+ message2.b = 2
+ message2.c = 2
+ self.assertTrue(proto.IsInitialized())
+
+ # Uninitialized nonrepeated message in an extension.
+ proto = unittest_pb2.TestAllExtensions()
+ extension = unittest_pb2.TestRequired.single
+ proto.Extensions[extension].a = 1
+ self.assertFalse(proto.IsInitialized())
+ proto.Extensions[extension].b = 2
+ proto.Extensions[extension].c = 3
+ self.assertTrue(proto.IsInitialized())
+
+ def testStringUTF8Encoding(self):
+ proto = unittest_pb2.TestAllTypes()
+
+ # Assignment of a unicode object to a field of type 'bytes' is not allowed.
+ self.assertRaises(TypeError,
+ setattr, proto, 'optional_bytes', u'unicode object')
+
+ # Check that the default value is of python's 'unicode' type.
+ self.assertEqual(type(proto.optional_string), unicode)
+
+ proto.optional_string = unicode('Testing')
+ self.assertEqual(proto.optional_string, str('Testing'))
+
+ # Assign a value of type 'str' which can be encoded in UTF-8.
+ proto.optional_string = str('Testing')
+ self.assertEqual(proto.optional_string, unicode('Testing'))
+
+ # Values of type 'str' are also accepted as long as they can be encoded in
+ # UTF-8.
+ self.assertEqual(type(proto.optional_string), str)
+
+ # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
+ self.assertRaises(ValueError,
+ setattr, proto, 'optional_string', str('a\x80a'))
+ # Assign a 'str' object which contains a UTF-8 encoded string.
+ self.assertRaises(ValueError,
+ setattr, proto, 'optional_string', 'Тест')
+ # No exception thrown.
+ proto.optional_string = 'abc'
+
+ def testStringUTF8Serialization(self):
+ proto = unittest_mset_pb2.TestMessageSet()
+ extension_message = unittest_mset_pb2.TestMessageSetExtension2
+ extension = extension_message.message_set_extension
+
+ test_utf8 = u'Тест'
+ test_utf8_bytes = test_utf8.encode('utf-8')
+
+ # 'Test' in another language, using UTF-8 charset.
+ proto.Extensions[extension].str = test_utf8
+
+ # Serialize using the MessageSet wire format (this is specified in the
+ # .proto file).
+ serialized = proto.SerializeToString()
+
+ # Check byte size.
+ self.assertEqual(proto.ByteSize(), len(serialized))
+
+ raw = unittest_mset_pb2.RawMessageSet()
+ raw.MergeFromString(serialized)
+
+ message2 = unittest_mset_pb2.TestMessageSetExtension2()
+
+ self.assertEqual(1, len(raw.item))
+ # Check that the type_id is the same as the tag ID in the .proto file.
+ self.assertEqual(raw.item[0].type_id, 1547769)
+
+ # Check the actually bytes on the wire.
+ self.assertTrue(
+ raw.item[0].message.endswith(test_utf8_bytes))
+ message2.MergeFromString(raw.item[0].message)
+
+ self.assertEqual(type(message2.str), unicode)
+ self.assertEqual(message2.str, test_utf8)
+
+ # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
+ bytes = raw.item[0].message.replace(
+ test_utf8_bytes, len(test_utf8_bytes) * '\xff')
+ self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
+
+
+# Since we had so many tests for protocol buffer equality, we broke these out
+# into separate TestCase classes.
+
+
+class TestAllTypesEqualityTest(unittest.TestCase):
+
+ def setUp(self):
+ self.first_proto = unittest_pb2.TestAllTypes()
+ self.second_proto = unittest_pb2.TestAllTypes()
+
+ def testSelfEquality(self):
+ self.assertEqual(self.first_proto, self.first_proto)
+
+ def testEmptyProtosEqual(self):
+ self.assertEqual(self.first_proto, self.second_proto)
+
+
+class FullProtosEqualityTest(unittest.TestCase):
+
+ """Equality tests using completely-full protos as a starting point."""
+
+ def setUp(self):
+ self.first_proto = unittest_pb2.TestAllTypes()
+ self.second_proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(self.first_proto)
+ test_util.SetAllFields(self.second_proto)
+
+ def testNoneNotEqual(self):
+ self.assertNotEqual(self.first_proto, None)
+ self.assertNotEqual(None, self.second_proto)
+
+ def testNotEqualToOtherMessage(self):
+ third_proto = unittest_pb2.TestRequired()
+ self.assertNotEqual(self.first_proto, third_proto)
+ self.assertNotEqual(third_proto, self.second_proto)
+
+ def testAllFieldsFilledEquality(self):
+ self.assertEqual(self.first_proto, self.second_proto)
+
+ def testNonRepeatedScalar(self):
+ # Nonrepeated scalar field change should cause inequality.
+ self.first_proto.optional_int32 += 1
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ # ...as should clearing a field.
+ self.first_proto.ClearField('optional_int32')
+ self.assertNotEqual(self.first_proto, self.second_proto)
+
+ def testNonRepeatedComposite(self):
+ # Change a nonrepeated composite field.
+ self.first_proto.optional_nested_message.bb += 1
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ self.first_proto.optional_nested_message.bb -= 1
+ self.assertEqual(self.first_proto, self.second_proto)
+ # Clear a field in the nested message.
+ self.first_proto.optional_nested_message.ClearField('bb')
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ self.first_proto.optional_nested_message.bb = (
+ self.second_proto.optional_nested_message.bb)
+ self.assertEqual(self.first_proto, self.second_proto)
+ # Remove the nested message entirely.
+ self.first_proto.ClearField('optional_nested_message')
+ self.assertNotEqual(self.first_proto, self.second_proto)
+
+ def testRepeatedScalar(self):
+ # Change a repeated scalar field.
+ self.first_proto.repeated_int32.append(5)
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ self.first_proto.ClearField('repeated_int32')
+ self.assertNotEqual(self.first_proto, self.second_proto)
+
+ def testRepeatedComposite(self):
+ # Change value within a repeated composite field.
+ self.first_proto.repeated_nested_message[0].bb += 1
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ self.first_proto.repeated_nested_message[0].bb -= 1
+ self.assertEqual(self.first_proto, self.second_proto)
+ # Add a value to a repeated composite field.
+ self.first_proto.repeated_nested_message.add()
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ self.second_proto.repeated_nested_message.add()
+ self.assertEqual(self.first_proto, self.second_proto)
+
+ def testNonRepeatedScalarHasBits(self):
+ # Ensure that we test "has" bits as well as value for
+ # nonrepeated scalar field.
+ self.first_proto.ClearField('optional_int32')
+ self.second_proto.optional_int32 = 0
+ self.assertNotEqual(self.first_proto, self.second_proto)
+
+ def testNonRepeatedCompositeHasBits(self):
+ # Ensure that we test "has" bits as well as value for
+ # nonrepeated composite field.
+ self.first_proto.ClearField('optional_nested_message')
+ self.second_proto.optional_nested_message.ClearField('bb')
+ self.assertNotEqual(self.first_proto, self.second_proto)
+ # TODO(robinson): Replace next two lines with method
+ # to set the "has" bit without changing the value,
+ # if/when such a method exists.
+ self.first_proto.optional_nested_message.bb = 0
+ self.first_proto.optional_nested_message.ClearField('bb')
+ self.assertEqual(self.first_proto, self.second_proto)
+
+
+class ExtensionEqualityTest(unittest.TestCase):
+
+ def testExtensionEquality(self):
+ first_proto = unittest_pb2.TestAllExtensions()
+ second_proto = unittest_pb2.TestAllExtensions()
+ self.assertEqual(first_proto, second_proto)
+ test_util.SetAllExtensions(first_proto)
+ self.assertNotEqual(first_proto, second_proto)
+ test_util.SetAllExtensions(second_proto)
+ self.assertEqual(first_proto, second_proto)
+
+ # Ensure that we check value equality.
+ first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
+ self.assertNotEqual(first_proto, second_proto)
+ first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
+ self.assertEqual(first_proto, second_proto)
+
+ # Ensure that we also look at "has" bits.
+ first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
+ second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
+ self.assertNotEqual(first_proto, second_proto)
+ first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
+ self.assertEqual(first_proto, second_proto)
+
+ # Ensure that differences in cached values
+ # don't matter if "has" bits are both false.
+ first_proto = unittest_pb2.TestAllExtensions()
+ second_proto = unittest_pb2.TestAllExtensions()
+ self.assertEqual(
+ 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
+ self.assertEqual(first_proto, second_proto)
+
+
+class MutualRecursionEqualityTest(unittest.TestCase):
+
+ def testEqualityWithMutualRecursion(self):
+ first_proto = unittest_pb2.TestMutualRecursionA()
+ second_proto = unittest_pb2.TestMutualRecursionA()
+ self.assertEqual(first_proto, second_proto)
+ first_proto.bb.a.bb.optional_int32 = 23
+ self.assertNotEqual(first_proto, second_proto)
+ second_proto.bb.a.bb.optional_int32 = 23
+ self.assertEqual(first_proto, second_proto)
+
+
+class ByteSizeTest(unittest.TestCase):
+
+ def setUp(self):
+ self.proto = unittest_pb2.TestAllTypes()
+ self.extended_proto = more_extensions_pb2.ExtendedMessage()
+ self.packed_proto = unittest_pb2.TestPackedTypes()
+ self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
+
+ def Size(self):
+ return self.proto.ByteSize()
+
+ def testEmptyMessage(self):
+ self.assertEqual(0, self.proto.ByteSize())
+
+ def testVarints(self):
+ def Test(i, expected_varint_size):
+ self.proto.Clear()
+ self.proto.optional_int64 = i
+ # Add one to the varint size for the tag info
+ # for tag 1.
+ self.assertEqual(expected_varint_size + 1, self.Size())
+ Test(0, 1)
+ Test(1, 1)
+ for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
+ Test((1 << i) - 1, num_bytes)
+ Test(-1, 10)
+ Test(-2, 10)
+ Test(-(1 << 63), 10)
+
+ def testStrings(self):
+ self.proto.optional_string = ''
+ # Need one byte for tag info (tag #14), and one byte for length.
+ self.assertEqual(2, self.Size())
+
+ self.proto.optional_string = 'abc'
+ # Need one byte for tag info (tag #14), and one byte for length.
+ self.assertEqual(2 + len(self.proto.optional_string), self.Size())
+
+ self.proto.optional_string = 'x' * 128
+ # Need one byte for tag info (tag #14), and TWO bytes for length.
+ self.assertEqual(3 + len(self.proto.optional_string), self.Size())
+
+ def testOtherNumerics(self):
+ self.proto.optional_fixed32 = 1234
+ # One byte for tag and 4 bytes for fixed32.
+ self.assertEqual(5, self.Size())
+ self.proto = unittest_pb2.TestAllTypes()
+
+ self.proto.optional_fixed64 = 1234
+ # One byte for tag and 8 bytes for fixed64.
+ self.assertEqual(9, self.Size())
+ self.proto = unittest_pb2.TestAllTypes()
+
+ self.proto.optional_float = 1.234
+ # One byte for tag and 4 bytes for float.
+ self.assertEqual(5, self.Size())
+ self.proto = unittest_pb2.TestAllTypes()
+
+ self.proto.optional_double = 1.234
+ # One byte for tag and 8 bytes for float.
+ self.assertEqual(9, self.Size())
+ self.proto = unittest_pb2.TestAllTypes()
+
+ self.proto.optional_sint32 = 64
+ # One byte for tag and 2 bytes for zig-zag-encoded 64.
+ self.assertEqual(3, self.Size())
+ self.proto = unittest_pb2.TestAllTypes()
+
+ def testComposites(self):
+ # 3 bytes.
+ self.proto.optional_nested_message.bb = (1 << 14)
+ # Plus one byte for bb tag.
+ # Plus 1 byte for optional_nested_message serialized size.
+ # Plus two bytes for optional_nested_message tag.
+ self.assertEqual(3 + 1 + 1 + 2, self.Size())
+
+ def testGroups(self):
+ # 4 bytes.
+ self.proto.optionalgroup.a = (1 << 21)
+ # Plus two bytes for |a| tag.
+ # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
+ self.assertEqual(4 + 2 + 2*2, self.Size())
+
+ def testRepeatedScalars(self):
+ self.proto.repeated_int32.append(10) # 1 byte.
+ self.proto.repeated_int32.append(128) # 2 bytes.
+ # Also need 2 bytes for each entry for tag.
+ self.assertEqual(1 + 2 + 2*2, self.Size())
+
+ def testRepeatedScalarsExtend(self):
+ self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
+ # Also need 2 bytes for each entry for tag.
+ self.assertEqual(1 + 2 + 2*2, self.Size())
+
+ def testRepeatedScalarsRemove(self):
+ self.proto.repeated_int32.append(10) # 1 byte.
+ self.proto.repeated_int32.append(128) # 2 bytes.
+ # Also need 2 bytes for each entry for tag.
+ self.assertEqual(1 + 2 + 2*2, self.Size())
+ self.proto.repeated_int32.remove(128)
+ self.assertEqual(1 + 2, self.Size())
+
+ def testRepeatedComposites(self):
+ # Empty message. 2 bytes tag plus 1 byte length.
+ foreign_message_0 = self.proto.repeated_nested_message.add()
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ foreign_message_1 = self.proto.repeated_nested_message.add()
+ foreign_message_1.bb = 7
+ self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
+
+ def testRepeatedCompositesDelete(self):
+ # Empty message. 2 bytes tag plus 1 byte length.
+ foreign_message_0 = self.proto.repeated_nested_message.add()
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ foreign_message_1 = self.proto.repeated_nested_message.add()
+ foreign_message_1.bb = 9
+ self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
+
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ del self.proto.repeated_nested_message[0]
+ self.assertEqual(2 + 1 + 1 + 1, self.Size())
+
+ # Now add a new message.
+ foreign_message_2 = self.proto.repeated_nested_message.add()
+ foreign_message_2.bb = 12
+
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
+
+ # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
+ del self.proto.repeated_nested_message[1]
+ self.assertEqual(2 + 1 + 1 + 1, self.Size())
+
+ del self.proto.repeated_nested_message[0]
+ self.assertEqual(0, self.Size())
+
+ def testRepeatedGroups(self):
+ # 2-byte START_GROUP plus 2-byte END_GROUP.
+ group_0 = self.proto.repeatedgroup.add()
+ # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
+ # plus 2-byte END_GROUP.
+ group_1 = self.proto.repeatedgroup.add()
+ group_1.a = 7
+ self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
+
+ def testExtensions(self):
+ proto = unittest_pb2.TestAllExtensions()
+ self.assertEqual(0, proto.ByteSize())
+ extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
+ proto.Extensions[extension] = 23
+ # 1 byte for tag, 1 byte for value.
+ self.assertEqual(2, proto.ByteSize())
+
+ def testCacheInvalidationForNonrepeatedScalar(self):
+ # Test non-extension.
+ self.proto.optional_int32 = 1
+ self.assertEqual(2, self.proto.ByteSize())
+ self.proto.optional_int32 = 128
+ self.assertEqual(3, self.proto.ByteSize())
+ self.proto.ClearField('optional_int32')
+ self.assertEqual(0, self.proto.ByteSize())
+
+ # Test within extension.
+ extension = more_extensions_pb2.optional_int_extension
+ self.extended_proto.Extensions[extension] = 1
+ self.assertEqual(2, self.extended_proto.ByteSize())
+ self.extended_proto.Extensions[extension] = 128
+ self.assertEqual(3, self.extended_proto.ByteSize())
+ self.extended_proto.ClearExtension(extension)
+ self.assertEqual(0, self.extended_proto.ByteSize())
+
+ def testCacheInvalidationForRepeatedScalar(self):
+ # Test non-extension.
+ self.proto.repeated_int32.append(1)
+ self.assertEqual(3, self.proto.ByteSize())
+ self.proto.repeated_int32.append(1)
+ self.assertEqual(6, self.proto.ByteSize())
+ self.proto.repeated_int32[1] = 128
+ self.assertEqual(7, self.proto.ByteSize())
+ self.proto.ClearField('repeated_int32')
+ self.assertEqual(0, self.proto.ByteSize())
+
+ # Test within extension.
+ extension = more_extensions_pb2.repeated_int_extension
+ repeated = self.extended_proto.Extensions[extension]
+ repeated.append(1)
+ self.assertEqual(2, self.extended_proto.ByteSize())
+ repeated.append(1)
+ self.assertEqual(4, self.extended_proto.ByteSize())
+ repeated[1] = 128
+ self.assertEqual(5, self.extended_proto.ByteSize())
+ self.extended_proto.ClearExtension(extension)
+ self.assertEqual(0, self.extended_proto.ByteSize())
+
+ def testCacheInvalidationForNonrepeatedMessage(self):
+ # Test non-extension.
+ self.proto.optional_foreign_message.c = 1
+ self.assertEqual(5, self.proto.ByteSize())
+ self.proto.optional_foreign_message.c = 128
+ self.assertEqual(6, self.proto.ByteSize())
+ self.proto.optional_foreign_message.ClearField('c')
+ self.assertEqual(3, self.proto.ByteSize())
+ self.proto.ClearField('optional_foreign_message')
+ self.assertEqual(0, self.proto.ByteSize())
+ child = self.proto.optional_foreign_message
+ self.proto.ClearField('optional_foreign_message')
+ child.c = 128
+ self.assertEqual(0, self.proto.ByteSize())
+
+ # Test within extension.
+ extension = more_extensions_pb2.optional_message_extension
+ child = self.extended_proto.Extensions[extension]
+ self.assertEqual(0, self.extended_proto.ByteSize())
+ child.foreign_message_int = 1
+ self.assertEqual(4, self.extended_proto.ByteSize())
+ child.foreign_message_int = 128
+ self.assertEqual(5, self.extended_proto.ByteSize())
+ self.extended_proto.ClearExtension(extension)
+ self.assertEqual(0, self.extended_proto.ByteSize())
+
+ def testCacheInvalidationForRepeatedMessage(self):
+ # Test non-extension.
+ child0 = self.proto.repeated_foreign_message.add()
+ self.assertEqual(3, self.proto.ByteSize())
+ self.proto.repeated_foreign_message.add()
+ self.assertEqual(6, self.proto.ByteSize())
+ child0.c = 1
+ self.assertEqual(8, self.proto.ByteSize())
+ self.proto.ClearField('repeated_foreign_message')
+ self.assertEqual(0, self.proto.ByteSize())
+
+ # Test within extension.
+ extension = more_extensions_pb2.repeated_message_extension
+ child_list = self.extended_proto.Extensions[extension]
+ child0 = child_list.add()
+ self.assertEqual(2, self.extended_proto.ByteSize())
+ child_list.add()
+ self.assertEqual(4, self.extended_proto.ByteSize())
+ child0.foreign_message_int = 1
+ self.assertEqual(6, self.extended_proto.ByteSize())
+ child0.ClearField('foreign_message_int')
+ self.assertEqual(4, self.extended_proto.ByteSize())
+ self.extended_proto.ClearExtension(extension)
+ self.assertEqual(0, self.extended_proto.ByteSize())
+
+ def testPackedRepeatedScalars(self):
+ self.assertEqual(0, self.packed_proto.ByteSize())
+
+ self.packed_proto.packed_int32.append(10) # 1 byte.
+ self.packed_proto.packed_int32.append(128) # 2 bytes.
+ # The tag is 2 bytes (the field number is 90), and the varint
+ # storing the length is 1 byte.
+ int_size = 1 + 2 + 3
+ self.assertEqual(int_size, self.packed_proto.ByteSize())
+
+ self.packed_proto.packed_double.append(4.2) # 8 bytes
+ self.packed_proto.packed_double.append(3.25) # 8 bytes
+ # 2 more tag bytes, 1 more length byte.
+ double_size = 8 + 8 + 3
+ self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
+
+ self.packed_proto.ClearField('packed_int32')
+ self.assertEqual(double_size, self.packed_proto.ByteSize())
+
+ def testPackedExtensions(self):
+ self.assertEqual(0, self.packed_extended_proto.ByteSize())
+ extension = self.packed_extended_proto.Extensions[
+ unittest_pb2.packed_fixed32_extension]
+ extension.extend([1, 2, 3, 4]) # 16 bytes
+ # Tag is 3 bytes.
+ self.assertEqual(19, self.packed_extended_proto.ByteSize())
+
+
+# TODO(robinson): We need cross-language serialization consistency tests.
+# Issues to be sure to cover include:
+# * Handling of unrecognized tags ("uninterpreted_bytes").
+# * Handling of MessageSets.
+# * Consistent ordering of tags in the wire format,
+# including ordering between extensions and non-extension
+# fields.
+# * Consistent serialization of negative numbers, especially
+# negative int32s.
+# * Handling of empty submessages (with and without "has"
+# bits set).
+
+class SerializationTest(unittest.TestCase):
+
+ def testSerializeEmtpyMessage(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ second_proto = unittest_pb2.TestAllTypes()
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(first_proto.ByteSize(), len(serialized))
+ second_proto.MergeFromString(serialized)
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeAllFields(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ second_proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(first_proto)
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(first_proto.ByteSize(), len(serialized))
+ second_proto.MergeFromString(serialized)
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeAllExtensions(self):
+ first_proto = unittest_pb2.TestAllExtensions()
+ second_proto = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(first_proto)
+ serialized = first_proto.SerializeToString()
+ second_proto.MergeFromString(serialized)
+ self.assertEqual(first_proto, second_proto)
+
+ def testCanonicalSerializationOrder(self):
+ proto = more_messages_pb2.OutOfOrderFields()
+ # These are also their tag numbers. Even though we're setting these in
+ # reverse-tag order AND they're listed in reverse tag-order in the .proto
+ # file, they should nonetheless be serialized in tag order.
+ proto.optional_sint32 = 5
+ proto.Extensions[more_messages_pb2.optional_uint64] = 4
+ proto.optional_uint32 = 3
+ proto.Extensions[more_messages_pb2.optional_int64] = 2
+ proto.optional_int32 = 1
+ serialized = proto.SerializeToString()
+ self.assertEqual(proto.ByteSize(), len(serialized))
+ d = decoder.Decoder(serialized)
+ ReadTag = d.ReadFieldNumberAndWireType
+ self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
+ self.assertEqual(1, d.ReadInt32())
+ self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
+ self.assertEqual(2, d.ReadInt64())
+ self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
+ self.assertEqual(3, d.ReadUInt32())
+ self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
+ self.assertEqual(4, d.ReadUInt64())
+ self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
+ self.assertEqual(5, d.ReadSInt32())
+
+ def testCanonicalSerializationOrderSameAsCpp(self):
+ # Copy of the same test we use for C++.
+ proto = unittest_pb2.TestFieldOrderings()
+ test_util.SetAllFieldsAndExtensions(proto)
+ serialized = proto.SerializeToString()
+ test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
+
+ def testMergeFromStringWhenFieldsAlreadySet(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ first_proto.repeated_string.append('foobar')
+ first_proto.optional_int32 = 23
+ first_proto.optional_nested_message.bb = 42
+ serialized = first_proto.SerializeToString()
+
+ second_proto = unittest_pb2.TestAllTypes()
+ second_proto.repeated_string.append('baz')
+ second_proto.optional_int32 = 100
+ second_proto.optional_nested_message.bb = 999
+
+ second_proto.MergeFromString(serialized)
+ # Ensure that we append to repeated fields.
+ self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
+ # Ensure that we overwrite nonrepeatd scalars.
+ self.assertEqual(23, second_proto.optional_int32)
+ # Ensure that we recursively call MergeFromString() on
+ # submessages.
+ self.assertEqual(42, second_proto.optional_nested_message.bb)
+
+ def testMessageSetWireFormat(self):
+ proto = unittest_mset_pb2.TestMessageSet()
+ extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
+ extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
+ extension1 = extension_message1.message_set_extension
+ extension2 = extension_message2.message_set_extension
+ proto.Extensions[extension1].i = 123
+ proto.Extensions[extension2].str = 'foo'
+
+ # Serialize using the MessageSet wire format (this is specified in the
+ # .proto file).
+ serialized = proto.SerializeToString()
+
+ raw = unittest_mset_pb2.RawMessageSet()
+ self.assertEqual(False,
+ raw.DESCRIPTOR.GetOptions().message_set_wire_format)
+ raw.MergeFromString(serialized)
+ self.assertEqual(2, len(raw.item))
+
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.MergeFromString(raw.item[0].message)
+ self.assertEqual(123, message1.i)
+
+ message2 = unittest_mset_pb2.TestMessageSetExtension2()
+ message2.MergeFromString(raw.item[1].message)
+ self.assertEqual('foo', message2.str)
+
+ # Deserialize using the MessageSet wire format.
+ proto2 = unittest_mset_pb2.TestMessageSet()
+ proto2.MergeFromString(serialized)
+ self.assertEqual(123, proto2.Extensions[extension1].i)
+ self.assertEqual('foo', proto2.Extensions[extension2].str)
+
+ # Check byte size.
+ self.assertEqual(proto2.ByteSize(), len(serialized))
+ self.assertEqual(proto.ByteSize(), len(serialized))
+
+ def testMessageSetWireFormatUnknownExtension(self):
+ # Create a message using the message set wire format with an unknown
+ # message.
+ raw = unittest_mset_pb2.RawMessageSet()
+
+ # Add an item.
+ item = raw.item.add()
+ item.type_id = 1545008
+ extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.i = 12345
+ item.message = message1.SerializeToString()
+
+ # Add a second, unknown extension.
+ item = raw.item.add()
+ item.type_id = 1545009
+ extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
+ message1 = unittest_mset_pb2.TestMessageSetExtension1()
+ message1.i = 12346
+ item.message = message1.SerializeToString()
+
+ # Add another unknown extension.
+ item = raw.item.add()
+ item.type_id = 1545010
+ message1 = unittest_mset_pb2.TestMessageSetExtension2()
+ message1.str = 'foo'
+ item.message = message1.SerializeToString()
+
+ serialized = raw.SerializeToString()
+
+ # Parse message using the message set wire format.
+ proto = unittest_mset_pb2.TestMessageSet()
+ proto.MergeFromString(serialized)
+
+ # Check that the message parsed well.
+ extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
+ extension1 = extension_message1.message_set_extension
+ self.assertEquals(12345, proto.Extensions[extension1].i)
+
+ def testUnknownFields(self):
+ proto = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto)
+
+ serialized = proto.SerializeToString()
+
+ # The empty message should be parsable with all of the fields
+ # unknown.
+ proto2 = unittest_pb2.TestEmptyMessage()
+
+ # Parsing this message should succeed.
+ proto2.MergeFromString(serialized)
+
+ # Now test with a int64 field set.
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_int64 = 0x0fffffffffffffff
+ serialized = proto.SerializeToString()
+ # The empty message should be parsable with all of the fields
+ # unknown.
+ proto2 = unittest_pb2.TestEmptyMessage()
+ # Parsing this message should succeed.
+ proto2.MergeFromString(serialized)
+
+ def _CheckRaises(self, exc_class, callable_obj, exception):
+ """This method checks if the excpetion type and message are as expected."""
+ try:
+ callable_obj()
+ except exc_class, ex:
+ # Check if the exception message is the right one.
+ self.assertEqual(exception, str(ex))
+ return
+ else:
+ raise self.failureException('%s not raised' % str(exc_class))
+
+ def testSerializeUninitialized(self):
+ proto = unittest_pb2.TestRequired()
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Required field protobuf_unittest.TestRequired.a is not set.')
+ # Shouldn't raise exceptions.
+ partial = proto.SerializePartialToString()
+
+ proto.a = 1
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Required field protobuf_unittest.TestRequired.b is not set.')
+ # Shouldn't raise exceptions.
+ partial = proto.SerializePartialToString()
+
+ proto.b = 2
+ self._CheckRaises(
+ message.EncodeError,
+ proto.SerializeToString,
+ 'Required field protobuf_unittest.TestRequired.c is not set.')
+ # Shouldn't raise exceptions.
+ partial = proto.SerializePartialToString()
+
+ proto.c = 3
+ serialized = proto.SerializeToString()
+ # Shouldn't raise exceptions.
+ partial = proto.SerializePartialToString()
+
+ proto2 = unittest_pb2.TestRequired()
+ proto2.MergeFromString(serialized)
+ self.assertEqual(1, proto2.a)
+ self.assertEqual(2, proto2.b)
+ self.assertEqual(3, proto2.c)
+ proto2.ParseFromString(partial)
+ self.assertEqual(1, proto2.a)
+ self.assertEqual(2, proto2.b)
+ self.assertEqual(3, proto2.c)
+
+ def testSerializeAllPackedFields(self):
+ first_proto = unittest_pb2.TestPackedTypes()
+ second_proto = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(first_proto)
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(first_proto.ByteSize(), len(serialized))
+ bytes_read = second_proto.MergeFromString(serialized)
+ self.assertEqual(second_proto.ByteSize(), bytes_read)
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeAllPackedExtensions(self):
+ first_proto = unittest_pb2.TestPackedExtensions()
+ second_proto = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(first_proto)
+ serialized = first_proto.SerializeToString()
+ bytes_read = second_proto.MergeFromString(serialized)
+ self.assertEqual(second_proto.ByteSize(), bytes_read)
+ self.assertEqual(first_proto, second_proto)
+
+ def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
+ first_proto = unittest_pb2.TestPackedTypes()
+ first_proto.packed_int32.extend([1, 2])
+ first_proto.packed_double.append(3.0)
+ serialized = first_proto.SerializeToString()
+
+ second_proto = unittest_pb2.TestPackedTypes()
+ second_proto.packed_int32.append(3)
+ second_proto.packed_double.extend([1.0, 2.0])
+ second_proto.packed_sint32.append(4)
+
+ second_proto.MergeFromString(serialized)
+ self.assertEqual([3, 1, 2], second_proto.packed_int32)
+ self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
+ self.assertEqual([4], second_proto.packed_sint32)
+
+ def testPackedFieldsWireFormat(self):
+ proto = unittest_pb2.TestPackedTypes()
+ proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
+ proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
+ proto.packed_float.append(2.0) # 4 bytes, will be before double
+ serialized = proto.SerializeToString()
+ self.assertEqual(proto.ByteSize(), len(serialized))
+ d = decoder.Decoder(serialized)
+ ReadTag = d.ReadFieldNumberAndWireType
+ self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(1+1+1+2, d.ReadInt32())
+ self.assertEqual(1, d.ReadInt32())
+ self.assertEqual(2, d.ReadInt32())
+ self.assertEqual(150, d.ReadInt32())
+ self.assertEqual(3, d.ReadInt32())
+ self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(4, d.ReadInt32())
+ self.assertEqual(2.0, d.ReadFloat())
+ self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(8+8, d.ReadInt32())
+ self.assertEqual(1.0, d.ReadDouble())
+ self.assertEqual(1000.0, d.ReadDouble())
+ self.assertTrue(d.EndOfStream())
+
+ def testFieldNumbers(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
+ self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
+ self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
+ self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
+ self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
+ self.assertEqual(
+ unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
+
+ def testExtensionFieldNumbers(self):
+ self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
+ self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
+ self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
+ self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
+ self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
+ self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
+ self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
+ self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
+ self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
+ self.assertEqual(
+ unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
+ self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
+ self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
+ 21)
+ self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
+ self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
+ self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
+ self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
+ self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
+ self.assertEqual(
+ unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
+ self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
+ self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
+ 51)
+
+ def testInitKwargs(self):
+ proto = unittest_pb2.TestAllTypes(
+ optional_int32=1,
+ optional_string='foo',
+ optional_bool=True,
+ optional_bytes='bar',
+ optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
+ optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
+ optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
+ optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
+ repeated_int32=[1, 2, 3])
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('optional_int32'))
+ self.assertTrue(proto.HasField('optional_string'))
+ self.assertTrue(proto.HasField('optional_bool'))
+ self.assertTrue(proto.HasField('optional_bytes'))
+ self.assertTrue(proto.HasField('optional_nested_message'))
+ self.assertTrue(proto.HasField('optional_foreign_message'))
+ self.assertTrue(proto.HasField('optional_nested_enum'))
+ self.assertTrue(proto.HasField('optional_foreign_enum'))
+ self.assertEqual(1, proto.optional_int32)
+ self.assertEqual('foo', proto.optional_string)
+ self.assertEqual(True, proto.optional_bool)
+ self.assertEqual('bar', proto.optional_bytes)
+ self.assertEqual(1, proto.optional_nested_message.bb)
+ self.assertEqual(1, proto.optional_foreign_message.c)
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO,
+ proto.optional_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
+ self.assertEqual([1, 2, 3], proto.repeated_int32)
+
+ def testInitArgsUnknownFieldName(self):
+ def InitalizeEmptyMessageWithExtraKeywordArg():
+ unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
+ self._CheckRaises(ValueError,
+ InitalizeEmptyMessageWithExtraKeywordArg,
+ 'Protocol message has no "unknown" field.')
+
+ def testInitRequiredKwargs(self):
+ proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('a'))
+ self.assertTrue(proto.HasField('b'))
+ self.assertTrue(proto.HasField('c'))
+ self.assertTrue(not proto.HasField('dummy2'))
+ self.assertEqual(1, proto.a)
+ self.assertEqual(1, proto.b)
+ self.assertEqual(1, proto.c)
+
+ def testInitRequiredForeignKwargs(self):
+ proto = unittest_pb2.TestRequiredForeign(
+ optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
+ self.assertTrue(proto.IsInitialized())
+ self.assertTrue(proto.HasField('optional_message'))
+ self.assertTrue(proto.optional_message.IsInitialized())
+ self.assertTrue(proto.optional_message.HasField('a'))
+ self.assertTrue(proto.optional_message.HasField('b'))
+ self.assertTrue(proto.optional_message.HasField('c'))
+ self.assertTrue(not proto.optional_message.HasField('dummy2'))
+ self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
+ proto.optional_message)
+ self.assertEqual(1, proto.optional_message.a)
+ self.assertEqual(1, proto.optional_message.b)
+ self.assertEqual(1, proto.optional_message.c)
+
+ def testInitRepeatedKwargs(self):
+ proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
+ self.assertTrue(proto.IsInitialized())
+ self.assertEqual(1, proto.repeated_int32[0])
+ self.assertEqual(2, proto.repeated_int32[1])
+ self.assertEqual(3, proto.repeated_int32[2])
+
+
+class OptionsTest(unittest.TestCase):
+
+ def testMessageOptions(self):
+ proto = unittest_mset_pb2.TestMessageSet()
+ self.assertEqual(True,
+ proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(False,
+ proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+
+ def testPackedOptions(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_int32 = 1
+ proto.optional_double = 3.0
+ for field_descriptor, _ in proto.ListFields():
+ self.assertEqual(False, field_descriptor.GetOptions().packed)
+
+ proto = unittest_pb2.TestPackedTypes()
+ proto.packed_int32.append(1)
+ proto.packed_double.append(3.0)
+ for field_descriptor, _ in proto.ListFields():
+ self.assertEqual(True, field_descriptor.GetOptions().packed)
+ self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
+ field_descriptor.label)
+
+
+class UtilityTest(unittest.TestCase):
+
+ def testImergeSorted(self):
+ ImergeSorted = reflection._ImergeSorted
+ # Various types of emptiness.
+ self.assertEqual([], list(ImergeSorted()))
+ self.assertEqual([], list(ImergeSorted([])))
+ self.assertEqual([], list(ImergeSorted([], [])))
+
+ # One nonempty list.
+ self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
+ self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
+ self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
+
+ # Merging some nonempty lists together.
+ self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
+ self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
+ self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
+
+ # Elements repeated across component iterators.
+ self.assertEqual([1, 2, 2, 3, 3],
+ list(ImergeSorted([1, 2], [3], [2, 3])))
+
+ # Elements repeated within an iterator.
+ self.assertEqual([1, 2, 2, 3, 3],
+ list(ImergeSorted([1, 2, 2], [3], [3])))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
new file mode 100755
index 0000000..e04f825
--- /dev/null
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -0,0 +1,136 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests for google.protobuf.internal.service_reflection."""
+
+__author__ = 'petar@google.com (Petar Petrov)'
+
+import unittest
+from google.protobuf import unittest_pb2
+from google.protobuf import service_reflection
+from google.protobuf import service
+
+
+class FooUnitTest(unittest.TestCase):
+
+ def testService(self):
+ class MockRpcChannel(service.RpcChannel):
+ def CallMethod(self, method, controller, request, response, callback):
+ self.method = method
+ self.controller = controller
+ self.request = request
+ callback(response)
+
+ class MockRpcController(service.RpcController):
+ def SetFailed(self, msg):
+ self.failure_message = msg
+
+ self.callback_response = None
+
+ class MyService(unittest_pb2.TestService):
+ pass
+
+ self.callback_response = None
+
+ def MyCallback(response):
+ self.callback_response = response
+
+ rpc_controller = MockRpcController()
+ channel = MockRpcChannel()
+ srvc = MyService()
+ srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
+ self.assertEqual('Method Foo not implemented.',
+ rpc_controller.failure_message)
+ self.assertEqual(None, self.callback_response)
+
+ rpc_controller.failure_message = None
+
+ service_descriptor = unittest_pb2.TestService.GetDescriptor()
+ srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
+ unittest_pb2.BarRequest(), MyCallback)
+ self.assertEqual('Method Bar not implemented.',
+ rpc_controller.failure_message)
+ self.assertEqual(None, self.callback_response)
+
+ class MyServiceImpl(unittest_pb2.TestService):
+ def Foo(self, rpc_controller, request, done):
+ self.foo_called = True
+ def Bar(self, rpc_controller, request, done):
+ self.bar_called = True
+
+ srvc = MyServiceImpl()
+ rpc_controller.failure_message = None
+ srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
+ self.assertEqual(None, rpc_controller.failure_message)
+ self.assertEqual(True, srvc.foo_called)
+
+ rpc_controller.failure_message = None
+ srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
+ unittest_pb2.BarRequest(), MyCallback)
+ self.assertEqual(None, rpc_controller.failure_message)
+ self.assertEqual(True, srvc.bar_called)
+
+ def testServiceStub(self):
+ class MockRpcChannel(service.RpcChannel):
+ def CallMethod(self, method, controller, request,
+ response_class, callback):
+ self.method = method
+ self.controller = controller
+ self.request = request
+ callback(response_class())
+
+ self.callback_response = None
+
+ def MyCallback(response):
+ self.callback_response = response
+
+ channel = MockRpcChannel()
+ stub = unittest_pb2.TestService_Stub(channel)
+ rpc_controller = 'controller'
+ request = 'request'
+
+ # GetDescriptor now static, still works as instance method for compatability
+ self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
+ stub.GetDescriptor())
+
+ # Invoke method.
+ stub.Foo(rpc_controller, request, MyCallback)
+
+ self.assertTrue(isinstance(self.callback_response,
+ unittest_pb2.FooResponse))
+ self.assertEqual(request, channel.request)
+ self.assertEqual(rpc_controller, channel.controller)
+ self.assertEqual(stub.GetDescriptor().methods[0], channel.method)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
new file mode 100755
index 0000000..1a0da55
--- /dev/null
+++ b/python/google/protobuf/internal/test_util.py
@@ -0,0 +1,612 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Utilities for Python proto2 tests.
+
+This is intentionally modeled on C++ code in
+//net/proto2/internal/test_util.*.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import os.path
+
+import unittest
+from google.protobuf import unittest_import_pb2
+from google.protobuf import unittest_pb2
+
+
+def SetAllFields(message):
+ """Sets every field in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestAllTypes instance.
+ """
+
+ #
+ # Optional fields.
+ #
+
+ message.optional_int32 = 101
+ message.optional_int64 = 102
+ message.optional_uint32 = 103
+ message.optional_uint64 = 104
+ message.optional_sint32 = 105
+ message.optional_sint64 = 106
+ message.optional_fixed32 = 107
+ message.optional_fixed64 = 108
+ message.optional_sfixed32 = 109
+ message.optional_sfixed64 = 110
+ message.optional_float = 111
+ message.optional_double = 112
+ message.optional_bool = True
+ # TODO(robinson): Firmly spec out and test how
+ # protos interact with unicode. One specific example:
+ # what happens if we change the literal below to
+ # u'115'? What *should* happen? Still some discussion
+ # to finish with Kenton about bytes vs. strings
+ # and forcing everything to be utf8. :-/
+ message.optional_string = '115'
+ message.optional_bytes = '116'
+
+ message.optionalgroup.a = 117
+ message.optional_nested_message.bb = 118
+ message.optional_foreign_message.c = 119
+ message.optional_import_message.d = 120
+
+ message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
+ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
+ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
+
+ message.optional_string_piece = '124'
+ message.optional_cord = '125'
+
+ #
+ # Repeated fields.
+ #
+
+ message.repeated_int32.append(201)
+ message.repeated_int64.append(202)
+ message.repeated_uint32.append(203)
+ message.repeated_uint64.append(204)
+ message.repeated_sint32.append(205)
+ message.repeated_sint64.append(206)
+ message.repeated_fixed32.append(207)
+ message.repeated_fixed64.append(208)
+ message.repeated_sfixed32.append(209)
+ message.repeated_sfixed64.append(210)
+ message.repeated_float.append(211)
+ message.repeated_double.append(212)
+ message.repeated_bool.append(True)
+ message.repeated_string.append('215')
+ message.repeated_bytes.append('216')
+
+ message.repeatedgroup.add().a = 217
+ message.repeated_nested_message.add().bb = 218
+ message.repeated_foreign_message.add().c = 219
+ message.repeated_import_message.add().d = 220
+
+ message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
+
+ message.repeated_string_piece.append('224')
+ message.repeated_cord.append('225')
+
+ # Add a second one of each field.
+ message.repeated_int32.append(301)
+ message.repeated_int64.append(302)
+ message.repeated_uint32.append(303)
+ message.repeated_uint64.append(304)
+ message.repeated_sint32.append(305)
+ message.repeated_sint64.append(306)
+ message.repeated_fixed32.append(307)
+ message.repeated_fixed64.append(308)
+ message.repeated_sfixed32.append(309)
+ message.repeated_sfixed64.append(310)
+ message.repeated_float.append(311)
+ message.repeated_double.append(312)
+ message.repeated_bool.append(False)
+ message.repeated_string.append('315')
+ message.repeated_bytes.append('316')
+
+ message.repeatedgroup.add().a = 317
+ message.repeated_nested_message.add().bb = 318
+ message.repeated_foreign_message.add().c = 319
+ message.repeated_import_message.add().d = 320
+
+ message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
+ message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
+
+ message.repeated_string_piece.append('324')
+ message.repeated_cord.append('325')
+
+ #
+ # Fields that have defaults.
+ #
+
+ message.default_int32 = 401
+ message.default_int64 = 402
+ message.default_uint32 = 403
+ message.default_uint64 = 404
+ message.default_sint32 = 405
+ message.default_sint64 = 406
+ message.default_fixed32 = 407
+ message.default_fixed64 = 408
+ message.default_sfixed32 = 409
+ message.default_sfixed64 = 410
+ message.default_float = 411
+ message.default_double = 412
+ message.default_bool = False
+ message.default_string = '415'
+ message.default_bytes = '416'
+
+ message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
+ message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
+ message.default_import_enum = unittest_import_pb2.IMPORT_FOO
+
+ message.default_string_piece = '424'
+ message.default_cord = '425'
+
+
+def SetAllExtensions(message):
+ """Sets every extension in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestAllExtensions instance.
+ """
+
+ extensions = message.Extensions
+ pb2 = unittest_pb2
+ import_pb2 = unittest_import_pb2
+
+ #
+ # Optional fields.
+ #
+
+ extensions[pb2.optional_int32_extension] = 101
+ extensions[pb2.optional_int64_extension] = 102
+ extensions[pb2.optional_uint32_extension] = 103
+ extensions[pb2.optional_uint64_extension] = 104
+ extensions[pb2.optional_sint32_extension] = 105
+ extensions[pb2.optional_sint64_extension] = 106
+ extensions[pb2.optional_fixed32_extension] = 107
+ extensions[pb2.optional_fixed64_extension] = 108
+ extensions[pb2.optional_sfixed32_extension] = 109
+ extensions[pb2.optional_sfixed64_extension] = 110
+ extensions[pb2.optional_float_extension] = 111
+ extensions[pb2.optional_double_extension] = 112
+ extensions[pb2.optional_bool_extension] = True
+ extensions[pb2.optional_string_extension] = '115'
+ extensions[pb2.optional_bytes_extension] = '116'
+
+ extensions[pb2.optionalgroup_extension].a = 117
+ extensions[pb2.optional_nested_message_extension].bb = 118
+ extensions[pb2.optional_foreign_message_extension].c = 119
+ extensions[pb2.optional_import_message_extension].d = 120
+
+ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
+ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
+ extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ
+ extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ
+
+ extensions[pb2.optional_string_piece_extension] = '124'
+ extensions[pb2.optional_cord_extension] = '125'
+
+ #
+ # Repeated fields.
+ #
+
+ extensions[pb2.repeated_int32_extension].append(201)
+ extensions[pb2.repeated_int64_extension].append(202)
+ extensions[pb2.repeated_uint32_extension].append(203)
+ extensions[pb2.repeated_uint64_extension].append(204)
+ extensions[pb2.repeated_sint32_extension].append(205)
+ extensions[pb2.repeated_sint64_extension].append(206)
+ extensions[pb2.repeated_fixed32_extension].append(207)
+ extensions[pb2.repeated_fixed64_extension].append(208)
+ extensions[pb2.repeated_sfixed32_extension].append(209)
+ extensions[pb2.repeated_sfixed64_extension].append(210)
+ extensions[pb2.repeated_float_extension].append(211)
+ extensions[pb2.repeated_double_extension].append(212)
+ extensions[pb2.repeated_bool_extension].append(True)
+ extensions[pb2.repeated_string_extension].append('215')
+ extensions[pb2.repeated_bytes_extension].append('216')
+
+ extensions[pb2.repeatedgroup_extension].add().a = 217
+ extensions[pb2.repeated_nested_message_extension].add().bb = 218
+ extensions[pb2.repeated_foreign_message_extension].add().c = 219
+ extensions[pb2.repeated_import_message_extension].add().d = 220
+
+ extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR)
+ extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR)
+ extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR)
+
+ extensions[pb2.repeated_string_piece_extension].append('224')
+ extensions[pb2.repeated_cord_extension].append('225')
+
+ # Append a second one of each field.
+ extensions[pb2.repeated_int32_extension].append(301)
+ extensions[pb2.repeated_int64_extension].append(302)
+ extensions[pb2.repeated_uint32_extension].append(303)
+ extensions[pb2.repeated_uint64_extension].append(304)
+ extensions[pb2.repeated_sint32_extension].append(305)
+ extensions[pb2.repeated_sint64_extension].append(306)
+ extensions[pb2.repeated_fixed32_extension].append(307)
+ extensions[pb2.repeated_fixed64_extension].append(308)
+ extensions[pb2.repeated_sfixed32_extension].append(309)
+ extensions[pb2.repeated_sfixed64_extension].append(310)
+ extensions[pb2.repeated_float_extension].append(311)
+ extensions[pb2.repeated_double_extension].append(312)
+ extensions[pb2.repeated_bool_extension].append(False)
+ extensions[pb2.repeated_string_extension].append('315')
+ extensions[pb2.repeated_bytes_extension].append('316')
+
+ extensions[pb2.repeatedgroup_extension].add().a = 317
+ extensions[pb2.repeated_nested_message_extension].add().bb = 318
+ extensions[pb2.repeated_foreign_message_extension].add().c = 319
+ extensions[pb2.repeated_import_message_extension].add().d = 320
+
+ extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ)
+ extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ)
+ extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ)
+
+ extensions[pb2.repeated_string_piece_extension].append('324')
+ extensions[pb2.repeated_cord_extension].append('325')
+
+ #
+ # Fields with defaults.
+ #
+
+ extensions[pb2.default_int32_extension] = 401
+ extensions[pb2.default_int64_extension] = 402
+ extensions[pb2.default_uint32_extension] = 403
+ extensions[pb2.default_uint64_extension] = 404
+ extensions[pb2.default_sint32_extension] = 405
+ extensions[pb2.default_sint64_extension] = 406
+ extensions[pb2.default_fixed32_extension] = 407
+ extensions[pb2.default_fixed64_extension] = 408
+ extensions[pb2.default_sfixed32_extension] = 409
+ extensions[pb2.default_sfixed64_extension] = 410
+ extensions[pb2.default_float_extension] = 411
+ extensions[pb2.default_double_extension] = 412
+ extensions[pb2.default_bool_extension] = False
+ extensions[pb2.default_string_extension] = '415'
+ extensions[pb2.default_bytes_extension] = '416'
+
+ extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO
+ extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO
+ extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO
+
+ extensions[pb2.default_string_piece_extension] = '424'
+ extensions[pb2.default_cord_extension] = '425'
+
+
+def SetAllFieldsAndExtensions(message):
+ """Sets every field and extension in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestAllExtensions message.
+ """
+ message.my_int = 1
+ message.my_string = 'foo'
+ message.my_float = 1.0
+ message.Extensions[unittest_pb2.my_extension_int] = 23
+ message.Extensions[unittest_pb2.my_extension_string] = 'bar'
+
+
+def ExpectAllFieldsAndExtensionsInOrder(serialized):
+ """Ensures that serialized is the serialization we expect for a message
+ filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the
+ serialization is in canonical, tag-number order).
+ """
+ my_extension_int = unittest_pb2.my_extension_int
+ my_extension_string = unittest_pb2.my_extension_string
+ expected_strings = []
+ message = unittest_pb2.TestFieldOrderings()
+ message.my_int = 1 # Field 1.
+ expected_strings.append(message.SerializeToString())
+ message.Clear()
+ message.Extensions[my_extension_int] = 23 # Field 5.
+ expected_strings.append(message.SerializeToString())
+ message.Clear()
+ message.my_string = 'foo' # Field 11.
+ expected_strings.append(message.SerializeToString())
+ message.Clear()
+ message.Extensions[my_extension_string] = 'bar' # Field 50.
+ expected_strings.append(message.SerializeToString())
+ message.Clear()
+ message.my_float = 1.0
+ expected_strings.append(message.SerializeToString())
+ message.Clear()
+ expected = ''.join(expected_strings)
+
+ if expected != serialized:
+ raise ValueError('Expected %r, found %r' % (expected, serialized))
+
+
+class GoldenMessageTestCase(unittest.TestCase):
+ """This adds methods to TestCase useful for verifying our Golden Message."""
+
+ def ExpectAllFieldsSet(self, message):
+ """Check all fields for correct values have after Set*Fields() is called."""
+ self.assertTrue(message.HasField('optional_int32'))
+ self.assertTrue(message.HasField('optional_int64'))
+ self.assertTrue(message.HasField('optional_uint32'))
+ self.assertTrue(message.HasField('optional_uint64'))
+ self.assertTrue(message.HasField('optional_sint32'))
+ self.assertTrue(message.HasField('optional_sint64'))
+ self.assertTrue(message.HasField('optional_fixed32'))
+ self.assertTrue(message.HasField('optional_fixed64'))
+ self.assertTrue(message.HasField('optional_sfixed32'))
+ self.assertTrue(message.HasField('optional_sfixed64'))
+ self.assertTrue(message.HasField('optional_float'))
+ self.assertTrue(message.HasField('optional_double'))
+ self.assertTrue(message.HasField('optional_bool'))
+ self.assertTrue(message.HasField('optional_string'))
+ self.assertTrue(message.HasField('optional_bytes'))
+
+ self.assertTrue(message.HasField('optionalgroup'))
+ self.assertTrue(message.HasField('optional_nested_message'))
+ self.assertTrue(message.HasField('optional_foreign_message'))
+ self.assertTrue(message.HasField('optional_import_message'))
+
+ self.assertTrue(message.optionalgroup.HasField('a'))
+ self.assertTrue(message.optional_nested_message.HasField('bb'))
+ self.assertTrue(message.optional_foreign_message.HasField('c'))
+ self.assertTrue(message.optional_import_message.HasField('d'))
+
+ self.assertTrue(message.HasField('optional_nested_enum'))
+ self.assertTrue(message.HasField('optional_foreign_enum'))
+ self.assertTrue(message.HasField('optional_import_enum'))
+
+ self.assertTrue(message.HasField('optional_string_piece'))
+ self.assertTrue(message.HasField('optional_cord'))
+
+ self.assertEqual(101, message.optional_int32)
+ self.assertEqual(102, message.optional_int64)
+ self.assertEqual(103, message.optional_uint32)
+ self.assertEqual(104, message.optional_uint64)
+ self.assertEqual(105, message.optional_sint32)
+ self.assertEqual(106, message.optional_sint64)
+ self.assertEqual(107, message.optional_fixed32)
+ self.assertEqual(108, message.optional_fixed64)
+ self.assertEqual(109, message.optional_sfixed32)
+ self.assertEqual(110, message.optional_sfixed64)
+ self.assertEqual(111, message.optional_float)
+ self.assertEqual(112, message.optional_double)
+ self.assertEqual(True, message.optional_bool)
+ self.assertEqual('115', message.optional_string)
+ self.assertEqual('116', message.optional_bytes)
+
+ self.assertEqual(117, message.optionalgroup.a);
+ self.assertEqual(118, message.optional_nested_message.bb)
+ self.assertEqual(119, message.optional_foreign_message.c)
+ self.assertEqual(120, message.optional_import_message.d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.optional_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum)
+ self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.optional_import_enum)
+
+ # -----------------------------------------------------------------
+
+ self.assertEqual(2, len(message.repeated_int32))
+ self.assertEqual(2, len(message.repeated_int64))
+ self.assertEqual(2, len(message.repeated_uint32))
+ self.assertEqual(2, len(message.repeated_uint64))
+ self.assertEqual(2, len(message.repeated_sint32))
+ self.assertEqual(2, len(message.repeated_sint64))
+ self.assertEqual(2, len(message.repeated_fixed32))
+ self.assertEqual(2, len(message.repeated_fixed64))
+ self.assertEqual(2, len(message.repeated_sfixed32))
+ self.assertEqual(2, len(message.repeated_sfixed64))
+ self.assertEqual(2, len(message.repeated_float))
+ self.assertEqual(2, len(message.repeated_double))
+ self.assertEqual(2, len(message.repeated_bool))
+ self.assertEqual(2, len(message.repeated_string))
+ self.assertEqual(2, len(message.repeated_bytes))
+
+ self.assertEqual(2, len(message.repeatedgroup))
+ self.assertEqual(2, len(message.repeated_nested_message))
+ self.assertEqual(2, len(message.repeated_foreign_message))
+ self.assertEqual(2, len(message.repeated_import_message))
+ self.assertEqual(2, len(message.repeated_nested_enum))
+ self.assertEqual(2, len(message.repeated_foreign_enum))
+ self.assertEqual(2, len(message.repeated_import_enum))
+
+ self.assertEqual(2, len(message.repeated_string_piece))
+ self.assertEqual(2, len(message.repeated_cord))
+
+ self.assertEqual(201, message.repeated_int32[0])
+ self.assertEqual(202, message.repeated_int64[0])
+ self.assertEqual(203, message.repeated_uint32[0])
+ self.assertEqual(204, message.repeated_uint64[0])
+ self.assertEqual(205, message.repeated_sint32[0])
+ self.assertEqual(206, message.repeated_sint64[0])
+ self.assertEqual(207, message.repeated_fixed32[0])
+ self.assertEqual(208, message.repeated_fixed64[0])
+ self.assertEqual(209, message.repeated_sfixed32[0])
+ self.assertEqual(210, message.repeated_sfixed64[0])
+ self.assertEqual(211, message.repeated_float[0])
+ self.assertEqual(212, message.repeated_double[0])
+ self.assertEqual(True, message.repeated_bool[0])
+ self.assertEqual('215', message.repeated_string[0])
+ self.assertEqual('216', message.repeated_bytes[0])
+
+ self.assertEqual(217, message.repeatedgroup[0].a)
+ self.assertEqual(218, message.repeated_nested_message[0].bb)
+ self.assertEqual(219, message.repeated_foreign_message[0].c)
+ self.assertEqual(220, message.repeated_import_message[0].d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAR,
+ message.repeated_nested_enum[0])
+ self.assertEqual(unittest_pb2.FOREIGN_BAR,
+ message.repeated_foreign_enum[0])
+ self.assertEqual(unittest_import_pb2.IMPORT_BAR,
+ message.repeated_import_enum[0])
+
+ self.assertEqual(301, message.repeated_int32[1])
+ self.assertEqual(302, message.repeated_int64[1])
+ self.assertEqual(303, message.repeated_uint32[1])
+ self.assertEqual(304, message.repeated_uint64[1])
+ self.assertEqual(305, message.repeated_sint32[1])
+ self.assertEqual(306, message.repeated_sint64[1])
+ self.assertEqual(307, message.repeated_fixed32[1])
+ self.assertEqual(308, message.repeated_fixed64[1])
+ self.assertEqual(309, message.repeated_sfixed32[1])
+ self.assertEqual(310, message.repeated_sfixed64[1])
+ self.assertEqual(311, message.repeated_float[1])
+ self.assertEqual(312, message.repeated_double[1])
+ self.assertEqual(False, message.repeated_bool[1])
+ self.assertEqual('315', message.repeated_string[1])
+ self.assertEqual('316', message.repeated_bytes[1])
+
+ self.assertEqual(317, message.repeatedgroup[1].a)
+ self.assertEqual(318, message.repeated_nested_message[1].bb)
+ self.assertEqual(319, message.repeated_foreign_message[1].c)
+ self.assertEqual(320, message.repeated_import_message[1].d)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
+ message.repeated_nested_enum[1])
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ message.repeated_foreign_enum[1])
+ self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
+ message.repeated_import_enum[1])
+
+ # -----------------------------------------------------------------
+
+ self.assertTrue(message.HasField('default_int32'))
+ self.assertTrue(message.HasField('default_int64'))
+ self.assertTrue(message.HasField('default_uint32'))
+ self.assertTrue(message.HasField('default_uint64'))
+ self.assertTrue(message.HasField('default_sint32'))
+ self.assertTrue(message.HasField('default_sint64'))
+ self.assertTrue(message.HasField('default_fixed32'))
+ self.assertTrue(message.HasField('default_fixed64'))
+ self.assertTrue(message.HasField('default_sfixed32'))
+ self.assertTrue(message.HasField('default_sfixed64'))
+ self.assertTrue(message.HasField('default_float'))
+ self.assertTrue(message.HasField('default_double'))
+ self.assertTrue(message.HasField('default_bool'))
+ self.assertTrue(message.HasField('default_string'))
+ self.assertTrue(message.HasField('default_bytes'))
+
+ self.assertTrue(message.HasField('default_nested_enum'))
+ self.assertTrue(message.HasField('default_foreign_enum'))
+ self.assertTrue(message.HasField('default_import_enum'))
+
+ self.assertEqual(401, message.default_int32)
+ self.assertEqual(402, message.default_int64)
+ self.assertEqual(403, message.default_uint32)
+ self.assertEqual(404, message.default_uint64)
+ self.assertEqual(405, message.default_sint32)
+ self.assertEqual(406, message.default_sint64)
+ self.assertEqual(407, message.default_fixed32)
+ self.assertEqual(408, message.default_fixed64)
+ self.assertEqual(409, message.default_sfixed32)
+ self.assertEqual(410, message.default_sfixed64)
+ self.assertEqual(411, message.default_float)
+ self.assertEqual(412, message.default_double)
+ self.assertEqual(False, message.default_bool)
+ self.assertEqual('415', message.default_string)
+ self.assertEqual('416', message.default_bytes)
+
+ self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum)
+ self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum)
+ self.assertEqual(unittest_import_pb2.IMPORT_FOO,
+ message.default_import_enum)
+
+def GoldenFile(filename):
+ """Finds the given golden file and returns a file object representing it."""
+
+ # Search up the directory tree looking for the C++ protobuf source code.
+ path = '.'
+ while os.path.exists(path):
+ if os.path.exists(os.path.join(path, 'src/google/protobuf')):
+ # Found it. Load the golden file from the testdata directory.
+ full_path = os.path.join(path, 'src/google/protobuf/testdata', filename)
+ return open(full_path, 'rb')
+ path = os.path.join(path, '..')
+
+ raise RuntimeError(
+ 'Could not find golden files. This test must be run from within the '
+ 'protobuf source package so that it can read test data files from the '
+ 'C++ source tree.')
+
+
+def SetAllPackedFields(message):
+ """Sets every field in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestPackedTypes instance.
+ """
+ message.packed_int32.extend([101, 102])
+ message.packed_int64.extend([103, 104])
+ message.packed_uint32.extend([105, 106])
+ message.packed_uint64.extend([107, 108])
+ message.packed_sint32.extend([109, 110])
+ message.packed_sint64.extend([111, 112])
+ message.packed_fixed32.extend([113, 114])
+ message.packed_fixed64.extend([115, 116])
+ message.packed_sfixed32.extend([117, 118])
+ message.packed_sfixed64.extend([119, 120])
+ message.packed_float.extend([121.0, 122.0])
+ message.packed_double.extend([122.0, 123.0])
+ message.packed_bool.extend([True, False])
+ message.packed_enum.extend([unittest_pb2.FOREIGN_FOO,
+ unittest_pb2.FOREIGN_BAR])
+
+
+def SetAllPackedExtensions(message):
+ """Sets every extension in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestPackedExtensions instance.
+ """
+ extensions = message.Extensions
+ pb2 = unittest_pb2
+
+ extensions[pb2.packed_int32_extension].append(101)
+ extensions[pb2.packed_int64_extension].append(102)
+ extensions[pb2.packed_uint32_extension].append(103)
+ extensions[pb2.packed_uint64_extension].append(104)
+ extensions[pb2.packed_sint32_extension].append(105)
+ extensions[pb2.packed_sint64_extension].append(106)
+ extensions[pb2.packed_fixed32_extension].append(107)
+ extensions[pb2.packed_fixed64_extension].append(108)
+ extensions[pb2.packed_sfixed32_extension].append(109)
+ extensions[pb2.packed_sfixed64_extension].append(110)
+ extensions[pb2.packed_float_extension].append(111.0)
+ extensions[pb2.packed_double_extension].append(112.0)
+ extensions[pb2.packed_bool_extension].append(True)
+ extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ)
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
new file mode 100755
index 0000000..0cf2718
--- /dev/null
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -0,0 +1,397 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.text_format."""
+
+__author__ = 'kenton@google.com (Kenton Varda)'
+
+import difflib
+
+import unittest
+from google.protobuf import text_format
+from google.protobuf.internal import test_util
+from google.protobuf import unittest_pb2
+from google.protobuf import unittest_mset_pb2
+
+
+class TextFormatTest(test_util.GoldenMessageTestCase):
+ def ReadGolden(self, golden_filename):
+ f = test_util.GoldenFile(golden_filename)
+ golden_lines = f.readlines()
+ f.close()
+ return golden_lines
+
+ def CompareToGoldenFile(self, text, golden_filename):
+ golden_lines = self.ReadGolden(golden_filename)
+ self.CompareToGoldenLines(text, golden_lines)
+
+ def CompareToGoldenText(self, text, golden_text):
+ self.CompareToGoldenLines(text, golden_text.splitlines(1))
+
+ def CompareToGoldenLines(self, text, golden_lines):
+ actual_lines = text.splitlines(1)
+ self.assertEqual(golden_lines, actual_lines,
+ "Text doesn't match golden. Diff:\n" +
+ ''.join(difflib.ndiff(golden_lines, actual_lines)))
+
+ def testPrintAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_data.txt')
+
+ def testPrintAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.CompareToGoldenFile(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'text_format_unittest_extensions_data.txt')
+
+ def testPrintMessageSet(self):
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ message.message_set.Extensions[ext1].i = 23
+ message.message_set.Extensions[ext2].str = 'foo'
+ self.CompareToGoldenText(text_format.MessageToString(message),
+ 'message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+
+ def testPrintExotic(self):
+ message = unittest_pb2.TestAllTypes()
+ message.repeated_int64.append(-9223372036854775808);
+ message.repeated_uint64.append(18446744073709551615);
+ message.repeated_double.append(123.456);
+ message.repeated_double.append(1.23e22);
+ message.repeated_double.append(1.23e-18);
+ message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"');
+ self.CompareToGoldenText(
+ self.RemoveRedundantZeros(text_format.MessageToString(message)),
+ 'repeated_int64: -9223372036854775808\n'
+ 'repeated_uint64: 18446744073709551615\n'
+ 'repeated_double: 123.456\n'
+ 'repeated_double: 1.23e+22\n'
+ 'repeated_double: 1.23e-18\n'
+ 'repeated_string: '
+ '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
+
+ def testMessageToString(self):
+ message = unittest_pb2.ForeignMessage()
+ message.c = 123
+ self.assertEqual('c: 123\n', str(message))
+
+ def RemoveRedundantZeros(self, text):
+ # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
+ # these zeros in order to match the golden file.
+ return text.replace('e+0','e+').replace('e+0','e+') \
+ .replace('e-0','e-').replace('e-0','e-')
+
+ def testMergeGolden(self):
+ golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
+ parsed_message = unittest_pb2.TestAllTypes()
+ text_format.Merge(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ self.assertEquals(message, parsed_message)
+
+ def testMergeGoldenExtensions(self):
+ golden_text = '\n'.join(self.ReadGolden(
+ 'text_format_unittest_extensions_data.txt'))
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Merge(golden_text, parsed_message)
+
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ self.assertEquals(message, parsed_message)
+
+ def testMergeAllFields(self):
+ message = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllTypes()
+ text_format.Merge(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+ self.ExpectAllFieldsSet(message)
+
+ def testMergeAllExtensions(self):
+ message = unittest_pb2.TestAllExtensions()
+ test_util.SetAllExtensions(message)
+ ascii_text = text_format.MessageToString(message)
+
+ parsed_message = unittest_pb2.TestAllExtensions()
+ text_format.Merge(ascii_text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
+ def testMergeMessageSet(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_uint64: 1\n'
+ 'repeated_uint64: 2\n')
+ text_format.Merge(text, message)
+ self.assertEqual(1, message.repeated_uint64[0])
+ self.assertEqual(2, message.repeated_uint64[1])
+
+ message = unittest_mset_pb2.TestMessageSetContainer()
+ text = ('message_set {\n'
+ ' [protobuf_unittest.TestMessageSetExtension1] {\n'
+ ' i: 23\n'
+ ' }\n'
+ ' [protobuf_unittest.TestMessageSetExtension2] {\n'
+ ' str: \"foo\"\n'
+ ' }\n'
+ '}\n')
+ text_format.Merge(text, message)
+ ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
+ ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
+ self.assertEquals(23, message.message_set.Extensions[ext1].i)
+ self.assertEquals('foo', message.message_set.Extensions[ext2].str)
+
+ def testMergeExotic(self):
+ message = unittest_pb2.TestAllTypes()
+ text = ('repeated_int64: -9223372036854775808\n'
+ 'repeated_uint64: 18446744073709551615\n'
+ 'repeated_double: 123.456\n'
+ 'repeated_double: 1.23e+22\n'
+ 'repeated_double: 1.23e-18\n'
+ 'repeated_string: \n'
+ '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
+ text_format.Merge(text, message)
+
+ self.assertEqual(-9223372036854775808, message.repeated_int64[0])
+ self.assertEqual(18446744073709551615, message.repeated_uint64[0])
+ self.assertEqual(123.456, message.repeated_double[0])
+ self.assertEqual(1.23e22, message.repeated_double[1])
+ self.assertEqual(1.23e-18, message.repeated_double[2])
+ self.assertEqual(
+ '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0])
+
+ def testMergeUnknownField(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'unknown_field: 8\n'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
+ '"unknown_field".'),
+ text_format.Merge, text, message)
+
+ def testMergeBadExtension(self):
+ message = unittest_pb2.TestAllTypes()
+ text = '[unknown_extension]: 8\n'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ '1:2 : Extension "unknown_extension" not registered.',
+ text_format.Merge, text, message)
+
+ def testMergeGroupNotClosed(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'RepeatedGroup: <'
+ self.assertRaisesWithMessage(
+ text_format.ParseError, '1:16 : Expected ">".',
+ text_format.Merge, text, message)
+
+ text = 'RepeatedGroup: {'
+ self.assertRaisesWithMessage(
+ text_format.ParseError, '1:16 : Expected "}".',
+ text_format.Merge, text, message)
+
+ def testMergeBadEnumValue(self):
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_nested_enum: BARR'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ 'has no value named BARR.'),
+ text_format.Merge, text, message)
+
+ message = unittest_pb2.TestAllTypes()
+ text = 'optional_nested_enum: 100'
+ self.assertRaisesWithMessage(
+ text_format.ParseError,
+ ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ 'has no value with number 100.'),
+ text_format.Merge, text, message)
+
+ def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs):
+ """Same as assertRaises, but also compares the exception message."""
+ if hasattr(e_class, '__name__'):
+ exc_name = e_class.__name__
+ else:
+ exc_name = str(e_class)
+
+ try:
+ func(*args, **kwargs)
+ except e_class, expr:
+ if str(expr) != e:
+ msg = '%s raised, but with wrong message: "%s" instead of "%s"'
+ raise self.failureException(msg % (exc_name,
+ str(expr).encode('string_escape'),
+ e.encode('string_escape')))
+ return
+ else:
+ raise self.failureException('%s not raised' % exc_name)
+
+
+class TokenizerTest(unittest.TestCase):
+
+ def testSimpleTokenCases(self):
+ text = ('identifier1:"string1"\n \n\n'
+ 'identifier2 : \n \n123 \n identifier3 :\'string\'\n'
+ 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n'
+ 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
+ 'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
+ 'ID12: 2222222222222222222')
+ tokenizer = text_format._Tokenizer(text)
+ methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
+ ':',
+ (tokenizer.ConsumeString, 'string1'),
+ (tokenizer.ConsumeIdentifier, 'identifier2'),
+ ':',
+ (tokenizer.ConsumeInt32, 123),
+ (tokenizer.ConsumeIdentifier, 'identifier3'),
+ ':',
+ (tokenizer.ConsumeString, 'string'),
+ (tokenizer.ConsumeIdentifier, 'identifiER_4'),
+ ':',
+ (tokenizer.ConsumeFloat, 1.1e+2),
+ (tokenizer.ConsumeIdentifier, 'ID5'),
+ ':',
+ (tokenizer.ConsumeFloat, -0.23),
+ (tokenizer.ConsumeIdentifier, 'ID6'),
+ ':',
+ (tokenizer.ConsumeString, 'aaaa\'bbbb'),
+ (tokenizer.ConsumeIdentifier, 'ID7'),
+ ':',
+ (tokenizer.ConsumeString, 'aa\"bb'),
+ (tokenizer.ConsumeIdentifier, 'ID8'),
+ ':',
+ '{',
+ (tokenizer.ConsumeIdentifier, 'A'),
+ ':',
+ (tokenizer.ConsumeFloat, float('inf')),
+ (tokenizer.ConsumeIdentifier, 'B'),
+ ':',
+ (tokenizer.ConsumeFloat, float('-inf')),
+ (tokenizer.ConsumeIdentifier, 'C'),
+ ':',
+ (tokenizer.ConsumeBool, True),
+ (tokenizer.ConsumeIdentifier, 'D'),
+ ':',
+ (tokenizer.ConsumeBool, False),
+ '}',
+ (tokenizer.ConsumeIdentifier, 'ID9'),
+ ':',
+ (tokenizer.ConsumeUint32, 22),
+ (tokenizer.ConsumeIdentifier, 'ID10'),
+ ':',
+ (tokenizer.ConsumeInt64, -111111111111111111),
+ (tokenizer.ConsumeIdentifier, 'ID11'),
+ ':',
+ (tokenizer.ConsumeInt32, -22),
+ (tokenizer.ConsumeIdentifier, 'ID12'),
+ ':',
+ (tokenizer.ConsumeUint64, 2222222222222222222)]
+
+ i = 0
+ while not tokenizer.AtEnd():
+ m = methods[i]
+ if type(m) == str:
+ token = tokenizer.token
+ self.assertEqual(token, m)
+ tokenizer.NextToken()
+ else:
+ self.assertEqual(m[1], m[0]())
+ i += 1
+
+ def testConsumeIntegers(self):
+ # This test only tests the failures in the integer parsing methods as well
+ # as the '0' special cases.
+ int64_max = (1 << 63) - 1
+ uint32_max = (1 << 32) - 1
+ text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
+ self.assertEqual(-1, tokenizer.ConsumeInt32())
+
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32)
+ self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64())
+
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64)
+ self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64())
+ self.assertTrue(tokenizer.AtEnd())
+
+ text = '-0 -0 0 0'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertEqual(0, tokenizer.ConsumeUint32())
+ self.assertEqual(0, tokenizer.ConsumeUint64())
+ self.assertEqual(0, tokenizer.ConsumeUint32())
+ self.assertEqual(0, tokenizer.ConsumeUint64())
+ self.assertTrue(tokenizer.AtEnd())
+
+ def testConsumeByteString(self):
+ text = '"string1\''
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = 'string1"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\xt"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ text = '\n"\\x"'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
+
+ def testConsumeBool(self):
+ text = 'not-a-bool'
+ tokenizer = text_format._Tokenizer(text)
+ self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
new file mode 100755
index 0000000..a3bc57f
--- /dev/null
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -0,0 +1,287 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Provides type checking routines.
+
+This module defines type checking utilities in the forms of dictionaries:
+
+VALUE_CHECKERS: A dictionary of field types and a value validation object.
+TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing
+ function.
+TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization
+ function.
+FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their
+ coresponding wire types.
+TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
+ function.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+from google.protobuf.internal import decoder
+from google.protobuf.internal import encoder
+from google.protobuf.internal import wire_format
+from google.protobuf import descriptor
+
+_FieldDescriptor = descriptor.FieldDescriptor
+
+
+def GetTypeChecker(cpp_type, field_type):
+ """Returns a type checker for a message field of the specified types.
+
+ Args:
+ cpp_type: C++ type of the field (see descriptor.py).
+ field_type: Protocol message field type (see descriptor.py).
+
+ Returns:
+ An instance of TypeChecker which can be used to verify the types
+ of values assigned to a field of the specified type.
+ """
+ if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and
+ field_type == _FieldDescriptor.TYPE_STRING):
+ return UnicodeValueChecker()
+ return _VALUE_CHECKERS[cpp_type]
+
+
+# None of the typecheckers below make any attempt to guard against people
+# subclassing builtin types and doing weird things. We're not trying to
+# protect against malicious clients here, just people accidentally shooting
+# themselves in the foot in obvious ways.
+
+class TypeChecker(object):
+
+ """Type checker used to catch type errors as early as possible
+ when the client is setting scalar fields in protocol messages.
+ """
+
+ def __init__(self, *acceptable_types):
+ self._acceptable_types = acceptable_types
+
+ def CheckValue(self, proposed_value):
+ if not isinstance(proposed_value, self._acceptable_types):
+ message = ('%.1024r has type %s, but expected one of: %s' %
+ (proposed_value, type(proposed_value), self._acceptable_types))
+ raise TypeError(message)
+
+
+# IntValueChecker and its subclasses perform integer type-checks
+# and bounds-checks.
+class IntValueChecker(object):
+
+ """Checker used for integer fields. Performs type-check and range check."""
+
+ def CheckValue(self, proposed_value):
+ if not isinstance(proposed_value, (int, long)):
+ message = ('%.1024r has type %s, but expected one of: %s' %
+ (proposed_value, type(proposed_value), (int, long)))
+ raise TypeError(message)
+ if not self._MIN <= proposed_value <= self._MAX:
+ raise ValueError('Value out of range: %d' % proposed_value)
+
+
+class UnicodeValueChecker(object):
+
+ """Checker used for string fields."""
+
+ def CheckValue(self, proposed_value):
+ if not isinstance(proposed_value, (str, unicode)):
+ message = ('%.1024r has type %s, but expected one of: %s' %
+ (proposed_value, type(proposed_value), (str, unicode)))
+ raise TypeError(message)
+
+ # If the value is of type 'str' make sure that it is in 7-bit ASCII
+ # encoding.
+ if isinstance(proposed_value, str):
+ try:
+ unicode(proposed_value, 'ascii')
+ except UnicodeDecodeError:
+ raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII '
+ 'encoding. Non-ASCII strings must be converted to '
+ 'unicode objects before being added.' %
+ (proposed_value))
+
+
+class Int32ValueChecker(IntValueChecker):
+ # We're sure to use ints instead of longs here since comparison may be more
+ # efficient.
+ _MIN = -2147483648
+ _MAX = 2147483647
+
+
+class Uint32ValueChecker(IntValueChecker):
+ _MIN = 0
+ _MAX = (1 << 32) - 1
+
+
+class Int64ValueChecker(IntValueChecker):
+ _MIN = -(1 << 63)
+ _MAX = (1 << 63) - 1
+
+
+class Uint64ValueChecker(IntValueChecker):
+ _MIN = 0
+ _MAX = (1 << 64) - 1
+
+
+# Type-checkers for all scalar CPPTYPEs.
+_VALUE_CHECKERS = {
+ _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(),
+ _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
+ _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
+ _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
+ _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker(
+ float, int, long),
+ _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker(
+ float, int, long),
+ _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int),
+ _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(),
+ _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str),
+ }
+
+
+# Map from field type to a function F, such that F(field_num, value)
+# gives the total byte size for a value of the given type. This
+# byte size includes tag information and any other additional space
+# associated with serializing "value".
+TYPE_TO_BYTE_SIZE_FN = {
+ _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
+ _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
+ _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
+ _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
+ _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
+ _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
+ _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
+ _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
+ _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
+ _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
+ _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
+ _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
+ _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
+ _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
+ _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
+ _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
+ _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
+ _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
+ }
+
+
+# Maps from field type to an unbound Encoder method F, such that
+# F(encoder, field_number, value) will append the serialization
+# of a value of this type to the encoder.
+_Encoder = encoder.Encoder
+TYPE_TO_SERIALIZE_METHOD = {
+ _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble,
+ _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat,
+ _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64,
+ _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64,
+ _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32,
+ _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64,
+ _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32,
+ _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool,
+ _FieldDescriptor.TYPE_STRING: _Encoder.AppendString,
+ _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup,
+ _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage,
+ _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes,
+ _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32,
+ _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum,
+ _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32,
+ _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64,
+ _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32,
+ _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64,
+ }
+
+
+TYPE_TO_NOTAG_SERIALIZE_METHOD = {
+ _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag,
+ _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag,
+ _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag,
+ _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag,
+ _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag,
+ _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag,
+ _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag,
+ _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag,
+ _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag,
+ _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag,
+ _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag,
+ _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag,
+ _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag,
+ _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag,
+ }
+
+# Maps from field type to expected wiretype.
+FIELD_TYPE_TO_WIRE_TYPE = {
+ _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
+ _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
+ _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
+ _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
+ _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_STRING:
+ wire_format.WIRETYPE_LENGTH_DELIMITED,
+ _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
+ _FieldDescriptor.TYPE_MESSAGE:
+ wire_format.WIRETYPE_LENGTH_DELIMITED,
+ _FieldDescriptor.TYPE_BYTES:
+ wire_format.WIRETYPE_LENGTH_DELIMITED,
+ _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
+ _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
+ _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
+ _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
+ }
+
+
+# Maps from field type to an unbound Decoder method F,
+# such that F(decoder) will read a field of the requested type.
+#
+# Note that Message and Group are intentionally missing here.
+# They're handled by _RecursivelyMerge().
+_Decoder = decoder.Decoder
+TYPE_TO_DESERIALIZE_METHOD = {
+ _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble,
+ _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat,
+ _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64,
+ _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64,
+ _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32,
+ _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64,
+ _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32,
+ _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool,
+ _FieldDescriptor.TYPE_STRING: _Decoder.ReadString,
+ _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes,
+ _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32,
+ _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum,
+ _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32,
+ _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64,
+ _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32,
+ _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64,
+ }
diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py
new file mode 100755
index 0000000..da6464d
--- /dev/null
+++ b/python/google/protobuf/internal/wire_format.py
@@ -0,0 +1,247 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Constants and static functions to support protocol buffer wire format."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import struct
+from google.protobuf import message
+
+
+TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
+_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
+
+# These numbers identify the wire type of a protocol buffer value.
+# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
+# tag-and-type to store one of these WIRETYPE_* constants.
+# These values must match WireType enum in //net/proto2/public/wire_format.h.
+WIRETYPE_VARINT = 0
+WIRETYPE_FIXED64 = 1
+WIRETYPE_LENGTH_DELIMITED = 2
+WIRETYPE_START_GROUP = 3
+WIRETYPE_END_GROUP = 4
+WIRETYPE_FIXED32 = 5
+_WIRETYPE_MAX = 5
+
+
+# Bounds for various integer types.
+INT32_MAX = int((1 << 31) - 1)
+INT32_MIN = int(-(1 << 31))
+UINT32_MAX = (1 << 32) - 1
+
+INT64_MAX = (1 << 63) - 1
+INT64_MIN = -(1 << 63)
+UINT64_MAX = (1 << 64) - 1
+
+# "struct" format strings that will encode/decode the specified formats.
+FORMAT_UINT32_LITTLE_ENDIAN = '<I'
+FORMAT_UINT64_LITTLE_ENDIAN = '<Q'
+FORMAT_FLOAT_LITTLE_ENDIAN = '<f'
+FORMAT_DOUBLE_LITTLE_ENDIAN = '<d'
+
+
+# We'll have to provide alternate implementations of AppendLittleEndian*() on
+# any architectures where these checks fail.
+if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4:
+ raise AssertionError('Format "I" is not a 32-bit number.')
+if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8:
+ raise AssertionError('Format "Q" is not a 64-bit number.')
+
+
+def PackTag(field_number, wire_type):
+ """Returns an unsigned 32-bit integer that encodes the field number and
+ wire type information in standard protocol message wire format.
+
+ Args:
+ field_number: Expected to be an integer in the range [1, 1 << 29)
+ wire_type: One of the WIRETYPE_* constants.
+ """
+ if not 0 <= wire_type <= _WIRETYPE_MAX:
+ raise message.EncodeError('Unknown wire type: %d' % wire_type)
+ return (field_number << TAG_TYPE_BITS) | wire_type
+
+
+def UnpackTag(tag):
+ """The inverse of PackTag(). Given an unsigned 32-bit number,
+ returns a (field_number, wire_type) tuple.
+ """
+ return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK)
+
+
+def ZigZagEncode(value):
+ """ZigZag Transform: Encodes signed integers so that they can be
+ effectively used with varint encoding. See wire_format.h for
+ more details.
+ """
+ if value >= 0:
+ return value << 1
+ return (value << 1) ^ (~0)
+
+
+def ZigZagDecode(value):
+ """Inverse of ZigZagEncode()."""
+ if not value & 0x1:
+ return value >> 1
+ return (value >> 1) ^ (~0)
+
+
+
+# The *ByteSize() functions below return the number of bytes required to
+# serialize "field number + type" information and then serialize the value.
+
+
+def Int32ByteSize(field_number, int32):
+ return Int64ByteSize(field_number, int32)
+
+
+def Int32ByteSizeNoTag(int32):
+ return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32)
+
+
+def Int64ByteSize(field_number, int64):
+ # Have to convert to uint before calling UInt64ByteSize().
+ return UInt64ByteSize(field_number, 0xffffffffffffffff & int64)
+
+
+def UInt32ByteSize(field_number, uint32):
+ return UInt64ByteSize(field_number, uint32)
+
+
+def UInt64ByteSize(field_number, uint64):
+ return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
+
+
+def SInt32ByteSize(field_number, int32):
+ return UInt32ByteSize(field_number, ZigZagEncode(int32))
+
+
+def SInt64ByteSize(field_number, int64):
+ return UInt64ByteSize(field_number, ZigZagEncode(int64))
+
+
+def Fixed32ByteSize(field_number, fixed32):
+ return TagByteSize(field_number) + 4
+
+
+def Fixed64ByteSize(field_number, fixed64):
+ return TagByteSize(field_number) + 8
+
+
+def SFixed32ByteSize(field_number, sfixed32):
+ return TagByteSize(field_number) + 4
+
+
+def SFixed64ByteSize(field_number, sfixed64):
+ return TagByteSize(field_number) + 8
+
+
+def FloatByteSize(field_number, flt):
+ return TagByteSize(field_number) + 4
+
+
+def DoubleByteSize(field_number, double):
+ return TagByteSize(field_number) + 8
+
+
+def BoolByteSize(field_number, b):
+ return TagByteSize(field_number) + 1
+
+
+def EnumByteSize(field_number, enum):
+ return UInt32ByteSize(field_number, enum)
+
+
+def StringByteSize(field_number, string):
+ return BytesByteSize(field_number, string.encode('utf-8'))
+
+
+def BytesByteSize(field_number, b):
+ return (TagByteSize(field_number)
+ + _VarUInt64ByteSizeNoTag(len(b))
+ + len(b))
+
+
+def GroupByteSize(field_number, message):
+ return (2 * TagByteSize(field_number) # START and END group.
+ + message.ByteSize())
+
+
+def MessageByteSize(field_number, message):
+ return (TagByteSize(field_number)
+ + _VarUInt64ByteSizeNoTag(message.ByteSize())
+ + message.ByteSize())
+
+
+def MessageSetItemByteSize(field_number, msg):
+ # First compute the sizes of the tags.
+ # There are 2 tags for the beginning and ending of the repeated group, that
+ # is field number 1, one with field number 2 (type_id) and one with field
+ # number 3 (message).
+ total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3))
+
+ # Add the number of bytes for type_id.
+ total_size += _VarUInt64ByteSizeNoTag(field_number)
+
+ message_size = msg.ByteSize()
+
+ # The number of bytes for encoding the length of the message.
+ total_size += _VarUInt64ByteSizeNoTag(message_size)
+
+ # The size of the message.
+ total_size += message_size
+ return total_size
+
+
+def TagByteSize(field_number):
+ """Returns the bytes required to serialize a tag with this field number."""
+ # Just pass in type 0, since the type won't affect the tag+type size.
+ return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0))
+
+
+# Private helper function for the *ByteSize() functions above.
+
+def _VarUInt64ByteSizeNoTag(uint64):
+ """Returns the number of bytes required to serialize a single varint
+ using boundary value comparisons. (unrolled loop optimization -WPierce)
+ uint64 must be unsigned.
+ """
+ if uint64 <= 0x7f: return 1
+ if uint64 <= 0x3fff: return 2
+ if uint64 <= 0x1fffff: return 3
+ if uint64 <= 0xfffffff: return 4
+ if uint64 <= 0x7ffffffff: return 5
+ if uint64 <= 0x3ffffffffff: return 6
+ if uint64 <= 0x1ffffffffffff: return 7
+ if uint64 <= 0xffffffffffffff: return 8
+ if uint64 <= 0x7fffffffffffffff: return 9
+ if uint64 > UINT64_MAX:
+ raise message.EncodeError('Value out of range: %d' % uint64)
+ return 10
diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py
new file mode 100755
index 0000000..7600778
--- /dev/null
+++ b/python/google/protobuf/internal/wire_format_test.py
@@ -0,0 +1,253 @@
+#! /usr/bin/python
+#
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Test for google.protobuf.internal.wire_format."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import unittest
+from google.protobuf import message
+from google.protobuf.internal import wire_format
+
+
+class WireFormatTest(unittest.TestCase):
+
+ def testPackTag(self):
+ field_number = 0xabc
+ tag_type = 2
+ self.assertEqual((field_number << 3) | tag_type,
+ wire_format.PackTag(field_number, tag_type))
+ PackTag = wire_format.PackTag
+ # Number too high.
+ self.assertRaises(message.EncodeError, PackTag, field_number, 6)
+ # Number too low.
+ self.assertRaises(message.EncodeError, PackTag, field_number, -1)
+
+ def testUnpackTag(self):
+ # Test field numbers that will require various varint sizes.
+ for expected_field_number in (1, 15, 16, 2047, 2048):
+ for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
+ field_number, wire_type = wire_format.UnpackTag(
+ wire_format.PackTag(expected_field_number, expected_wire_type))
+ self.assertEqual(expected_field_number, field_number)
+ self.assertEqual(expected_wire_type, wire_type)
+
+ self.assertRaises(TypeError, wire_format.UnpackTag, None)
+ self.assertRaises(TypeError, wire_format.UnpackTag, 'abc')
+ self.assertRaises(TypeError, wire_format.UnpackTag, 0.0)
+ self.assertRaises(TypeError, wire_format.UnpackTag, object())
+
+ def testZigZagEncode(self):
+ Z = wire_format.ZigZagEncode
+ self.assertEqual(0, Z(0))
+ self.assertEqual(1, Z(-1))
+ self.assertEqual(2, Z(1))
+ self.assertEqual(3, Z(-2))
+ self.assertEqual(4, Z(2))
+ self.assertEqual(0xfffffffe, Z(0x7fffffff))
+ self.assertEqual(0xffffffff, Z(-0x80000000))
+ self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff))
+ self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000))
+
+ self.assertRaises(TypeError, Z, None)
+ self.assertRaises(TypeError, Z, 'abcd')
+ self.assertRaises(TypeError, Z, 0.0)
+ self.assertRaises(TypeError, Z, object())
+
+ def testZigZagDecode(self):
+ Z = wire_format.ZigZagDecode
+ self.assertEqual(0, Z(0))
+ self.assertEqual(-1, Z(1))
+ self.assertEqual(1, Z(2))
+ self.assertEqual(-2, Z(3))
+ self.assertEqual(2, Z(4))
+ self.assertEqual(0x7fffffff, Z(0xfffffffe))
+ self.assertEqual(-0x80000000, Z(0xffffffff))
+ self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe))
+ self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff))
+
+ self.assertRaises(TypeError, Z, None)
+ self.assertRaises(TypeError, Z, 'abcd')
+ self.assertRaises(TypeError, Z, 0.0)
+ self.assertRaises(TypeError, Z, object())
+
+ def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size):
+ # Use field numbers that cause various byte sizes for the tag information.
+ for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)):
+ expected_size = expected_value_size + tag_bytes
+ actual_size = byte_size_fn(field_number, value)
+ self.assertEqual(expected_size, actual_size,
+ 'byte_size_fn: %s, field_number: %d, value: %r\n'
+ 'Expected: %d, Actual: %d'% (
+ byte_size_fn, field_number, value, expected_size, actual_size))
+
+ def testByteSizeFunctions(self):
+ # Test all numeric *ByteSize() functions.
+ NUMERIC_ARGS = [
+ # Int32ByteSize().
+ [wire_format.Int32ByteSize, 0, 1],
+ [wire_format.Int32ByteSize, 127, 1],
+ [wire_format.Int32ByteSize, 128, 2],
+ [wire_format.Int32ByteSize, -1, 10],
+ # Int64ByteSize().
+ [wire_format.Int64ByteSize, 0, 1],
+ [wire_format.Int64ByteSize, 127, 1],
+ [wire_format.Int64ByteSize, 128, 2],
+ [wire_format.Int64ByteSize, -1, 10],
+ # UInt32ByteSize().
+ [wire_format.UInt32ByteSize, 0, 1],
+ [wire_format.UInt32ByteSize, 127, 1],
+ [wire_format.UInt32ByteSize, 128, 2],
+ [wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5],
+ # UInt64ByteSize().
+ [wire_format.UInt64ByteSize, 0, 1],
+ [wire_format.UInt64ByteSize, 127, 1],
+ [wire_format.UInt64ByteSize, 128, 2],
+ [wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10],
+ # SInt32ByteSize().
+ [wire_format.SInt32ByteSize, 0, 1],
+ [wire_format.SInt32ByteSize, -1, 1],
+ [wire_format.SInt32ByteSize, 1, 1],
+ [wire_format.SInt32ByteSize, -63, 1],
+ [wire_format.SInt32ByteSize, 63, 1],
+ [wire_format.SInt32ByteSize, -64, 1],
+ [wire_format.SInt32ByteSize, 64, 2],
+ # SInt64ByteSize().
+ [wire_format.SInt64ByteSize, 0, 1],
+ [wire_format.SInt64ByteSize, -1, 1],
+ [wire_format.SInt64ByteSize, 1, 1],
+ [wire_format.SInt64ByteSize, -63, 1],
+ [wire_format.SInt64ByteSize, 63, 1],
+ [wire_format.SInt64ByteSize, -64, 1],
+ [wire_format.SInt64ByteSize, 64, 2],
+ # Fixed32ByteSize().
+ [wire_format.Fixed32ByteSize, 0, 4],
+ [wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4],
+ # Fixed64ByteSize().
+ [wire_format.Fixed64ByteSize, 0, 8],
+ [wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8],
+ # SFixed32ByteSize().
+ [wire_format.SFixed32ByteSize, 0, 4],
+ [wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4],
+ [wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4],
+ # SFixed64ByteSize().
+ [wire_format.SFixed64ByteSize, 0, 8],
+ [wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8],
+ [wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8],
+ # FloatByteSize().
+ [wire_format.FloatByteSize, 0.0, 4],
+ [wire_format.FloatByteSize, 1000000000.0, 4],
+ [wire_format.FloatByteSize, -1000000000.0, 4],
+ # DoubleByteSize().
+ [wire_format.DoubleByteSize, 0.0, 8],
+ [wire_format.DoubleByteSize, 1000000000.0, 8],
+ [wire_format.DoubleByteSize, -1000000000.0, 8],
+ # BoolByteSize().
+ [wire_format.BoolByteSize, False, 1],
+ [wire_format.BoolByteSize, True, 1],
+ # EnumByteSize().
+ [wire_format.EnumByteSize, 0, 1],
+ [wire_format.EnumByteSize, 127, 1],
+ [wire_format.EnumByteSize, 128, 2],
+ [wire_format.EnumByteSize, wire_format.UINT32_MAX, 5],
+ ]
+ for args in NUMERIC_ARGS:
+ self.NumericByteSizeTestHelper(*args)
+
+ # Test strings and bytes.
+ for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize):
+ # 1 byte for tag, 1 byte for length, 3 bytes for contents.
+ self.assertEqual(5, byte_size_fn(10, 'abc'))
+ # 2 bytes for tag, 1 byte for length, 3 bytes for contents.
+ self.assertEqual(6, byte_size_fn(16, 'abc'))
+ # 2 bytes for tag, 2 bytes for length, 128 bytes for contents.
+ self.assertEqual(132, byte_size_fn(16, 'a' * 128))
+
+ # Test UTF-8 string byte size calculation.
+ # 1 byte for tag, 1 byte for length, 8 bytes for content.
+ self.assertEqual(10, wire_format.StringByteSize(
+ 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8')))
+
+ class MockMessage(object):
+ def __init__(self, byte_size):
+ self.byte_size = byte_size
+ def ByteSize(self):
+ return self.byte_size
+
+ message_byte_size = 10
+ mock_message = MockMessage(byte_size=message_byte_size)
+ # Test groups.
+ # (2 * 1) bytes for begin and end tags, plus message_byte_size.
+ self.assertEqual(2 + message_byte_size,
+ wire_format.GroupByteSize(1, mock_message))
+ # (2 * 2) bytes for begin and end tags, plus message_byte_size.
+ self.assertEqual(4 + message_byte_size,
+ wire_format.GroupByteSize(16, mock_message))
+
+ # Test messages.
+ # 1 byte for tag, plus 1 byte for length, plus contents.
+ self.assertEqual(2 + mock_message.byte_size,
+ wire_format.MessageByteSize(1, mock_message))
+ # 2 bytes for tag, plus 1 byte for length, plus contents.
+ self.assertEqual(3 + mock_message.byte_size,
+ wire_format.MessageByteSize(16, mock_message))
+ # 2 bytes for tag, plus 2 bytes for length, plus contents.
+ mock_message.byte_size = 128
+ self.assertEqual(4 + mock_message.byte_size,
+ wire_format.MessageByteSize(16, mock_message))
+
+
+ # Test message set item byte size.
+ # 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id,
+ # plus contents.
+ mock_message.byte_size = 10
+ self.assertEqual(mock_message.byte_size + 6,
+ wire_format.MessageSetItemByteSize(1, mock_message))
+
+ # 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id,
+ # plus contents.
+ mock_message.byte_size = 128
+ self.assertEqual(mock_message.byte_size + 7,
+ wire_format.MessageSetItemByteSize(1, mock_message))
+
+ # 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id,
+ # plus contents.
+ self.assertEqual(mock_message.byte_size + 8,
+ wire_format.MessageSetItemByteSize(128, mock_message))
+
+ # Too-long varint.
+ self.assertRaises(message.EncodeError,
+ wire_format.UInt64ByteSize, 1, 1 << 128)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
new file mode 100755
index 0000000..9a88bdc
--- /dev/null
+++ b/python/google/protobuf/message.py
@@ -0,0 +1,245 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# TODO(robinson): We should just make these methods all "pure-virtual" and move
+# all implementation out, into reflection.py for now.
+
+
+"""Contains an abstract base class for protocol messages."""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+
+class Error(Exception): pass
+class DecodeError(Error): pass
+class EncodeError(Error): pass
+
+
+class Message(object):
+
+ """Abstract base class for protocol messages.
+
+ Protocol message classes are almost always generated by the protocol
+ compiler. These generated types subclass Message and implement the methods
+ shown below.
+
+ TODO(robinson): Link to an HTML document here.
+
+ TODO(robinson): Document that instances of this class will also
+ have an Extensions attribute with __getitem__ and __setitem__.
+ Again, not sure how to best convey this.
+
+ TODO(robinson): Document that the class must also have a static
+ RegisterExtension(extension_field) method.
+ Not sure how to best express at this point.
+ """
+
+ # TODO(robinson): Document these fields and methods.
+
+ __slots__ = []
+
+ DESCRIPTOR = None
+
+ def __eq__(self, other_msg):
+ raise NotImplementedError
+
+ def __ne__(self, other_msg):
+ # Can't just say self != other_msg, since that would infinitely recurse. :)
+ return not self == other_msg
+
+ def __str__(self):
+ raise NotImplementedError
+
+ def MergeFrom(self, other_msg):
+ """Merges the contents of the specified message into current message.
+
+ This method merges the contents of the specified message into the current
+ message. Singular fields that are set in the specified message overwrite
+ the corresponding fields in the current message. Repeated fields are
+ appended. Singular sub-messages and groups are recursively merged.
+
+ Args:
+ other_msg: Message to merge into the current message.
+ """
+ raise NotImplementedError
+
+ def CopyFrom(self, other_msg):
+ """Copies the content of the specified message into the current message.
+
+ The method clears the current message and then merges the specified
+ message using MergeFrom.
+
+ Args:
+ other_msg: Message to copy into the current one.
+ """
+ if self == other_msg:
+ return
+ self.Clear()
+ self.MergeFrom(other_msg)
+
+ def Clear(self):
+ """Clears all data that was set in the message."""
+ raise NotImplementedError
+
+ def IsInitialized(self):
+ """Checks if the message is initialized.
+
+ Returns:
+ The method returns True if the message is initialized (i.e. all of its
+ required fields are set).
+ """
+ raise NotImplementedError
+
+ # TODO(robinson): MergeFromString() should probably return None and be
+ # implemented in terms of a helper that returns the # of bytes read. Our
+ # deserialization routines would use the helper when recursively
+ # deserializing, but the end user would almost always just want the no-return
+ # MergeFromString().
+
+ def MergeFromString(self, serialized):
+ """Merges serialized protocol buffer data into this message.
+
+ When we find a field in |serialized| that is already present
+ in this message:
+ - If it's a "repeated" field, we append to the end of our list.
+ - Else, if it's a scalar, we overwrite our field.
+ - Else, (it's a nonrepeated composite), we recursively merge
+ into the existing composite.
+
+ TODO(robinson): Document handling of unknown fields.
+
+ Args:
+ serialized: Any object that allows us to call buffer(serialized)
+ to access a string of bytes using the buffer interface.
+
+ TODO(robinson): When we switch to a helper, this will return None.
+
+ Returns:
+ The number of bytes read from |serialized|.
+ For non-group messages, this will always be len(serialized),
+ but for messages which are actually groups, this will
+ generally be less than len(serialized), since we must
+ stop when we reach an END_GROUP tag. Note that if
+ we *do* stop because of an END_GROUP tag, the number
+ of bytes returned does not include the bytes
+ for the END_GROUP tag information.
+ """
+ raise NotImplementedError
+
+ def ParseFromString(self, serialized):
+ """Like MergeFromString(), except we clear the object first."""
+ self.Clear()
+ self.MergeFromString(serialized)
+
+ def SerializeToString(self):
+ """Serializes the protocol message to a binary string.
+
+ Returns:
+ A binary string representation of the message if all of the required
+ fields in the message are set (i.e. the message is initialized).
+
+ Raises:
+ message.EncodeError if the message isn't initialized.
+ """
+ raise NotImplementedError
+
+ def SerializePartialToString(self):
+ """Serializes the protocol message to a binary string.
+
+ This method is similar to SerializeToString but doesn't check if the
+ message is initialized.
+
+ Returns:
+ A string representation of the partial message.
+ """
+ raise NotImplementedError
+
+ # TODO(robinson): Decide whether we like these better
+ # than auto-generated has_foo() and clear_foo() methods
+ # on the instances themselves. This way is less consistent
+ # with C++, but it makes reflection-type access easier and
+ # reduces the number of magically autogenerated things.
+ #
+ # TODO(robinson): Be sure to document (and test) exactly
+ # which field names are accepted here. Are we case-sensitive?
+ # What do we do with fields that share names with Python keywords
+ # like 'lambda' and 'yield'?
+ #
+ # nnorwitz says:
+ # """
+ # Typically (in python), an underscore is appended to names that are
+ # keywords. So they would become lambda_ or yield_.
+ # """
+ def ListFields(self):
+ """Returns a list of (FieldDescriptor, value) tuples for all
+ fields in the message which are not empty. A singular field is non-empty
+ if HasField() would return true, and a repeated field is non-empty if
+ it contains at least one element. The fields are ordered by field
+ number"""
+ raise NotImplementedError
+
+ def HasField(self, field_name):
+ raise NotImplementedError
+
+ def ClearField(self, field_name):
+ raise NotImplementedError
+
+ def HasExtension(self, extension_handle):
+ raise NotImplementedError
+
+ def ClearExtension(self, extension_handle):
+ raise NotImplementedError
+
+ def ByteSize(self):
+ """Returns the serialized size of this message.
+ Recursively calls ByteSize() on all contained messages.
+ """
+ raise NotImplementedError
+
+ def _SetListener(self, message_listener):
+ """Internal method used by the protocol message implementation.
+ Clients should not call this directly.
+
+ Sets a listener that this message will call on certain state transitions.
+
+ The purpose of this method is to register back-edges from children to
+ parents at runtime, for the purpose of setting "has" bits and
+ byte-size-dirty bits in the parent and ancestor objects whenever a child or
+ descendant object is modified.
+
+ If the client wants to disconnect this Message from the object tree, she
+ explicitly sets callback to None.
+
+ If message_listener is None, unregisters any existing listener. Otherwise,
+ message_listener must implement the MessageListener interface in
+ internal/message_listener.py, and we discard any listener registered
+ via a previous _SetListener() call.
+ """
+ raise NotImplementedError
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
new file mode 100755
index 0000000..d65d8b6
--- /dev/null
+++ b/python/google/protobuf/reflection.py
@@ -0,0 +1,1661 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+# This code is meant to work on Python 2.4 and above only.
+#
+# TODO(robinson): Helpers for verbose, common checks like seeing if a
+# descriptor's cpp_type is CPPTYPE_MESSAGE.
+
+"""Contains a metaclass and helper functions used to create
+protocol message classes from Descriptor objects at runtime.
+
+Recall that a metaclass is the "type" of a class.
+(A class is to a metaclass what an instance is to a class.)
+
+In this case, we use the GeneratedProtocolMessageType metaclass
+to inject all the useful functionality into the classes
+output by the protocol compiler at compile-time.
+
+The upshot of all this is that the real implementation
+details for ALL pure-Python protocol buffers are *here in
+this file*.
+"""
+
+__author__ = 'robinson@google.com (Will Robinson)'
+
+import heapq
+import threading
+import weakref
+# We use "as" to avoid name collisions with variables.
+from google.protobuf.internal import containers
+from google.protobuf.internal import decoder
+from google.protobuf.internal import encoder
+from google.protobuf.internal import message_listener as message_listener_mod
+from google.protobuf.internal import type_checkers
+from google.protobuf.internal import wire_format
+from google.protobuf import descriptor as descriptor_mod
+from google.protobuf import message as message_mod
+from google.protobuf import text_format
+
+_FieldDescriptor = descriptor_mod.FieldDescriptor
+
+
+class GeneratedProtocolMessageType(type):
+
+ """Metaclass for protocol message classes created at runtime from Descriptors.
+
+ We add implementations for all methods described in the Message class. We
+ also create properties to allow getting/setting all fields in the protocol
+ message. Finally, we create slots to prevent users from accidentally
+ "setting" nonexistent fields in the protocol message, which then wouldn't get
+ serialized / deserialized properly.
+
+ The protocol compiler currently uses this metaclass to create protocol
+ message classes at runtime. Clients can also manually create their own
+ classes at runtime, as in this example:
+
+ mydescriptor = Descriptor(.....)
+ class MyProtoClass(Message):
+ __metaclass__ = GeneratedProtocolMessageType
+ DESCRIPTOR = mydescriptor
+ myproto_instance = MyProtoClass()
+ myproto.foo_field = 23
+ ...
+ """
+
+ # Must be consistent with the protocol-compiler code in
+ # proto2/compiler/internal/generator.*.
+ _DESCRIPTOR_KEY = 'DESCRIPTOR'
+
+ def __new__(cls, name, bases, dictionary):
+ """Custom allocation for runtime-generated class types.
+
+ We override __new__ because this is apparently the only place
+ where we can meaningfully set __slots__ on the class we're creating(?).
+ (The interplay between metaclasses and slots is not very well-documented).
+
+ Args:
+ name: Name of the class (ignored, but required by the
+ metaclass protocol).
+ bases: Base classes of the class we're constructing.
+ (Should be message.Message). We ignore this field, but
+ it's required by the metaclass protocol
+ dictionary: The class dictionary of the class we're
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain
+ a Descriptor object describing this protocol message
+ type.
+
+ Returns:
+ Newly-allocated class.
+ """
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ _AddSlots(descriptor, dictionary)
+ _AddClassAttributesForNestedExtensions(descriptor, dictionary)
+ superclass = super(GeneratedProtocolMessageType, cls)
+ return superclass.__new__(cls, name, bases, dictionary)
+
+ def __init__(cls, name, bases, dictionary):
+ """Here we perform the majority of our work on the class.
+ We add enum getters, an __init__ method, implementations
+ of all Message methods, and properties for all fields
+ in the protocol type.
+
+ Args:
+ name: Name of the class (ignored, but required by the
+ metaclass protocol).
+ bases: Base classes of the class we're constructing.
+ (Should be message.Message). We ignore this field, but
+ it's required by the metaclass protocol
+ dictionary: The class dictionary of the class we're
+ constructing. dictionary[_DESCRIPTOR_KEY] must contain
+ a Descriptor object describing this protocol message
+ type.
+ """
+ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+ # We act as a "friend" class of the descriptor, setting
+ # its _concrete_class attribute the first time we use a
+ # given descriptor to initialize a concrete protocol message
+ # class.
+ concrete_class_attr_name = '_concrete_class'
+ if not hasattr(descriptor, concrete_class_attr_name):
+ setattr(descriptor, concrete_class_attr_name, cls)
+ cls._known_extensions = []
+ _AddEnumValues(descriptor, cls)
+ _AddInitMethod(descriptor, cls)
+ _AddPropertiesForFields(descriptor, cls)
+ _AddPropertiesForExtensions(descriptor, cls)
+ _AddStaticMethods(cls)
+ _AddMessageMethods(descriptor, cls)
+ _AddPrivateHelperMethods(cls)
+ superclass = super(GeneratedProtocolMessageType, cls)
+ superclass.__init__(name, bases, dictionary)
+
+
+# Stateless helpers for GeneratedProtocolMessageType below.
+# Outside clients should not access these directly.
+#
+# I opted not to make any of these methods on the metaclass, to make it more
+# clear that I'm not really using any state there and to keep clients from
+# thinking that they have direct access to these construction helpers.
+
+
+def _PropertyName(proto_field_name):
+ """Returns the name of the public property attribute which
+ clients can use to get and (in some cases) set the value
+ of a protocol message field.
+
+ Args:
+ proto_field_name: The protocol message field name, exactly
+ as it appears (or would appear) in a .proto file.
+ """
+ # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
+ # nnorwitz makes my day by writing:
+ # """
+ # FYI. See the keyword module in the stdlib. This could be as simple as:
+ #
+ # if keyword.iskeyword(proto_field_name):
+ # return proto_field_name + "_"
+ # return proto_field_name
+ # """
+ return proto_field_name
+
+
+def _ValueFieldName(proto_field_name):
+ """Returns the name of the (internal) instance attribute which objects
+ should use to store the current value for a given protocol message field.
+
+ Args:
+ proto_field_name: The protocol message field name, exactly
+ as it appears (or would appear) in a .proto file.
+ """
+ return '_value_' + proto_field_name
+
+
+def _HasFieldName(proto_field_name):
+ """Returns the name of the (internal) instance attribute which
+ objects should use to store a boolean telling whether this field
+ is explicitly set or not.
+
+ Args:
+ proto_field_name: The protocol message field name, exactly
+ as it appears (or would appear) in a .proto file.
+ """
+ return '_has_' + proto_field_name
+
+
+def _AddSlots(message_descriptor, dictionary):
+ """Adds a __slots__ entry to dictionary, containing the names of all valid
+ attributes for this message type.
+
+ Args:
+ message_descriptor: A Descriptor instance describing this message type.
+ dictionary: Class dictionary to which we'll add a '__slots__' entry.
+ """
+ field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields]
+ field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields
+ if f.label != _FieldDescriptor.LABEL_REPEATED)
+ field_names.extend(('Extensions',
+ '_cached_byte_size',
+ '_cached_byte_size_dirty',
+ '_called_transition_to_nonempty',
+ '_listener',
+ '_lock', '__weakref__'))
+ dictionary['__slots__'] = field_names
+
+
+def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
+ extension_dict = descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ assert extension_name not in dictionary
+ dictionary[extension_name] = extension_field
+
+
+def _AddEnumValues(descriptor, cls):
+ """Sets class-level attributes for all enum fields defined in this message.
+
+ Args:
+ descriptor: Descriptor object for this message type.
+ cls: Class we're constructing for this message type.
+ """
+ for enum_type in descriptor.enum_types:
+ for enum_value in enum_type.values:
+ setattr(cls, enum_value.name, enum_value.number)
+
+
+def _DefaultValueForField(message, field):
+ """Returns a default value for a field.
+
+ Args:
+ message: Message instance containing this field, or a weakref proxy
+ of same.
+ field: FieldDescriptor object for this field.
+
+ Returns: A default value for this field. May refer back to |message|
+ via a weak reference.
+ """
+ # TODO(robinson): Only the repeated fields need a reference to 'message' (so
+ # that they can set the 'has' bit on the containing Message when someone
+ # append()s a value). We could special-case this, and avoid an extra
+ # function call on __init__() and Clear() for non-repeated fields.
+
+ # TODO(robinson): Find a better place for the default value assertion in this
+ # function. No need to repeat them every time the client calls Clear('foo').
+ # (We should probably just assert these things once and as early as possible,
+ # by tightening checking in the descriptor classes.)
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ if field.default_value != []:
+ raise ValueError('Repeated field default value not empty list: %s' % (
+ field.default_value))
+ listener = _Listener(message, None)
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ # We can't look at _concrete_class yet since it might not have
+ # been set. (Depends on order in which we initialize the classes).
+ return containers.RepeatedCompositeFieldContainer(
+ listener, field.message_type)
+ else:
+ return containers.RepeatedScalarFieldContainer(
+ listener, type_checkers.GetTypeChecker(field.cpp_type, field.type))
+
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ assert field.default_value is None
+
+ return field.default_value
+
+
+def _AddInitMethod(message_descriptor, cls):
+ """Adds an __init__ method to cls."""
+ fields = message_descriptor.fields
+ def init(self, **kwargs):
+ self._cached_byte_size = 0
+ self._cached_byte_size_dirty = False
+ self._listener = message_listener_mod.NullMessageListener()
+ self._called_transition_to_nonempty = False
+ # TODO(robinson): We should only create a lock if we really need one
+ # in this class.
+ self._lock = threading.Lock()
+ for field in fields:
+ default_value = _DefaultValueForField(self, field)
+ python_field_name = _ValueFieldName(field.name)
+ setattr(self, python_field_name, default_value)
+ if field.label != _FieldDescriptor.LABEL_REPEATED:
+ setattr(self, _HasFieldName(field.name), False)
+ self.Extensions = _ExtensionDict(self, cls._known_extensions)
+ for field_name, field_value in kwargs.iteritems():
+ field = _GetFieldByName(message_descriptor, field_name)
+ _MergeFieldOrExtension(self, field, field_value)
+
+ init.__module__ = None
+ init.__doc__ = None
+ cls.__init__ = init
+
+
+def _GetFieldByName(message_descriptor, field_name):
+ """Returns a field descriptor by field name.
+
+ Args:
+ message_descriptor: A Descriptor describing all fields in message.
+ field_name: The name of the field to retrieve.
+ Returns:
+ The field descriptor associated with the field name.
+ """
+ try:
+ return message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+
+def _AddPropertiesForFields(descriptor, cls):
+ """Adds properties for all fields in this protocol message type."""
+ for field in descriptor.fields:
+ _AddPropertiesForField(field, cls)
+
+
+def _AddPropertiesForField(field, cls):
+ """Adds a public property for a protocol message field.
+ Clients can use this property to get and (in the case
+ of non-repeated scalar fields) directly set the value
+ of a protocol message field.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ # Catch it if we add other types that we should
+ # handle specially here.
+ assert _FieldDescriptor.MAX_CPPTYPE == 10
+
+ constant_name = field.name.upper() + "_FIELD_NUMBER"
+ setattr(cls, constant_name, field.number)
+
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ _AddPropertiesForRepeatedField(field, cls)
+ elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ _AddPropertiesForNonRepeatedCompositeField(field, cls)
+ else:
+ _AddPropertiesForNonRepeatedScalarField(field, cls)
+
+
+def _AddPropertiesForRepeatedField(field, cls):
+ """Adds a public property for a "repeated" protocol message field. Clients
+ can use this property to get the value of the field, which will be either a
+ _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
+ below).
+
+ Note that when clients add values to these containers, we perform
+ type-checking in the case of repeated scalar fields, and we also set any
+ necessary "has" bits as a side-effect.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ proto_field_name = field.name
+ python_field_name = _ValueFieldName(proto_field_name)
+ property_name = _PropertyName(proto_field_name)
+
+ def getter(self):
+ return getattr(self, python_field_name)
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ # We define a setter just so we can throw an exception with a more
+ # helpful error message.
+ def setter(self, new_value):
+ raise AttributeError('Assignment not allowed to repeated field '
+ '"%s" in protocol message object.' % proto_field_name)
+
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForNonRepeatedScalarField(field, cls):
+ """Adds a public property for a nonrepeated, scalar protocol message field.
+ Clients can use this property to get and directly set the value of the field.
+ Note that when the client sets the value of a field by using this property,
+ all necessary "has" bits are set as a side-effect, and we also perform
+ type-checking.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ proto_field_name = field.name
+ python_field_name = _ValueFieldName(proto_field_name)
+ has_field_name = _HasFieldName(proto_field_name)
+ property_name = _PropertyName(proto_field_name)
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+
+ def getter(self):
+ return getattr(self, python_field_name)
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+ def setter(self, new_value):
+ type_checker.CheckValue(new_value)
+ setattr(self, has_field_name, True)
+ self._MarkByteSizeDirty()
+ self._MaybeCallTransitionToNonemptyCallback()
+ setattr(self, python_field_name, new_value)
+ setter.__module__ = None
+ setter.__doc__ = 'Setter for %s.' % proto_field_name
+
+ # Add a property to encapsulate the getter/setter.
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForNonRepeatedCompositeField(field, cls):
+ """Adds a public property for a nonrepeated, composite protocol message field.
+ A composite field is a "group" or "message" field.
+
+ Clients can use this property to get the value of the field, but cannot
+ assign to the property directly.
+
+ Args:
+ field: A FieldDescriptor for this field.
+ cls: The class we're constructing.
+ """
+ # TODO(robinson): Remove duplication with similar method
+ # for non-repeated scalars.
+ proto_field_name = field.name
+ python_field_name = _ValueFieldName(proto_field_name)
+ has_field_name = _HasFieldName(proto_field_name)
+ property_name = _PropertyName(proto_field_name)
+ message_type = field.message_type
+
+ def getter(self):
+ # TODO(robinson): Appropriately scary note about double-checked locking.
+ field_value = getattr(self, python_field_name)
+ if field_value is None:
+ self._lock.acquire()
+ try:
+ field_value = getattr(self, python_field_name)
+ if field_value is None:
+ field_class = message_type._concrete_class
+ field_value = field_class()
+ field_value._SetListener(_Listener(self, has_field_name))
+ setattr(self, python_field_name, field_value)
+ finally:
+ self._lock.release()
+ return field_value
+ getter.__module__ = None
+ getter.__doc__ = 'Getter for %s.' % proto_field_name
+
+ # We define a setter just so we can throw an exception with a more
+ # helpful error message.
+ def setter(self, new_value):
+ raise AttributeError('Assignment not allowed to composite field '
+ '"%s" in protocol message object.' % proto_field_name)
+
+ # Add a property to encapsulate the getter.
+ doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
+ setattr(cls, property_name, property(getter, setter, doc=doc))
+
+
+def _AddPropertiesForExtensions(descriptor, cls):
+ """Adds properties for all fields in this protocol message type."""
+ extension_dict = descriptor.extensions_by_name
+ for extension_name, extension_field in extension_dict.iteritems():
+ constant_name = extension_name.upper() + "_FIELD_NUMBER"
+ setattr(cls, constant_name, extension_field.number)
+
+
+def _AddStaticMethods(cls):
+ # TODO(robinson): This probably needs to be thread-safe(?)
+ def RegisterExtension(extension_handle):
+ extension_handle.containing_type = cls.DESCRIPTOR
+ cls._known_extensions.append(extension_handle)
+ cls.RegisterExtension = staticmethod(RegisterExtension)
+
+ def FromString(s):
+ message = cls()
+ message.MergeFromString(s)
+ return message
+ cls.FromString = staticmethod(FromString)
+
+
+def _AddListFieldsMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ # Ensure that we always list in ascending field-number order.
+ # For non-extension fields, we can do the sort once, here, at import-time.
+ # For extensions, we sort on each ListFields() call, though
+ # we could do better if we have to.
+ fields = sorted(message_descriptor.fields, key=lambda f: f.number)
+ has_field_names = (_HasFieldName(f.name) for f in fields)
+ value_field_names = (_ValueFieldName(f.name) for f in fields)
+ triplets = zip(has_field_names, value_field_names, fields)
+
+ def ListFields(self):
+ # We need to list all extension and non-extension fields
+ # together, in sorted order by field number.
+
+ # Step 0: Get an iterator over all "set" non-extension fields,
+ # sorted by field number.
+ # This iterator yields (field_number, field_descriptor, value) tuples.
+ def SortedSetFieldsIter():
+ # Note that triplets is already sorted by field number.
+ for has_field_name, value_field_name, field_descriptor in triplets:
+ if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
+ value = getattr(self, _ValueFieldName(field_descriptor.name))
+ if len(value) > 0:
+ yield (field_descriptor.number, field_descriptor, value)
+ elif getattr(self, _HasFieldName(field_descriptor.name)):
+ value = getattr(self, _ValueFieldName(field_descriptor.name))
+ yield (field_descriptor.number, field_descriptor, value)
+ sorted_fields = SortedSetFieldsIter()
+
+ # Step 1: Get an iterator over all "set" extension fields,
+ # sorted by field number.
+ # This iterator ALSO yields (field_number, field_descriptor, value) tuples.
+ # TODO(robinson): It's not necessary to repeat this with each
+ # serialization call. We can do better.
+ sorted_extension_fields = sorted(
+ [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()])
+
+ # Step 2: Create a composite iterator that merges the extension-
+ # and non-extension fields, and that still yields fields in
+ # sorted order.
+ all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields)
+
+ # Step 3: Strip off the field numbers and return.
+ return [field[1:] for field in all_set_fields]
+
+ cls.ListFields = ListFields
+
+def _AddHasFieldMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def HasField(self, field_name):
+ try:
+ return getattr(self, _HasFieldName(field_name))
+ except AttributeError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+ cls.HasField = HasField
+
+
+def _AddClearFieldMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def ClearField(self, field_name):
+ field = _GetFieldByName(self.DESCRIPTOR, field_name)
+ proto_field_name = field.name
+ python_field_name = _ValueFieldName(proto_field_name)
+ has_field_name = _HasFieldName(proto_field_name)
+ default_value = _DefaultValueForField(self, field)
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ self._MarkByteSizeDirty()
+ else:
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ old_field_value = getattr(self, python_field_name)
+ if old_field_value is not None:
+ # Snip the old object out of the object tree.
+ old_field_value._SetListener(None)
+ if getattr(self, has_field_name):
+ setattr(self, has_field_name, False)
+ # Set dirty bit on ourself and parents only if
+ # we're actually changing state.
+ self._MarkByteSizeDirty()
+ setattr(self, python_field_name, default_value)
+ cls.ClearField = ClearField
+
+
+def _AddClearExtensionMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def ClearExtension(self, extension_handle):
+ self.Extensions._ClearExtension(extension_handle)
+ cls.ClearExtension = ClearExtension
+
+
+def _AddClearMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def Clear(self):
+ # Clear fields.
+ fields = self.DESCRIPTOR.fields
+ for field in fields:
+ self.ClearField(field.name)
+ # Clear extensions.
+ extensions = self.Extensions._ListSetExtensions()
+ for extension in extensions:
+ self.ClearExtension(extension[0])
+ cls.Clear = Clear
+
+
+def _AddHasExtensionMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def HasExtension(self, extension_handle):
+ return self.Extensions._HasExtension(extension_handle)
+ cls.HasExtension = HasExtension
+
+
+def _AddEqualsMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __eq__(self, other):
+ if (not isinstance(other, message_mod.Message) or
+ other.DESCRIPTOR != self.DESCRIPTOR):
+ return False
+
+ if self is other:
+ return True
+
+ # Compare all fields contained directly in this message.
+ for field_descriptor in message_descriptor.fields:
+ label = field_descriptor.label
+ property_name = _PropertyName(field_descriptor.name)
+ # Non-repeated field equality requires matching "has" bits as well
+ # as having an equal value.
+ if label != _FieldDescriptor.LABEL_REPEATED:
+ self_has = self.HasField(property_name)
+ other_has = other.HasField(property_name)
+ if self_has != other_has:
+ return False
+ if not self_has:
+ # If the "has" bit for this field is False, we must stop here.
+ # Otherwise we will recurse forever on recursively-defined protos.
+ continue
+ if getattr(self, property_name) != getattr(other, property_name):
+ return False
+
+ # Compare the extensions present in both messages.
+ return self.Extensions == other.Extensions
+ cls.__eq__ = __eq__
+
+
+def _AddStrMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __str__(self):
+ return text_format.MessageToString(self)
+ cls.__str__ = __str__
+
+
+def _AddSetListenerMethod(cls):
+ """Helper for _AddMessageMethods()."""
+ def SetListener(self, listener):
+ if listener is None:
+ self._listener = message_listener_mod.NullMessageListener()
+ else:
+ self._listener = listener
+ cls._SetListener = SetListener
+
+
+def _BytesForNonRepeatedElement(value, field_number, field_type):
+ """Returns the number of bytes needed to serialize a non-repeated element.
+ The returned byte count includes space for tag information and any
+ other additional space associated with serializing value.
+
+ Args:
+ value: Value we're serializing.
+ field_number: Field number of this value. (Since the field number
+ is stored as part of a varint-encoded tag, this has an impact
+ on the total bytes required to serialize the value).
+ field_type: The type of the field. One of the TYPE_* constants
+ within FieldDescriptor.
+ """
+ try:
+ fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
+ return fn(field_number, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
+
+
+def _AddByteSizeMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def BytesForField(message, field, value):
+ """Returns the number of bytes required to serialize a single field
+ in message. The field may be repeated or not, composite or not.
+
+ Args:
+ message: The Message instance containing a field of the given type.
+ field: A FieldDescriptor describing the field of interest.
+ value: The value whose byte size we're interested in.
+
+ Returns: The number of bytes required to serialize the current value
+ of "field" in "message", including space for tags and any other
+ necessary information.
+ """
+
+ if _MessageSetField(field):
+ return wire_format.MessageSetItemByteSize(field.number, value)
+
+ field_number, field_type = field.number, field.type
+
+ # Repeated fields.
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ elements = value
+ else:
+ elements = [value]
+
+ if field.GetOptions().packed:
+ content_size = _ContentBytesForPackedField(message, field, elements)
+ if content_size:
+ tag_size = wire_format.TagByteSize(field_number)
+ length_size = wire_format.Int32ByteSizeNoTag(content_size)
+ return tag_size + length_size + content_size
+ else:
+ return 0
+ else:
+ return sum(_BytesForNonRepeatedElement(element, field_number, field_type)
+ for element in elements)
+
+ def _ContentBytesForPackedField(self, field, value):
+ """Returns the number of bytes required to serialize the actual
+ content of a packed field (not including the tag or the encoding
+ of the length.
+
+ Args:
+ self: The Message instance containing a field of the given type.
+ field: A FieldDescriptor describing the field of interest.
+ value: The value whose byte size we're interested in.
+
+ Returns: The number of bytes required to serialize the current value
+ of the packed "field" in "message", excluding space for tags and the
+ length encoding.
+ """
+ size = sum(_BytesForNonRepeatedElement(element, field.number, field.type)
+ for element in value)
+ # In the packed case, there are no per element tags.
+ return size - wire_format.TagByteSize(field.number) * len(value)
+
+ fields = message_descriptor.fields
+ has_field_names = (_HasFieldName(f.name) for f in fields)
+ zipped = zip(has_field_names, fields)
+
+ def ByteSize(self):
+ if not self._cached_byte_size_dirty:
+ return self._cached_byte_size
+
+ size = 0
+ # Hardcoded fields first.
+ for has_field_name, field in zipped:
+ if (field.label == _FieldDescriptor.LABEL_REPEATED
+ or getattr(self, has_field_name)):
+ value = getattr(self, _ValueFieldName(field.name))
+ size += BytesForField(self, field, value)
+ # Extensions next.
+ for field, value in self.Extensions._ListSetExtensions():
+ size += BytesForField(self, field, value)
+
+ self._cached_byte_size = size
+ self._cached_byte_size_dirty = False
+ return size
+
+ cls._ContentBytesForPackedField = _ContentBytesForPackedField
+ cls.ByteSize = ByteSize
+
+
+def _MessageSetField(field_descriptor):
+ """Checks if a field should be serialized using the message set wire format.
+
+ Args:
+ field_descriptor: Descriptor of the field.
+
+ Returns:
+ True if the field should be serialized using the message set wire format,
+ false otherwise.
+ """
+ return (field_descriptor.is_extension and
+ field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and
+ field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
+ field_descriptor.containing_type.GetOptions().message_set_wire_format)
+
+
+def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder):
+ """Appends the serialization of a single value to encoder.
+
+ Args:
+ value: Value to serialize.
+ field_number: Field number of this value.
+ field_descriptor: Descriptor of the field to serialize.
+ encoder: encoder.Encoder object to which we should serialize this value.
+ """
+ if _MessageSetField(field_descriptor):
+ encoder.AppendMessageSetItem(field_number, value)
+ return
+
+ try:
+ method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type]
+ method(encoder, field_number, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' %
+ field_descriptor.type)
+
+
+def _ImergeSorted(*streams):
+ """Merges N sorted iterators into a single sorted iterator.
+ Each element in streams must be an iterable that yields
+ its elements in sorted order, and the elements contained
+ in each stream must all be comparable.
+
+ There may be repeated elements in the component streams or
+ across the streams; the repeated elements will all be repeated
+ in the merged iterator as well.
+
+ I believe that the heapq module at HEAD in the Python
+ sources has a method like this, but for now we roll our own.
+ """
+ iters = [iter(stream) for stream in streams]
+ heap = []
+ for index, it in enumerate(iters):
+ try:
+ heap.append((it.next(), index))
+ except StopIteration:
+ pass
+ heapq.heapify(heap)
+
+ while heap:
+ smallest_value, idx = heap[0]
+ yield smallest_value
+ try:
+ next_element = iters[idx].next()
+ heapq.heapreplace(heap, (next_element, idx))
+ except StopIteration:
+ heapq.heappop(heap)
+
+
+def _AddSerializeToStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+
+ def SerializeToString(self):
+ # Check if the message has all of its required fields set.
+ errors = []
+ if not _InternalIsInitialized(self, errors):
+ raise message_mod.EncodeError('\n'.join(errors))
+ return self.SerializePartialToString()
+ cls.SerializeToString = SerializeToString
+
+
+def _AddSerializePartialToStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ Encoder = encoder.Encoder
+
+ def SerializePartialToString(self):
+ encoder = Encoder()
+ # We need to serialize all extension and non-extension fields
+ # together, in sorted order by field number.
+ for field_descriptor, field_value in self.ListFields():
+ if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
+ repeated_value = field_value
+ else:
+ repeated_value = [field_value]
+ if field_descriptor.GetOptions().packed:
+ # First, write the field number and WIRETYPE_LENGTH_DELIMITED.
+ field_number = field_descriptor.number
+ encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ # Next, write the number of bytes.
+ content_bytes = self._ContentBytesForPackedField(
+ field_descriptor, field_value)
+ encoder.AppendInt32NoTag(content_bytes)
+ # Finally, write the actual values.
+ try:
+ method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[
+ field_descriptor.type]
+ for value in repeated_value:
+ method(encoder, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' %
+ field_descriptor.type)
+ else:
+ for element in repeated_value:
+ _SerializeValueToEncoder(element, field_descriptor.number,
+ field_descriptor, encoder)
+ return encoder.ToString()
+
+ cls.SerializePartialToString = SerializePartialToString
+
+
+def _WireTypeForFieldType(field_type):
+ """Given a field type, returns the expected wire type."""
+ try:
+ return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type]
+ except KeyError:
+ raise message_mod.DecodeError('Unknown field type: %d' % field_type)
+
+
+def _WireTypeForField(field_descriptor):
+ """Given a field descriptor, returns the expected wire type."""
+ if field_descriptor.GetOptions().packed:
+ return wire_format.WIRETYPE_LENGTH_DELIMITED
+ else:
+ return _WireTypeForFieldType(field_descriptor.type)
+
+
+def _RecursivelyMerge(field_number, field_type, decoder, message):
+ """Decodes a message from decoder into message.
+ message is either a group or a nested message within some containing
+ protocol message. If it's a group, we use the group protocol to
+ deserialize, and if it's a nested message, we use the nested-message
+ protocol.
+
+ Args:
+ field_number: The field number of message in its enclosing protocol buffer.
+ field_type: The field type of message. Must be either TYPE_MESSAGE
+ or TYPE_GROUP.
+ decoder: Decoder to read from.
+ message: Message to deserialize into.
+ """
+ if field_type == _FieldDescriptor.TYPE_MESSAGE:
+ decoder.ReadMessageInto(message)
+ elif field_type == _FieldDescriptor.TYPE_GROUP:
+ decoder.ReadGroupInto(field_number, message)
+ else:
+ raise message_mod.DecodeError('Unexpected field type: %d' % field_type)
+
+
+def _DeserializeScalarFromDecoder(field_type, decoder):
+ """Deserializes a scalar of the requested type from decoder. field_type must
+ be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant.
+ """
+ try:
+ method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type]
+ return method(decoder)
+ except KeyError:
+ raise message_mod.DecodeError('Unrecognized field type: %d' % field_type)
+
+
+def _SkipField(field_number, wire_type, decoder):
+ """Skips a field with the specified wire type.
+
+ Args:
+ field_number: Tag number of the field to skip.
+ wire_type: Wire type of the field to skip.
+ decoder: Decoder used to deserialize the messsage. It must be positioned
+ just after reading the the tag and wire type of the field.
+ """
+ if wire_type == wire_format.WIRETYPE_VARINT:
+ decoder.ReadUInt64()
+ elif wire_type == wire_format.WIRETYPE_FIXED64:
+ decoder.ReadFixed64()
+ elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ decoder.SkipBytes(decoder.ReadInt32())
+ elif wire_type == wire_format.WIRETYPE_START_GROUP:
+ _SkipGroup(field_number, decoder)
+ elif wire_type == wire_format.WIRETYPE_END_GROUP:
+ pass
+ elif wire_type == wire_format.WIRETYPE_FIXED32:
+ decoder.ReadFixed32()
+ else:
+ raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type)
+
+
+def _SkipGroup(group_number, decoder):
+ """Skips a nested group from the decoder.
+
+ Args:
+ group_number: Tag number of the group to skip.
+ decoder: Decoder used to deserialize the message. It must be positioned
+ exactly at the beginning of the message that should be skipped.
+ """
+ while True:
+ field_number, wire_type = decoder.ReadFieldNumberAndWireType()
+ if (wire_type == wire_format.WIRETYPE_END_GROUP and
+ field_number == group_number):
+ return
+ _SkipField(field_number, wire_type, decoder)
+
+
+def _DeserializeMessageSetItem(message, decoder):
+ """Deserializes a message using the message set wire format.
+
+ Args:
+ message: Message to be parsed to.
+ decoder: The decoder to be used to deserialize encoded data. Note that the
+ decoder should be positioned just after reading the START_GROUP tag that
+ began the messageset item.
+ """
+ field_number, wire_type = decoder.ReadFieldNumberAndWireType()
+ if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2:
+ raise message_mod.DecodeError(
+ 'Incorrect message set wire format. '
+ 'wire_type: %d, field_number: %d' % (wire_type, field_number))
+
+ type_id = decoder.ReadInt32()
+ field_number, wire_type = decoder.ReadFieldNumberAndWireType()
+ if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3:
+ raise message_mod.DecodeError(
+ 'Incorrect message set wire format. '
+ 'wire_type: %d, field_number: %d' % (wire_type, field_number))
+
+ extension_dict = message.Extensions
+ extensions_by_number = extension_dict._AllExtensionsByNumber()
+ if type_id not in extensions_by_number:
+ _SkipField(field_number, wire_type, decoder)
+ return
+
+ field_descriptor = extensions_by_number[type_id]
+ value = extension_dict[field_descriptor]
+ decoder.ReadMessageInto(value)
+ # Read the END_GROUP tag.
+ field_number, wire_type = decoder.ReadFieldNumberAndWireType()
+ if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1:
+ raise message_mod.DecodeError(
+ 'Incorrect message set wire format. '
+ 'wire_type: %d, field_number: %d' % (wire_type, field_number))
+
+
+def _DeserializeOneEntity(message_descriptor, message, decoder):
+ """Deserializes the next wire entity from decoder into message.
+
+ The next wire entity is either a scalar or a nested message, an
+ element in a repeated field (the wire encoding in this case is the
+ same), or a packed repeated field (in this case, the entire repeated
+ field is read by a single call to _DeserializeOneEntity).
+
+ Args:
+ message_descriptor: A Descriptor instance describing all fields
+ in message.
+ message: The Message instance into which we're decoding our fields.
+ decoder: The Decoder we're using to deserialize encoded data.
+
+ Returns: The number of bytes read from decoder during this method.
+ """
+ initial_position = decoder.Position()
+ field_number, wire_type = decoder.ReadFieldNumberAndWireType()
+ extension_dict = message.Extensions
+ extensions_by_number = extension_dict._AllExtensionsByNumber()
+ if field_number in message_descriptor.fields_by_number:
+ # Non-extension field.
+ field_descriptor = message_descriptor.fields_by_number[field_number]
+ value = getattr(message, _PropertyName(field_descriptor.name))
+ def nonextension_setter_fn(scalar):
+ setattr(message, _PropertyName(field_descriptor.name), scalar)
+ scalar_setter_fn = nonextension_setter_fn
+ elif field_number in extensions_by_number:
+ # Extension field.
+ field_descriptor = extensions_by_number[field_number]
+ value = extension_dict[field_descriptor]
+ def extension_setter_fn(scalar):
+ extension_dict[field_descriptor] = scalar
+ scalar_setter_fn = extension_setter_fn
+ elif wire_type == wire_format.WIRETYPE_END_GROUP:
+ # We assume we're being parsed as the group that's ended.
+ return 0
+ elif (wire_type == wire_format.WIRETYPE_START_GROUP and
+ field_number == 1 and
+ message_descriptor.GetOptions().message_set_wire_format):
+ # A Message Set item.
+ _DeserializeMessageSetItem(message, decoder)
+ return decoder.Position() - initial_position
+ else:
+ _SkipField(field_number, wire_type, decoder)
+ return decoder.Position() - initial_position
+
+ # If we reach this point, we've identified the field as either
+ # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|,
+ # and |value| appropriately. Now actually deserialize the thing.
+ #
+ # field_descriptor: Describes the field we're deserializing.
+ # value: The value currently stored in the field to deserialize.
+ # Used only if the field is composite and/or repeated.
+ # scalar_setter_fn: A function F such that F(scalar) will
+ # set a nonrepeated scalar value for this field. Used only
+ # if this field is a nonrepeated scalar.
+
+ field_number = field_descriptor.number
+ expected_wire_type = _WireTypeForField(field_descriptor)
+ if wire_type != expected_wire_type:
+ # Need to fill in uninterpreted_bytes. Work for the next CL.
+ raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
+
+ property_name = _PropertyName(field_descriptor.name)
+ label = field_descriptor.label
+ field_type = field_descriptor.type
+ cpp_type = field_descriptor.cpp_type
+
+ # Nonrepeated scalar. Just set the field directly.
+ if (label != _FieldDescriptor.LABEL_REPEATED
+ and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
+ scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - initial_position
+
+ # Nonrepeated composite. Recursively deserialize.
+ if label != _FieldDescriptor.LABEL_REPEATED:
+ composite = value
+ _RecursivelyMerge(field_number, field_type, decoder, composite)
+ return decoder.Position() - initial_position
+
+ # Now we know we're dealing with a repeated field of some kind.
+ element_list = value
+
+ if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
+ # Repeated scalar.
+ if not field_descriptor.GetOptions().packed:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - initial_position
+ else:
+ # Packed repeated field.
+ length = _DeserializeScalarFromDecoder(
+ _FieldDescriptor.TYPE_INT32, decoder)
+ content_start = decoder.Position()
+ while decoder.Position() - content_start < length:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - initial_position
+ else:
+ # Repeated composite.
+ composite = element_list.add()
+ _RecursivelyMerge(field_number, field_type, decoder, composite)
+ return decoder.Position() - initial_position
+
+
+def _FieldOrExtensionValues(message, field_or_extension):
+ """Retrieves the list of values for the specified field or extension.
+
+ The target field or extension can be optional, required or repeated, but it
+ must have value(s) set. The assumption is that the target field or extension
+ is set (e.g. _HasFieldOrExtension holds true).
+
+ Args:
+ message: Message which contains the target field or extension.
+ field_or_extension: Field or extension for which the list of values is
+ required. Must be an instance of FieldDescriptor.
+
+ Returns:
+ A list of values for the specified field or extension. This list will only
+ contain a single element if the field is non-repeated.
+ """
+ if field_or_extension.is_extension:
+ value = message.Extensions[field_or_extension]
+ else:
+ value = getattr(message, _ValueFieldName(field_or_extension.name))
+ if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED:
+ return [value]
+ else:
+ # In this case value is a list or repeated values.
+ return value
+
+
+def _HasFieldOrExtension(message, field_or_extension):
+ """Checks if a message has the specified field or extension set.
+
+ The field or extension specified can be optional, required or repeated. If
+ it is repeated, this function returns True. Otherwise it checks the has bit
+ of the field or extension.
+
+ Args:
+ message: Message which contains the target field or extension.
+ field_or_extension: Field or extension to check. This must be a
+ FieldDescriptor instance.
+
+ Returns:
+ True if the message has a value set for the specified field or extension,
+ or if the field or extension is repeated.
+ """
+ if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED:
+ return True
+ if field_or_extension.is_extension:
+ return message.HasExtension(field_or_extension)
+ else:
+ return message.HasField(field_or_extension.name)
+
+
+def _IsFieldOrExtensionInitialized(message, field, errors=None):
+ """Checks if a message field or extension is initialized.
+
+ Args:
+ message: The message which contains the field or extension.
+ field: Field or extension to check. This must be a FieldDescriptor instance.
+ errors: Errors will be appended to it, if set to a meaningful value.
+
+ Returns:
+ True if the field/extension can be considered initialized.
+ """
+ # If the field is required and is not set, it isn't initialized.
+ if field.label == _FieldDescriptor.LABEL_REQUIRED:
+ if not _HasFieldOrExtension(message, field):
+ if errors is not None:
+ errors.append('Required field %s is not set.' % field.full_name)
+ return False
+
+ # If the field is optional and is not set, or if it
+ # isn't a submessage then the field is initialized.
+ if field.label == _FieldDescriptor.LABEL_OPTIONAL:
+ if not _HasFieldOrExtension(message, field):
+ return True
+ if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
+ return True
+
+ # The field is set and is either a single or a repeated submessage.
+ messages = _FieldOrExtensionValues(message, field)
+ # If all submessages in this field are initialized, the field is
+ # considered initialized.
+ for message in messages:
+ if not _InternalIsInitialized(message, errors):
+ return False
+ return True
+
+
+def _InternalIsInitialized(message, errors=None):
+ """Checks if all required fields of a message are set.
+
+ Args:
+ message: The message to check.
+ errors: If set, initialization errors will be appended to it.
+
+ Returns:
+ True iff the specified message has all required fields set.
+ """
+ fields_and_extensions = []
+ fields_and_extensions.extend(message.DESCRIPTOR.fields)
+ fields_and_extensions.extend(
+ [extension[0] for extension in message.Extensions._ListSetExtensions()])
+ for field_or_extension in fields_and_extensions:
+ if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors):
+ return False
+ return True
+
+
+def _AddMergeFromStringMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ Decoder = decoder.Decoder
+ def MergeFromString(self, serialized):
+ decoder = Decoder(serialized)
+ byte_count = 0
+ while not decoder.EndOfStream():
+ bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder)
+ if not bytes_read:
+ break
+ byte_count += bytes_read
+ return byte_count
+ cls.MergeFromString = MergeFromString
+
+
+def _AddIsInitializedMethod(cls):
+ """Adds the IsInitialized method to the protocol message class."""
+ cls.IsInitialized = _InternalIsInitialized
+
+
+def _MergeFieldOrExtension(destination_msg, field, value):
+ """Merges a specified message field into another message."""
+ property_name = _PropertyName(field.name)
+ is_extension = field.is_extension
+
+ if not is_extension:
+ destination = getattr(destination_msg, property_name)
+ elif (field.label == _FieldDescriptor.LABEL_REPEATED or
+ field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
+ destination = destination_msg.Extensions[field]
+
+ # Case 1 - a composite field.
+ if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for v in value:
+ destination.add().MergeFrom(v)
+ else:
+ destination.MergeFrom(value)
+ return
+
+ # Case 2 - a repeated field.
+ if field.label == _FieldDescriptor.LABEL_REPEATED:
+ for v in value:
+ destination.append(v)
+ return
+
+ # Case 3 - a singular field.
+ if is_extension:
+ destination_msg.Extensions[field] = value
+ else:
+ setattr(destination_msg, property_name, value)
+
+
+def _AddMergeFromMethod(cls):
+ def MergeFrom(self, msg):
+ assert msg is not self
+ for field in msg.ListFields():
+ _MergeFieldOrExtension(self, field[0], field[1])
+ cls.MergeFrom = MergeFrom
+
+
+def _AddMessageMethods(message_descriptor, cls):
+ """Adds implementations of all Message methods to cls."""
+ _AddListFieldsMethod(message_descriptor, cls)
+ _AddHasFieldMethod(cls)
+ _AddClearFieldMethod(cls)
+ _AddClearExtensionMethod(cls)
+ _AddClearMethod(cls)
+ _AddHasExtensionMethod(cls)
+ _AddEqualsMethod(message_descriptor, cls)
+ _AddStrMethod(message_descriptor, cls)
+ _AddSetListenerMethod(cls)
+ _AddByteSizeMethod(message_descriptor, cls)
+ _AddSerializeToStringMethod(message_descriptor, cls)
+ _AddSerializePartialToStringMethod(message_descriptor, cls)
+ _AddMergeFromStringMethod(message_descriptor, cls)
+ _AddIsInitializedMethod(cls)
+ _AddMergeFromMethod(cls)
+
+
+def _AddPrivateHelperMethods(cls):
+ """Adds implementation of private helper methods to cls."""
+
+ def MaybeCallTransitionToNonemptyCallback(self):
+ """Calls self._listener.TransitionToNonempty() the first time this
+ method is called. On all subsequent calls, this is a no-op.
+ """
+ if not self._called_transition_to_nonempty:
+ self._listener.TransitionToNonempty()
+ self._called_transition_to_nonempty = True
+ cls._MaybeCallTransitionToNonemptyCallback = (
+ MaybeCallTransitionToNonemptyCallback)
+
+ def MarkByteSizeDirty(self):
+ """Sets the _cached_byte_size_dirty bit to true,
+ and propagates this to our listener iff this was a state change.
+ """
+ if not self._cached_byte_size_dirty:
+ self._cached_byte_size_dirty = True
+ self._listener.ByteSizeDirty()
+ cls._MarkByteSizeDirty = MarkByteSizeDirty
+
+
+class _Listener(object):
+
+ """MessageListener implementation that a parent message registers with its
+ child message.
+
+ In order to support semantics like:
+
+ foo.bar.baz = 23
+ assert foo.HasField('bar')
+
+ ...child objects must have back references to their parents.
+ This helper class is at the heart of this support.
+ """
+
+ def __init__(self, parent_message, has_field_name):
+ """Args:
+ parent_message: The message whose _MaybeCallTransitionToNonemptyCallback()
+ and _MarkByteSizeDirty() methods we should call when we receive
+ TransitionToNonempty() and ByteSizeDirty() messages.
+ has_field_name: The name of the "has" field that we should set in
+ the parent message when we receive a TransitionToNonempty message,
+ or None if there's no "has" field to set. (This will be the case
+ for child objects in "repeated" fields).
+ """
+ # This listener establishes a back reference from a child (contained) object
+ # to its parent (containing) object. We make this a weak reference to avoid
+ # creating cyclic garbage when the client finishes with the 'parent' object
+ # in the tree.
+ if isinstance(parent_message, weakref.ProxyType):
+ self._parent_message_weakref = parent_message
+ else:
+ self._parent_message_weakref = weakref.proxy(parent_message)
+ self._has_field_name = has_field_name
+
+ def TransitionToNonempty(self):
+ try:
+ if self._has_field_name is not None:
+ setattr(self._parent_message_weakref, self._has_field_name, True)
+ # Propagate the signal to our parents iff this is the first field set.
+ self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback()
+ except ReferenceError:
+ # We can get here if a client has kept a reference to a child object,
+ # and is now setting a field on it, but the child's parent has been
+ # garbage-collected. This is not an error.
+ pass
+
+ def ByteSizeDirty(self):
+ try:
+ self._parent_message_weakref._MarkByteSizeDirty()
+ except ReferenceError:
+ # Same as above.
+ pass
+
+
+# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
+# TODO(robinson): Unify error handling of "unknown extension" crap.
+# TODO(robinson): There's so much similarity between the way that
+# extensions behave and the way that normal fields behave that it would
+# be really nice to unify more code. It's not immediately obvious
+# how to do this, though, and I'd rather get the full functionality
+# implemented (and, crucially, get all the tests and specs fleshed out
+# and passing), and then come back to this thorny unification problem.
+# TODO(robinson): Support iteritems()-style iteration over all
+# extensions with the "has" bits turned on?
+class _ExtensionDict(object):
+
+ """Dict-like container for supporting an indexable "Extensions"
+ field on proto instances.
+
+ Note that in all cases we expect extension handles to be
+ FieldDescriptors.
+ """
+
+ class _ExtensionListener(object):
+
+ """Adapts an _ExtensionDict to behave as a MessageListener."""
+
+ def __init__(self, extension_dict, handle_id):
+ self._extension_dict = extension_dict
+ self._handle_id = handle_id
+
+ def TransitionToNonempty(self):
+ self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id)
+
+ def ByteSizeDirty(self):
+ self._extension_dict._SubmessageByteSizeBecameDirty()
+
+ # TODO(robinson): Somewhere, we need to blow up if people
+ # try to register two extensions with the same field number.
+ # (And we need a test for this of course).
+
+ def __init__(self, extended_message, known_extensions):
+ """extended_message: Message instance for which we are the Extensions dict.
+ known_extensions: Iterable of known extension handles.
+ These must be FieldDescriptors.
+ """
+ # We keep a weak reference to extended_message, since
+ # it has a reference to this instance in turn.
+ self._extended_message = weakref.proxy(extended_message)
+ # We make a deep copy of known_extensions to avoid any
+ # thread-safety concerns, since the argument passed in
+ # is the global (class-level) dict of known extensions for
+ # this type of message, which could be modified at any time
+ # via a RegisterExtension() call.
+ #
+ # This dict maps from handle id to handle (a FieldDescriptor).
+ #
+ # XXX
+ # TODO(robinson): This isn't good enough. The client could
+ # instantiate an object in module A, then afterward import
+ # module B and pass the instance to B.Foo(). If B imports
+ # an extender of this proto and then tries to use it, B
+ # will get a KeyError, even though the extension *is* registered
+ # at the time of use.
+ # XXX
+ self._known_extensions = dict((id(e), e) for e in known_extensions)
+ # Read lock around self._values, which may be modified by multiple
+ # concurrent readers in the conceptually "const" __getitem__ method.
+ # So, we grab this lock in every "read-only" method to ensure
+ # that concurrent read access is safe without external locking.
+ self._lock = threading.Lock()
+ # Maps from extension handle ID to current value of that extension.
+ self._values = {}
+ # Maps from extension handle ID to a boolean "has" bit, but only
+ # for non-repeated extension fields.
+ keys = (id for id, extension in self._known_extensions.iteritems()
+ if extension.label != _FieldDescriptor.LABEL_REPEATED)
+ self._has_bits = dict.fromkeys(keys, False)
+
+ self._extensions_by_number = dict(
+ (f.number, f) for f in self._known_extensions.itervalues())
+
+ self._extensions_by_name = {}
+ for extension in self._known_extensions.itervalues():
+ if (extension.containing_type.GetOptions().message_set_wire_format and
+ extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and
+ extension.message_type == extension.extension_scope and
+ extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL):
+ extension_name = extension.message_type.full_name
+ else:
+ extension_name = extension.full_name
+ self._extensions_by_name[extension_name] = extension
+
+ def __getitem__(self, extension_handle):
+ """Returns the current value of the given extension handle."""
+ # We don't care as much about keeping critical sections short in the
+ # extension support, since it's presumably much less of a common case.
+ self._lock.acquire()
+ try:
+ handle_id = id(extension_handle)
+ if handle_id not in self._known_extensions:
+ raise KeyError('Extension not known to this class')
+ if handle_id not in self._values:
+ self._AddMissingHandle(extension_handle, handle_id)
+ return self._values[handle_id]
+ finally:
+ self._lock.release()
+
+ def __eq__(self, other):
+ # We have to grab read locks since we're accessing _values
+ # in a "const" method. See the comment in the constructor.
+ if self is other:
+ return True
+ self._lock.acquire()
+ try:
+ other._lock.acquire()
+ try:
+ if self._has_bits != other._has_bits:
+ return False
+ # If there's a "has" bit, then only compare values where it is true.
+ for k, v in self._values.iteritems():
+ if self._has_bits.get(k, False) and v != other._values[k]:
+ return False
+ return True
+ finally:
+ other._lock.release()
+ finally:
+ self._lock.release()
+
+ def __ne__(self, other):
+ return not self == other
+
+ # Note that this is only meaningful for non-repeated, scalar extension
+ # fields. Note also that we may have to call
+ # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field
+ # this way, to set any necssary "has" bits in the ancestors of the extended
+ # message.
+ def __setitem__(self, extension_handle, value):
+ """If extension_handle specifies a non-repeated, scalar extension
+ field, sets the value of that field.
+ """
+ handle_id = id(extension_handle)
+ if handle_id not in self._known_extensions:
+ raise KeyError('Extension not known to this class')
+ field = extension_handle # Just shorten the name.
+ if (field.label == _FieldDescriptor.LABEL_OPTIONAL
+ and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
+ # It's slightly wasteful to lookup the type checker each time,
+ # but we expect this to be a vanishingly uncommon case anyway.
+ type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
+ type_checker.CheckValue(value)
+ self._values[handle_id] = value
+ self._has_bits[handle_id] = True
+ self._extended_message._MarkByteSizeDirty()
+ self._extended_message._MaybeCallTransitionToNonemptyCallback()
+ else:
+ raise TypeError('Extension is repeated and/or a composite type.')
+
+ def _AddMissingHandle(self, extension_handle, handle_id):
+ """Helper internal to ExtensionDict."""
+ # Special handling for non-repeated message extensions, which (like
+ # normal fields of this kind) are initialized lazily.
+ # REQUIRES: _lock already held.
+ cpp_type = extension_handle.cpp_type
+ label = extension_handle.label
+ if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+ and label != _FieldDescriptor.LABEL_REPEATED):
+ self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id)
+ else:
+ self._values[handle_id] = _DefaultValueForField(
+ self._extended_message, extension_handle)
+
+ def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id):
+ """Helper internal to ExtensionDict."""
+ # REQUIRES: _lock already held.
+ value = extension_handle.message_type._concrete_class()
+ value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id))
+ self._values[handle_id] = value
+
+ def _SubmessageTransitionedToNonempty(self, handle_id):
+ """Called when a submessage with a given handle id first transitions to
+ being nonempty. Called by _ExtensionListener.
+ """
+ assert handle_id in self._has_bits
+ self._has_bits[handle_id] = True
+ self._extended_message._MaybeCallTransitionToNonemptyCallback()
+
+ def _SubmessageByteSizeBecameDirty(self):
+ """Called whenever a submessage's cached byte size becomes invalid
+ (goes from being "clean" to being "dirty"). Called by _ExtensionListener.
+ """
+ self._extended_message._MarkByteSizeDirty()
+
+ # We may wish to widen the public interface of Message.Extensions
+ # to expose some of this private functionality in the future.
+ # For now, we make all this functionality module-private and just
+ # implement what we need for serialization/deserialization,
+ # HasField()/ClearField(), etc.
+
+ def _HasExtension(self, extension_handle):
+ """Method for internal use by this module.
+ Returns true iff we "have" this extension in the sense of the
+ "has" bit being set.
+ """
+ handle_id = id(extension_handle)
+ # Note that this is different from the other checks.
+ if handle_id not in self._has_bits:
+ raise KeyError('Extension not known to this class, or is repeated field.')
+ return self._has_bits[handle_id]
+
+ # Intentionally pretty similar to ClearField() above.
+ def _ClearExtension(self, extension_handle):
+ """Method for internal use by this module.
+ Clears the specified extension, unsetting its "has" bit.
+ """
+ handle_id = id(extension_handle)
+ if handle_id not in self._known_extensions:
+ raise KeyError('Extension not known to this class')
+ default_value = _DefaultValueForField(self._extended_message,
+ extension_handle)
+ if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
+ self._extended_message._MarkByteSizeDirty()
+ else:
+ cpp_type = extension_handle.cpp_type
+ if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ if handle_id in self._values:
+ # Future modifications to this object shouldn't set any
+ # "has" bits here.
+ self._values[handle_id]._SetListener(None)
+ if self._has_bits[handle_id]:
+ self._has_bits[handle_id] = False
+ self._extended_message._MarkByteSizeDirty()
+ if handle_id in self._values:
+ del self._values[handle_id]
+
+ def _ListSetExtensions(self):
+ """Method for internal use by this module.
+
+ Returns an sequence of all extensions that are currently "set"
+ in this extension dict. A "set" extension is a repeated extension,
+ or a non-repeated extension with its "has" bit set.
+
+ The returned sequence contains (field_descriptor, value) pairs,
+ where value is the current value of the extension with the given
+ field descriptor.
+
+ The sequence values are in arbitrary order.
+ """
+ self._lock.acquire() # Read-only methods must lock around self._values.
+ try:
+ set_extensions = []
+ for handle_id, value in self._values.iteritems():
+ handle = self._known_extensions[handle_id]
+ if (handle.label == _FieldDescriptor.LABEL_REPEATED
+ or self._has_bits[handle_id]):
+ set_extensions.append((handle, value))
+ return set_extensions
+ finally:
+ self._lock.release()
+
+ def _AllExtensionsByNumber(self):
+ """Method for internal use by this module.
+
+ Returns: A dict mapping field_number to (handle, field_descriptor),
+ for *all* registered extensions for this dict.
+ """
+ return self._extensions_by_number
+
+ def _FindExtensionByName(self, name):
+ """Tries to find a known extension with the specified name.
+
+ Args:
+ name: Extension full name.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extensions_by_name.get(name, None)
diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py
new file mode 100755
index 0000000..dd136c9
--- /dev/null
+++ b/python/google/protobuf/service.py
@@ -0,0 +1,222 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Declares the RPC service interfaces.
+
+This module declares the abstract interfaces underlying proto2 RPC
+services. These are intended to be independent of any particular RPC
+implementation, so that proto2 services can be used on top of a variety
+of implementations.
+"""
+
+__author__ = 'petar@google.com (Petar Petrov)'
+
+
+class RpcException(Exception):
+ """Exception raised on failed blocking RPC method call."""
+ pass
+
+
+class Service(object):
+
+ """Abstract base interface for protocol-buffer-based RPC services.
+
+ Services themselves are abstract classes (implemented either by servers or as
+ stubs), but they subclass this base interface. The methods of this
+ interface can be used to call the methods of the service without knowing
+ its exact type at compile time (analogous to the Message interface).
+ """
+
+ def GetDescriptor():
+ """Retrieves this service's descriptor."""
+ raise NotImplementedError
+
+ def CallMethod(self, method_descriptor, rpc_controller,
+ request, done):
+ """Calls a method of the service specified by method_descriptor.
+
+ If "done" is None then the call is blocking and the response
+ message will be returned directly. Otherwise the call is asynchronous
+ and "done" will later be called with the response value.
+
+ In the blocking case, RpcException will be raised on error.
+
+ Preconditions:
+ * method_descriptor.service == GetDescriptor
+ * request is of the exact same classes as returned by
+ GetRequestClass(method).
+ * After the call has started, the request must not be modified.
+ * "rpc_controller" is of the correct type for the RPC implementation being
+ used by this Service. For stubs, the "correct type" depends on the
+ RpcChannel which the stub is using.
+
+ Postconditions:
+ * "done" will be called when the method is complete. This may be
+ before CallMethod() returns or it may be at some point in the future.
+ * If the RPC failed, the response value passed to "done" will be None.
+ Further details about the failure can be found by querying the
+ RpcController.
+ """
+ raise NotImplementedError
+
+ def GetRequestClass(self, method_descriptor):
+ """Returns the class of the request message for the specified method.
+
+ CallMethod() requires that the request is of a particular subclass of
+ Message. GetRequestClass() gets the default instance of this required
+ type.
+
+ Example:
+ method = service.GetDescriptor().FindMethodByName("Foo")
+ request = stub.GetRequestClass(method)()
+ request.ParseFromString(input)
+ service.CallMethod(method, request, callback)
+ """
+ raise NotImplementedError
+
+ def GetResponseClass(self, method_descriptor):
+ """Returns the class of the response message for the specified method.
+
+ This method isn't really needed, as the RpcChannel's CallMethod constructs
+ the response protocol message. It's provided anyway in case it is useful
+ for the caller to know the response type in advance.
+ """
+ raise NotImplementedError
+
+
+class RpcController(object):
+
+ """An RpcController mediates a single method call.
+
+ The primary purpose of the controller is to provide a way to manipulate
+ settings specific to the RPC implementation and to find out about RPC-level
+ errors. The methods provided by the RpcController interface are intended
+ to be a "least common denominator" set of features which we expect all
+ implementations to support. Specific implementations may provide more
+ advanced features (e.g. deadline propagation).
+ """
+
+ # Client-side methods below
+
+ def Reset(self):
+ """Resets the RpcController to its initial state.
+
+ After the RpcController has been reset, it may be reused in
+ a new call. Must not be called while an RPC is in progress.
+ """
+ raise NotImplementedError
+
+ def Failed(self):
+ """Returns true if the call failed.
+
+ After a call has finished, returns true if the call failed. The possible
+ reasons for failure depend on the RPC implementation. Failed() must not
+ be called before a call has finished. If Failed() returns true, the
+ contents of the response message are undefined.
+ """
+ raise NotImplementedError
+
+ def ErrorText(self):
+ """If Failed is true, returns a human-readable description of the error."""
+ raise NotImplementedError
+
+ def StartCancel(self):
+ """Initiate cancellation.
+
+ Advises the RPC system that the caller desires that the RPC call be
+ canceled. The RPC system may cancel it immediately, may wait awhile and
+ then cancel it, or may not even cancel the call at all. If the call is
+ canceled, the "done" callback will still be called and the RpcController
+ will indicate that the call failed at that time.
+ """
+ raise NotImplementedError
+
+ # Server-side methods below
+
+ def SetFailed(self, reason):
+ """Sets a failure reason.
+
+ Causes Failed() to return true on the client side. "reason" will be
+ incorporated into the message returned by ErrorText(). If you find
+ you need to return machine-readable information about failures, you
+ should incorporate it into your response protocol buffer and should
+ NOT call SetFailed().
+ """
+ raise NotImplementedError
+
+ def IsCanceled(self):
+ """Checks if the client cancelled the RPC.
+
+ If true, indicates that the client canceled the RPC, so the server may
+ as well give up on replying to it. The server should still call the
+ final "done" callback.
+ """
+ raise NotImplementedError
+
+ def NotifyOnCancel(self, callback):
+ """Sets a callback to invoke on cancel.
+
+ Asks that the given callback be called when the RPC is canceled. The
+ callback will always be called exactly once. If the RPC completes without
+ being canceled, the callback will be called after completion. If the RPC
+ has already been canceled when NotifyOnCancel() is called, the callback
+ will be called immediately.
+
+ NotifyOnCancel() must be called no more than once per request.
+ """
+ raise NotImplementedError
+
+
+class RpcChannel(object):
+
+ """Abstract interface for an RPC channel.
+
+ An RpcChannel represents a communication line to a service which can be used
+ to call that service's methods. The service may be running on another
+ machine. Normally, you should not use an RpcChannel directly, but instead
+ construct a stub {@link Service} wrapping it. Example:
+
+ Example:
+ RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234")
+ RpcController controller = rpcImpl.Controller()
+ MyService service = MyService_Stub(channel)
+ service.MyMethod(controller, request, callback)
+ """
+
+ def CallMethod(self, method_descriptor, rpc_controller,
+ request, response_class, done):
+ """Calls the method identified by the descriptor.
+
+ Call the given method of the remote service. The signature of this
+ procedure looks the same as Service.CallMethod(), but the requirements
+ are less strict in one important way: the request object doesn't have to
+ be of any specific class as long as its descriptor is method.input_type.
+ """
+ raise NotImplementedError
diff --git a/python/google/protobuf/service_reflection.py b/python/google/protobuf/service_reflection.py
new file mode 100755
index 0000000..851e83e
--- /dev/null
+++ b/python/google/protobuf/service_reflection.py
@@ -0,0 +1,284 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains metaclasses used to create protocol service and service stub
+classes from ServiceDescriptor objects at runtime.
+
+The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to
+inject all useful functionality into the classes output by the protocol
+compiler at compile-time.
+"""
+
+__author__ = 'petar@google.com (Petar Petrov)'
+
+
+class GeneratedServiceType(type):
+
+ """Metaclass for service classes created at runtime from ServiceDescriptors.
+
+ Implementations for all methods described in the Service class are added here
+ by this class. We also create properties to allow getting/setting all fields
+ in the protocol message.
+
+ The protocol compiler currently uses this metaclass to create protocol service
+ classes at runtime. Clients can also manually create their own classes at
+ runtime, as in this example:
+
+ mydescriptor = ServiceDescriptor(.....)
+ class MyProtoService(service.Service):
+ __metaclass__ = GeneratedServiceType
+ DESCRIPTOR = mydescriptor
+ myservice_instance = MyProtoService()
+ ...
+ """
+
+ _DESCRIPTOR_KEY = 'DESCRIPTOR'
+
+ def __init__(cls, name, bases, dictionary):
+ """Creates a message service class.
+
+ Args:
+ name: Name of the class (ignored, but required by the metaclass
+ protocol).
+ bases: Base classes of the class being constructed.
+ dictionary: The class dictionary of the class being constructed.
+ dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object
+ describing this protocol service type.
+ """
+ # Don't do anything if this class doesn't have a descriptor. This happens
+ # when a service class is subclassed.
+ if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary:
+ return
+ descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY]
+ service_builder = _ServiceBuilder(descriptor)
+ service_builder.BuildService(cls)
+
+
+class GeneratedServiceStubType(GeneratedServiceType):
+
+ """Metaclass for service stubs created at runtime from ServiceDescriptors.
+
+ This class has similar responsibilities as GeneratedServiceType, except that
+ it creates the service stub classes.
+ """
+
+ _DESCRIPTOR_KEY = 'DESCRIPTOR'
+
+ def __init__(cls, name, bases, dictionary):
+ """Creates a message service stub class.
+
+ Args:
+ name: Name of the class (ignored, here).
+ bases: Base classes of the class being constructed.
+ dictionary: The class dictionary of the class being constructed.
+ dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object
+ describing this protocol service type.
+ """
+ super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary)
+ # Don't do anything if this class doesn't have a descriptor. This happens
+ # when a service stub is subclassed.
+ if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary:
+ return
+ descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY]
+ service_stub_builder = _ServiceStubBuilder(descriptor)
+ service_stub_builder.BuildServiceStub(cls)
+
+
+class _ServiceBuilder(object):
+
+ """This class constructs a protocol service class using a service descriptor.
+
+ Given a service descriptor, this class constructs a class that represents
+ the specified service descriptor. One service builder instance constructs
+ exactly one service class. That means all instances of that class share the
+ same builder.
+ """
+
+ def __init__(self, service_descriptor):
+ """Initializes an instance of the service class builder.
+
+ Args:
+ service_descriptor: ServiceDescriptor to use when constructing the
+ service class.
+ """
+ self.descriptor = service_descriptor
+
+ def BuildService(self, cls):
+ """Constructs the service class.
+
+ Args:
+ cls: The class that will be constructed.
+ """
+
+ # CallMethod needs to operate with an instance of the Service class. This
+ # internal wrapper function exists only to be able to pass the service
+ # instance to the method that does the real CallMethod work.
+ def _WrapCallMethod(srvc, method_descriptor,
+ rpc_controller, request, callback):
+ return self._CallMethod(srvc, method_descriptor,
+ rpc_controller, request, callback)
+ self.cls = cls
+ cls.CallMethod = _WrapCallMethod
+ cls.GetDescriptor = staticmethod(lambda: self.descriptor)
+ cls.GetDescriptor.__doc__ = "Returns the service descriptor."
+ cls.GetRequestClass = self._GetRequestClass
+ cls.GetResponseClass = self._GetResponseClass
+ for method in self.descriptor.methods:
+ setattr(cls, method.name, self._GenerateNonImplementedMethod(method))
+
+ def _CallMethod(self, srvc, method_descriptor,
+ rpc_controller, request, callback):
+ """Calls the method described by a given method descriptor.
+
+ Args:
+ srvc: Instance of the service for which this method is called.
+ method_descriptor: Descriptor that represent the method to call.
+ rpc_controller: RPC controller to use for this method's execution.
+ request: Request protocol message.
+ callback: A callback to invoke after the method has completed.
+ """
+ if method_descriptor.containing_service != self.descriptor:
+ raise RuntimeError(
+ 'CallMethod() given method descriptor for wrong service type.')
+ method = getattr(srvc, method_descriptor.name)
+ return method(rpc_controller, request, callback)
+
+ def _GetRequestClass(self, method_descriptor):
+ """Returns the class of the request protocol message.
+
+ Args:
+ method_descriptor: Descriptor of the method for which to return the
+ request protocol message class.
+
+ Returns:
+ A class that represents the input protocol message of the specified
+ method.
+ """
+ if method_descriptor.containing_service != self.descriptor:
+ raise RuntimeError(
+ 'GetRequestClass() given method descriptor for wrong service type.')
+ return method_descriptor.input_type._concrete_class
+
+ def _GetResponseClass(self, method_descriptor):
+ """Returns the class of the response protocol message.
+
+ Args:
+ method_descriptor: Descriptor of the method for which to return the
+ response protocol message class.
+
+ Returns:
+ A class that represents the output protocol message of the specified
+ method.
+ """
+ if method_descriptor.containing_service != self.descriptor:
+ raise RuntimeError(
+ 'GetResponseClass() given method descriptor for wrong service type.')
+ return method_descriptor.output_type._concrete_class
+
+ def _GenerateNonImplementedMethod(self, method):
+ """Generates and returns a method that can be set for a service methods.
+
+ Args:
+ method: Descriptor of the service method for which a method is to be
+ generated.
+
+ Returns:
+ A method that can be added to the service class.
+ """
+ return lambda inst, rpc_controller, request, callback: (
+ self._NonImplementedMethod(method.name, rpc_controller, callback))
+
+ def _NonImplementedMethod(self, method_name, rpc_controller, callback):
+ """The body of all methods in the generated service class.
+
+ Args:
+ method_name: Name of the method being executed.
+ rpc_controller: RPC controller used to execute this method.
+ callback: A callback which will be invoked when the method finishes.
+ """
+ rpc_controller.SetFailed('Method %s not implemented.' % method_name)
+ callback(None)
+
+
+class _ServiceStubBuilder(object):
+
+ """Constructs a protocol service stub class using a service descriptor.
+
+ Given a service descriptor, this class constructs a suitable stub class.
+ A stub is just a type-safe wrapper around an RpcChannel which emulates a
+ local implementation of the service.
+
+ One service stub builder instance constructs exactly one class. It means all
+ instances of that class share the same service stub builder.
+ """
+
+ def __init__(self, service_descriptor):
+ """Initializes an instance of the service stub class builder.
+
+ Args:
+ service_descriptor: ServiceDescriptor to use when constructing the
+ stub class.
+ """
+ self.descriptor = service_descriptor
+
+ def BuildServiceStub(self, cls):
+ """Constructs the stub class.
+
+ Args:
+ cls: The class that will be constructed.
+ """
+
+ def _ServiceStubInit(stub, rpc_channel):
+ stub.rpc_channel = rpc_channel
+ self.cls = cls
+ cls.__init__ = _ServiceStubInit
+ for method in self.descriptor.methods:
+ setattr(cls, method.name, self._GenerateStubMethod(method))
+
+ def _GenerateStubMethod(self, method):
+ return (lambda inst, rpc_controller, request, callback=None:
+ self._StubMethod(inst, method, rpc_controller, request, callback))
+
+ def _StubMethod(self, stub, method_descriptor,
+ rpc_controller, request, callback):
+ """The body of all service methods in the generated stub class.
+
+ Args:
+ stub: Stub instance.
+ method_descriptor: Descriptor of the invoked method.
+ rpc_controller: Rpc controller to execute the method.
+ request: Request protocol message.
+ callback: A callback to execute when the method finishes.
+ Returns:
+ Response message (in case of blocking call).
+ """
+ return stub.rpc_channel.CallMethod(
+ method_descriptor, rpc_controller, request,
+ method_descriptor.output_type._concrete_class, callback)
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
new file mode 100755
index 0000000..1cddce6
--- /dev/null
+++ b/python/google/protobuf/text_format.py
@@ -0,0 +1,650 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# http://code.google.com/p/protobuf/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Contains routines for printing protocol messages in text format."""
+
+__author__ = 'kenton@google.com (Kenton Varda)'
+
+import cStringIO
+import re
+
+from collections import deque
+from google.protobuf.internal import type_checkers
+from google.protobuf import descriptor
+
+__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField',
+ 'PrintFieldValue', 'Merge' ]
+
+
+class ParseError(Exception):
+ """Thrown in case of ASCII parsing error."""
+
+
+def MessageToString(message):
+ out = cStringIO.StringIO()
+ PrintMessage(message, out)
+ result = out.getvalue()
+ out.close()
+ return result
+
+
+def PrintMessage(message, out, indent = 0):
+ for field, value in message.ListFields():
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ for element in value:
+ PrintField(field, element, out, indent)
+ else:
+ PrintField(field, value, out, indent)
+
+
+def PrintField(field, value, out, indent = 0):
+ """Print a single field name/value pair. For repeated fields, the value
+ should be a single element."""
+
+ out.write(' ' * indent);
+ if field.is_extension:
+ out.write('[')
+ if (field.containing_type.GetOptions().message_set_wire_format and
+ field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
+ field.message_type == field.extension_scope and
+ field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
+ out.write(field.message_type.full_name)
+ else:
+ out.write(field.full_name)
+ out.write(']')
+ elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
+ # For groups, use the capitalized name.
+ out.write(field.message_type.name)
+ else:
+ out.write(field.name)
+
+ if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ # The colon is optional in this case, but our cross-language golden files
+ # don't include it.
+ out.write(': ')
+
+ PrintFieldValue(field, value, out, indent)
+ out.write('\n')
+
+
+def PrintFieldValue(field, value, out, indent = 0):
+ """Print a single field value (not including name). For repeated fields,
+ the value should be a single element."""
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ out.write(' {\n')
+ PrintMessage(value, out, indent + 2)
+ out.write(' ' * indent + '}')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
+ out.write(field.enum_type.values_by_number[value].name)
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
+ out.write('\"')
+ out.write(_CEscape(value))
+ out.write('\"')
+ elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
+ if value:
+ out.write("true")
+ else:
+ out.write("false")
+ else:
+ out.write(str(value))
+
+
+def Merge(text, message):
+ """Merges an ASCII representation of a protocol message into a message.
+
+ Args:
+ text: Message ASCII representation.
+ message: A protocol buffer message to merge into.
+
+ Raises:
+ ParseError: On ASCII parsing problems.
+ """
+ tokenizer = _Tokenizer(text)
+ while not tokenizer.AtEnd():
+ _MergeField(tokenizer, message)
+
+
+def _MergeField(tokenizer, message):
+ """Merges a single protocol message field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field name and values.
+ message: A protocol message to record the data.
+
+ Raises:
+ ParseError: In case of ASCII parsing problems.
+ """
+ message_descriptor = message.DESCRIPTOR
+ if tokenizer.TryConsume('['):
+ name = [tokenizer.ConsumeIdentifier()]
+ while tokenizer.TryConsume('.'):
+ name.append(tokenizer.ConsumeIdentifier())
+ name = '.'.join(name)
+
+ field = message.Extensions._FindExtensionByName(name)
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" not registered.' % name)
+ elif message_descriptor != field.containing_type:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Extension "%s" does not extend message type "%s".' % (
+ name, message_descriptor.full_name))
+ tokenizer.Consume(']')
+ else:
+ name = tokenizer.ConsumeIdentifier()
+ field = message_descriptor.fields_by_name.get(name, None)
+
+ # Group names are expected to be capitalized as they appear in the
+ # .proto file, which actually matches their type names, not their field
+ # names.
+ if not field:
+ field = message_descriptor.fields_by_name.get(name.lower(), None)
+ if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
+ field = None
+
+ if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
+ field.message_type.name != name):
+ field = None
+
+ if not field:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Message type "%s" has no field named "%s".' % (
+ message_descriptor.full_name, name))
+
+ if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
+ tokenizer.TryConsume(':')
+
+ if tokenizer.TryConsume('<'):
+ end_token = '>'
+ else:
+ tokenizer.Consume('{')
+ end_token = '}'
+
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ sub_message = message.Extensions[field].add()
+ else:
+ sub_message = getattr(message, field.name).add()
+ else:
+ if field.is_extension:
+ sub_message = message.Extensions[field]
+ else:
+ sub_message = getattr(message, field.name)
+
+ while not tokenizer.TryConsume(end_token):
+ if tokenizer.AtEnd():
+ raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
+ _MergeField(tokenizer, sub_message)
+ else:
+ _MergeScalarField(tokenizer, message, field)
+
+
+def _MergeScalarField(tokenizer, message, field):
+ """Merges a single protocol message scalar field into a message.
+
+ Args:
+ tokenizer: A tokenizer to parse the field value.
+ message: A protocol message to record the data.
+ field: The descriptor of the field to be merged.
+
+ Raises:
+ ParseError: In case of ASCII parsing problems.
+ RuntimeError: On runtime errors.
+ """
+ tokenizer.Consume(':')
+ value = None
+
+ if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
+ descriptor.FieldDescriptor.TYPE_SINT32,
+ descriptor.FieldDescriptor.TYPE_SFIXED32):
+ value = tokenizer.ConsumeInt32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
+ descriptor.FieldDescriptor.TYPE_SINT64,
+ descriptor.FieldDescriptor.TYPE_SFIXED64):
+ value = tokenizer.ConsumeInt64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
+ descriptor.FieldDescriptor.TYPE_FIXED32):
+ value = tokenizer.ConsumeUint32()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
+ descriptor.FieldDescriptor.TYPE_FIXED64):
+ value = tokenizer.ConsumeUint64()
+ elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
+ descriptor.FieldDescriptor.TYPE_DOUBLE):
+ value = tokenizer.ConsumeFloat()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
+ value = tokenizer.ConsumeBool()
+ elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
+ value = tokenizer.ConsumeString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
+ value = tokenizer.ConsumeByteString()
+ elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
+ # Enum can be specified by a number (the enum value), or by
+ # a string literal (the enum name).
+ enum_descriptor = field.enum_type
+ if tokenizer.LookingAtInteger():
+ number = tokenizer.ConsumeInt32()
+ enum_value = enum_descriptor.values_by_number.get(number, None)
+ if enum_value is None:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Enum type "%s" has no value with number %d.' % (
+ enum_descriptor.full_name, number))
+ else:
+ identifier = tokenizer.ConsumeIdentifier()
+ enum_value = enum_descriptor.values_by_name.get(identifier, None)
+ if enum_value is None:
+ raise tokenizer.ParseErrorPreviousToken(
+ 'Enum type "%s" has no value named %s.' % (
+ enum_descriptor.full_name, identifier))
+ value = enum_value.number
+ else:
+ raise RuntimeError('Unknown field type %d' % field.type)
+
+ if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ if field.is_extension:
+ message.Extensions[field].append(value)
+ else:
+ getattr(message, field.name).append(value)
+ else:
+ if field.is_extension:
+ message.Extensions[field] = value
+ else:
+ setattr(message, field.name, value)
+
+
+class _Tokenizer(object):
+ """Protocol buffer ASCII representation tokenizer.
+
+ This class handles the lower level string parsing by splitting it into
+ meaningful tokens.
+
+ It was directly ported from the Java protocol buffer API.
+ """
+
+ _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
+ _TOKEN = re.compile(
+ '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier
+ '[0-9+-][0-9a-zA-Z_.+-]*|' # a number
+ '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
+ '\'([^\"\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
+ _IDENTIFIER = re.compile('\w+')
+ _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(),
+ type_checkers.Int32ValueChecker(),
+ type_checkers.Uint64ValueChecker(),
+ type_checkers.Int64ValueChecker()]
+ _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE)
+ _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE)
+
+ def __init__(self, text_message):
+ self._text_message = text_message
+
+ self._position = 0
+ self._line = -1
+ self._column = 0
+ self._token_start = None
+ self.token = ''
+ self._lines = deque(text_message.split('\n'))
+ self._current_line = ''
+ self._previous_line = 0
+ self._previous_column = 0
+ self._SkipWhitespace()
+ self.NextToken()
+
+ def AtEnd(self):
+ """Checks the end of the text was reached.
+
+ Returns:
+ True iff the end was reached.
+ """
+ return not self._lines and not self._current_line
+
+ def _PopLine(self):
+ while not self._current_line:
+ if not self._lines:
+ self._current_line = ''
+ return
+ self._line += 1
+ self._column = 0
+ self._current_line = self._lines.popleft()
+
+ def _SkipWhitespace(self):
+ while True:
+ self._PopLine()
+ match = re.match(self._WHITESPACE, self._current_line)
+ if not match:
+ break
+ length = len(match.group(0))
+ self._current_line = self._current_line[length:]
+ self._column += length
+
+ def TryConsume(self, token):
+ """Tries to consume a given piece of text.
+
+ Args:
+ token: Text to consume.
+
+ Returns:
+ True iff the text was consumed.
+ """
+ if self.token == token:
+ self.NextToken()
+ return True
+ return False
+
+ def Consume(self, token):
+ """Consumes a piece of text.
+
+ Args:
+ token: Text to consume.
+
+ Raises:
+ ParseError: If the text couldn't be consumed.
+ """
+ if not self.TryConsume(token):
+ raise self._ParseError('Expected "%s".' % token)
+
+ def LookingAtInteger(self):
+ """Checks if the current token is an integer.
+
+ Returns:
+ True iff the current token is an integer.
+ """
+ if not self.token:
+ return False
+ c = self.token[0]
+ return (c >= '0' and c <= '9') or c == '-' or c == '+'
+
+ def ConsumeIdentifier(self):
+ """Consumes protocol message field identifier.
+
+ Returns:
+ Identifier string.
+
+ Raises:
+ ParseError: If an identifier couldn't be consumed.
+ """
+ result = self.token
+ if not re.match(self._IDENTIFIER, result):
+ raise self._ParseError('Expected identifier.')
+ self.NextToken()
+ return result
+
+ def ConsumeInt32(self):
+ """Consumes a signed 32bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 32bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=True, is_long=False)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeUint32(self):
+ """Consumes an unsigned 32bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 32bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=False, is_long=False)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeInt64(self):
+ """Consumes a signed 64bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If a signed 64bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=True, is_long=True)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeUint64(self):
+ """Consumes an unsigned 64bit integer number.
+
+ Returns:
+ The integer parsed.
+
+ Raises:
+ ParseError: If an unsigned 64bit integer couldn't be consumed.
+ """
+ try:
+ result = self._ParseInteger(self.token, is_signed=False, is_long=True)
+ except ValueError, e:
+ raise self._IntegerParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeFloat(self):
+ """Consumes an floating point number.
+
+ Returns:
+ The number parsed.
+
+ Raises:
+ ParseError: If a floating point number couldn't be consumed.
+ """
+ text = self.token
+ if re.match(self._FLOAT_INFINITY, text):
+ self.NextToken()
+ if text.startswith('-'):
+ return float('-inf')
+ return float('inf')
+
+ if re.match(self._FLOAT_NAN, text):
+ self.NextToken()
+ return float('nan')
+
+ try:
+ result = float(text)
+ except ValueError, e:
+ raise self._FloatParseError(e)
+ self.NextToken()
+ return result
+
+ def ConsumeBool(self):
+ """Consumes a boolean value.
+
+ Returns:
+ The bool parsed.
+
+ Raises:
+ ParseError: If a boolean value couldn't be consumed.
+ """
+ if self.token == 'true':
+ self.NextToken()
+ return True
+ elif self.token == 'false':
+ self.NextToken()
+ return False
+ else:
+ raise self._ParseError('Expected "true" or "false".')
+
+ def ConsumeString(self):
+ """Consumes a string value.
+
+ Returns:
+ The string parsed.
+
+ Raises:
+ ParseError: If a string value couldn't be consumed.
+ """
+ return unicode(self.ConsumeByteString(), 'utf-8')
+
+ def ConsumeByteString(self):
+ """Consumes a byte array value.
+
+ Returns:
+ The array parsed (as a string).
+
+ Raises:
+ ParseError: If a byte array value couldn't be consumed.
+ """
+ text = self.token
+ if len(text) < 1 or text[0] not in ('\'', '"'):
+ raise self._ParseError('Exptected string.')
+
+ if len(text) < 2 or text[-1] != text[0]:
+ raise self._ParseError('String missing ending quote.')
+
+ try:
+ result = _CUnescape(text[1:-1])
+ except ValueError, e:
+ raise self._ParseError(str(e))
+ self.NextToken()
+ return result
+
+ def _ParseInteger(self, text, is_signed=False, is_long=False):
+ """Parses an integer.
+
+ Args:
+ text: The text to parse.
+ is_signed: True if a signed integer must be parsed.
+ is_long: True if a long integer must be parsed.
+
+ Returns:
+ The integer value.
+
+ Raises:
+ ValueError: Thrown Iff the text is not a valid integer.
+ """
+ pos = 0
+ if text.startswith('-'):
+ pos += 1
+
+ base = 10
+ if text.startswith('0x', pos) or text.startswith('0X', pos):
+ base = 16
+ elif text.startswith('0', pos):
+ base = 8
+
+ # Do the actual parsing. Exception handling is propagated to caller.
+ result = int(text, base)
+
+ # Check if the integer is sane. Exceptions handled by callers.
+ checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
+ checker.CheckValue(result)
+ return result
+
+ def ParseErrorPreviousToken(self, message):
+ """Creates and *returns* a ParseError for the previously read token.
+
+ Args:
+ message: A message to set for the exception.
+
+ Returns:
+ A ParseError instance.
+ """
+ return ParseError('%d:%d : %s' % (
+ self._previous_line + 1, self._previous_column + 1, message))
+
+ def _ParseError(self, message):
+ """Creates and *returns* a ParseError for the current token."""
+ return ParseError('%d:%d : %s' % (
+ self._line + 1, self._column + 1, message))
+
+ def _IntegerParseError(self, e):
+ return self._ParseError('Couldn\'t parse integer: ' + str(e))
+
+ def _FloatParseError(self, e):
+ return self._ParseError('Couldn\'t parse number: ' + str(e))
+
+ def NextToken(self):
+ """Reads the next meaningful token."""
+ self._previous_line = self._line
+ self._previous_column = self._column
+ if self.AtEnd():
+ self.token = ''
+ return
+ self._column += len(self.token)
+
+ # Make sure there is data to work on.
+ self._PopLine()
+
+ match = re.match(self._TOKEN, self._current_line)
+ if match:
+ token = match.group(0)
+ self._current_line = self._current_line[len(token):]
+ self.token = token
+ else:
+ self.token = self._current_line[0]
+ self._current_line = self._current_line[1:]
+ self._SkipWhitespace()
+
+
+# text.encode('string_escape') does not seem to satisfy our needs as it
+# encodes unprintable characters using two-digit hex escapes whereas our
+# C++ unescaping function allows hex escapes to be any length. So,
+# "\0011".encode('string_escape') ends up being "\\x011", which will be
+# decoded in C++ as a single-character string with char code 0x11.
+def _CEscape(text):
+ def escape(c):
+ o = ord(c)
+ if o == 10: return r"\n" # optional escape
+ if o == 13: return r"\r" # optional escape
+ if o == 9: return r"\t" # optional escape
+ if o == 39: return r"\'" # optional escape
+
+ if o == 34: return r'\"' # necessary escape
+ if o == 92: return r"\\" # necessary escape
+
+ if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
+ return c
+ return "".join([escape(c) for c in text])
+
+
+_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])')
+
+
+def _CUnescape(text):
+ def ReplaceHex(m):
+ return chr(int(m.group(0)[2:], 16))
+ # This is required because the 'string_escape' encoding doesn't
+ # allow single-digit hex escapes (like '\xf').
+ result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
+ return result.decode('string_escape')
diff --git a/python/mox.py b/python/mox.py
new file mode 100755
index 0000000..ce80ba5
--- /dev/null
+++ b/python/mox.py
@@ -0,0 +1,1401 @@
+#!/usr/bin/python2.4
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is used for testing. The original is at:
+# http://code.google.com/p/pymox/
+
+"""Mox, an object-mocking framework for Python.
+
+Mox works in the record-replay-verify paradigm. When you first create
+a mock object, it is in record mode. You then programmatically set
+the expected behavior of the mock object (what methods are to be
+called on it, with what parameters, what they should return, and in
+what order).
+
+Once you have set up the expected mock behavior, you put it in replay
+mode. Now the mock responds to method calls just as you told it to.
+If an unexpected method (or an expected method with unexpected
+parameters) is called, then an exception will be raised.
+
+Once you are done interacting with the mock, you need to verify that
+all the expected interactions occured. (Maybe your code exited
+prematurely without calling some cleanup method!) The verify phase
+ensures that every expected method was called; otherwise, an exception
+will be raised.
+
+Suggested usage / workflow:
+
+ # Create Mox factory
+ my_mox = Mox()
+
+ # Create a mock data access object
+ mock_dao = my_mox.CreateMock(DAOClass)
+
+ # Set up expected behavior
+ mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
+ mock_dao.DeletePerson(person)
+
+ # Put mocks in replay mode
+ my_mox.ReplayAll()
+
+ # Inject mock object and run test
+ controller.SetDao(mock_dao)
+ controller.DeletePersonById('1')
+
+ # Verify all methods were called as expected
+ my_mox.VerifyAll()
+"""
+
+from collections import deque
+import re
+import types
+import unittest
+
+import stubout
+
+class Error(AssertionError):
+ """Base exception for this module."""
+
+ pass
+
+
+class ExpectedMethodCallsError(Error):
+ """Raised when Verify() is called before all expected methods have been called
+ """
+
+ def __init__(self, expected_methods):
+ """Init exception.
+
+ Args:
+ # expected_methods: A sequence of MockMethod objects that should have been
+ # called.
+ expected_methods: [MockMethod]
+
+ Raises:
+ ValueError: if expected_methods contains no methods.
+ """
+
+ if not expected_methods:
+ raise ValueError("There must be at least one expected method")
+ Error.__init__(self)
+ self._expected_methods = expected_methods
+
+ def __str__(self):
+ calls = "\n".join(["%3d. %s" % (i, m)
+ for i, m in enumerate(self._expected_methods)])
+ return "Verify: Expected methods never called:\n%s" % (calls,)
+
+
+class UnexpectedMethodCallError(Error):
+ """Raised when an unexpected method is called.
+
+ This can occur if a method is called with incorrect parameters, or out of the
+ specified order.
+ """
+
+ def __init__(self, unexpected_method, expected):
+ """Init exception.
+
+ Args:
+ # unexpected_method: MockMethod that was called but was not at the head of
+ # the expected_method queue.
+ # expected: MockMethod or UnorderedGroup the method should have
+ # been in.
+ unexpected_method: MockMethod
+ expected: MockMethod or UnorderedGroup
+ """
+
+ Error.__init__(self)
+ self._unexpected_method = unexpected_method
+ self._expected = expected
+
+ def __str__(self):
+ return "Unexpected method call: %s. Expecting: %s" % \
+ (self._unexpected_method, self._expected)
+
+
+class UnknownMethodCallError(Error):
+ """Raised if an unknown method is requested of the mock object."""
+
+ def __init__(self, unknown_method_name):
+ """Init exception.
+
+ Args:
+ # unknown_method_name: Method call that is not part of the mocked class's
+ # public interface.
+ unknown_method_name: str
+ """
+
+ Error.__init__(self)
+ self._unknown_method_name = unknown_method_name
+
+ def __str__(self):
+ return "Method called is not a member of the object: %s" % \
+ self._unknown_method_name
+
+
+class Mox(object):
+ """Mox: a factory for creating mock objects."""
+
+ # A list of types that should be stubbed out with MockObjects (as
+ # opposed to MockAnythings).
+ _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
+ types.ObjectType, types.TypeType]
+
+ def __init__(self):
+ """Initialize a new Mox."""
+
+ self._mock_objects = []
+ self.stubs = stubout.StubOutForTesting()
+
+ def CreateMock(self, class_to_mock):
+ """Create a new mock object.
+
+ Args:
+ # class_to_mock: the class to be mocked
+ class_to_mock: class
+
+ Returns:
+ MockObject that can be used as the class_to_mock would be.
+ """
+
+ new_mock = MockObject(class_to_mock)
+ self._mock_objects.append(new_mock)
+ return new_mock
+
+ def CreateMockAnything(self):
+ """Create a mock that will accept any method calls.
+
+ This does not enforce an interface.
+ """
+
+ new_mock = MockAnything()
+ self._mock_objects.append(new_mock)
+ return new_mock
+
+ def ReplayAll(self):
+ """Set all mock objects to replay mode."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Replay()
+
+
+ def VerifyAll(self):
+ """Call verify on all mock objects created."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Verify()
+
+ def ResetAll(self):
+ """Call reset on all mock objects. This does not unset stubs."""
+
+ for mock_obj in self._mock_objects:
+ mock_obj._Reset()
+
+ def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
+ """Replace a method, attribute, etc. with a Mock.
+
+ This will replace a class or module with a MockObject, and everything else
+ (method, function, etc) with a MockAnything. This can be overridden to
+ always use a MockAnything by setting use_mock_anything to True.
+
+ Args:
+ obj: A Python object (class, module, instance, callable).
+ attr_name: str. The name of the attribute to replace with a mock.
+ use_mock_anything: bool. True if a MockAnything should be used regardless
+ of the type of attribute.
+ """
+
+ attr_to_replace = getattr(obj, attr_name)
+ if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
+ stub = self.CreateMock(attr_to_replace)
+ else:
+ stub = self.CreateMockAnything()
+
+ self.stubs.Set(obj, attr_name, stub)
+
+ def UnsetStubs(self):
+ """Restore stubs to their original state."""
+
+ self.stubs.UnsetAll()
+
+def Replay(*args):
+ """Put mocks into Replay mode.
+
+ Args:
+ # args is any number of mocks to put into replay mode.
+ """
+
+ for mock in args:
+ mock._Replay()
+
+
+def Verify(*args):
+ """Verify mocks.
+
+ Args:
+ # args is any number of mocks to be verified.
+ """
+
+ for mock in args:
+ mock._Verify()
+
+
+def Reset(*args):
+ """Reset mocks.
+
+ Args:
+ # args is any number of mocks to be reset.
+ """
+
+ for mock in args:
+ mock._Reset()
+
+
+class MockAnything:
+ """A mock that can be used to mock anything.
+
+ This is helpful for mocking classes that do not provide a public interface.
+ """
+
+ def __init__(self):
+ """ """
+ self._Reset()
+
+ def __getattr__(self, method_name):
+ """Intercept method calls on this object.
+
+ A new MockMethod is returned that is aware of the MockAnything's
+ state (record or replay). The call will be recorded or replayed
+ by the MockMethod's __call__.
+
+ Args:
+ # method name: the name of the method being called.
+ method_name: str
+
+ Returns:
+ A new MockMethod aware of MockAnything's state (record or replay).
+ """
+
+ return self._CreateMockMethod(method_name)
+
+ def _CreateMockMethod(self, method_name):
+ """Create a new mock method call and return it.
+
+ Args:
+ # method name: the name of the method being called.
+ method_name: str
+
+ Returns:
+ A new MockMethod aware of MockAnything's state (record or replay).
+ """
+
+ return MockMethod(method_name, self._expected_calls_queue,
+ self._replay_mode)
+
+ def __nonzero__(self):
+ """Return 1 for nonzero so the mock can be used as a conditional."""
+
+ return 1
+
+ def __eq__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return (isinstance(rhs, MockAnything) and
+ self._replay_mode == rhs._replay_mode and
+ self._expected_calls_queue == rhs._expected_calls_queue)
+
+ def __ne__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return not self == rhs
+
+ def _Replay(self):
+ """Start replaying expected method calls."""
+
+ self._replay_mode = True
+
+ def _Verify(self):
+ """Verify that all of the expected calls have been made.
+
+ Raises:
+ ExpectedMethodCallsError: if there are still more method calls in the
+ expected queue.
+ """
+
+ # If the list of expected calls is not empty, raise an exception
+ if self._expected_calls_queue:
+ # The last MultipleTimesGroup is not popped from the queue.
+ if (len(self._expected_calls_queue) == 1 and
+ isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
+ self._expected_calls_queue[0].IsSatisfied()):
+ pass
+ else:
+ raise ExpectedMethodCallsError(self._expected_calls_queue)
+
+ def _Reset(self):
+ """Reset the state of this mock to record mode with an empty queue."""
+
+ # Maintain a list of method calls we are expecting
+ self._expected_calls_queue = deque()
+
+ # Make sure we are in setup mode, not replay mode
+ self._replay_mode = False
+
+
+class MockObject(MockAnything, object):
+ """A mock object that simulates the public/protected interface of a class."""
+
+ def __init__(self, class_to_mock):
+ """Initialize a mock object.
+
+ This determines the methods and properties of the class and stores them.
+
+ Args:
+ # class_to_mock: class to be mocked
+ class_to_mock: class
+ """
+
+ # This is used to hack around the mixin/inheritance of MockAnything, which
+ # is not a proper object (it can be anything. :-)
+ MockAnything.__dict__['__init__'](self)
+
+ # Get a list of all the public and special methods we should mock.
+ self._known_methods = set()
+ self._known_vars = set()
+ self._class_to_mock = class_to_mock
+ for method in dir(class_to_mock):
+ if callable(getattr(class_to_mock, method)):
+ self._known_methods.add(method)
+ else:
+ self._known_vars.add(method)
+
+ def __getattr__(self, name):
+ """Intercept attribute request on this object.
+
+ If the attribute is a public class variable, it will be returned and not
+ recorded as a call.
+
+ If the attribute is not a variable, it is handled like a method
+ call. The method name is checked against the set of mockable
+ methods, and a new MockMethod is returned that is aware of the
+ MockObject's state (record or replay). The call will be recorded
+ or replayed by the MockMethod's __call__.
+
+ Args:
+ # name: the name of the attribute being requested.
+ name: str
+
+ Returns:
+ Either a class variable or a new MockMethod that is aware of the state
+ of the mock (record or replay).
+
+ Raises:
+ UnknownMethodCallError if the MockObject does not mock the requested
+ method.
+ """
+
+ if name in self._known_vars:
+ return getattr(self._class_to_mock, name)
+
+ if name in self._known_methods:
+ return self._CreateMockMethod(name)
+
+ raise UnknownMethodCallError(name)
+
+ def __eq__(self, rhs):
+ """Provide custom logic to compare objects."""
+
+ return (isinstance(rhs, MockObject) and
+ self._class_to_mock == rhs._class_to_mock and
+ self._replay_mode == rhs._replay_mode and
+ self._expected_calls_queue == rhs._expected_calls_queue)
+
+ def __setitem__(self, key, value):
+ """Provide custom logic for mocking classes that support item assignment.
+
+ Args:
+ key: Key to set the value for.
+ value: Value to set.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __setitem__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class does not support item assignment.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __setitem__.
+
+ """
+ setitem = self._class_to_mock.__dict__.get('__setitem__', None)
+
+ # Verify the class supports item assignment.
+ if setitem is None:
+ raise TypeError('object does not support item assignment')
+
+ # If we are in replay mode then simply call the mock __setitem__ method.
+ if self._replay_mode:
+ return MockMethod('__setitem__', self._expected_calls_queue,
+ self._replay_mode)(key, value)
+
+
+ # Otherwise, create a mock method __setitem__.
+ return self._CreateMockMethod('__setitem__')(key, value)
+
+ def __getitem__(self, key):
+ """Provide custom logic for mocking classes that are subscriptable.
+
+ Args:
+ key: Key to return the value for.
+
+ Returns:
+ Expected return value in replay mode. A MockMethod object for the
+ __getitem__ method that has already been called if not in replay mode.
+
+ Raises:
+ TypeError if the underlying class is not subscriptable.
+ UnexpectedMethodCallError if the object does not expect the call to
+ __setitem__.
+
+ """
+ getitem = self._class_to_mock.__dict__.get('__getitem__', None)
+
+ # Verify the class supports item assignment.
+ if getitem is None:
+ raise TypeError('unsubscriptable object')
+
+ # If we are in replay mode then simply call the mock __getitem__ method.
+ if self._replay_mode:
+ return MockMethod('__getitem__', self._expected_calls_queue,
+ self._replay_mode)(key)
+
+
+ # Otherwise, create a mock method __getitem__.
+ return self._CreateMockMethod('__getitem__')(key)
+
+ def __call__(self, *params, **named_params):
+ """Provide custom logic for mocking classes that are callable."""
+
+ # Verify the class we are mocking is callable
+ callable = self._class_to_mock.__dict__.get('__call__', None)
+ if callable is None:
+ raise TypeError('Not callable')
+
+ # Because the call is happening directly on this object instead of a method,
+ # the call on the mock method is made right here
+ mock_method = self._CreateMockMethod('__call__')
+ return mock_method(*params, **named_params)
+
+ @property
+ def __class__(self):
+ """Return the class that is being mocked."""
+
+ return self._class_to_mock
+
+
+class MockMethod(object):
+ """Callable mock method.
+
+ A MockMethod should act exactly like the method it mocks, accepting parameters
+ and returning a value, or throwing an exception (as specified). When this
+ method is called, it can optionally verify whether the called method (name and
+ signature) matches the expected method.
+ """
+
+ def __init__(self, method_name, call_queue, replay_mode):
+ """Construct a new mock method.
+
+ Args:
+ # method_name: the name of the method
+ # call_queue: deque of calls, verify this call against the head, or add
+ # this call to the queue.
+ # replay_mode: False if we are recording, True if we are verifying calls
+ # against the call queue.
+ method_name: str
+ call_queue: list or deque
+ replay_mode: bool
+ """
+
+ self._name = method_name
+ self._call_queue = call_queue
+ if not isinstance(call_queue, deque):
+ self._call_queue = deque(self._call_queue)
+ self._replay_mode = replay_mode
+
+ self._params = None
+ self._named_params = None
+ self._return_value = None
+ self._exception = None
+ self._side_effects = None
+
+ def __call__(self, *params, **named_params):
+ """Log parameters and return the specified return value.
+
+ If the Mock(Anything/Object) associated with this call is in record mode,
+ this MockMethod will be pushed onto the expected call queue. If the mock
+ is in replay mode, this will pop a MockMethod off the top of the queue and
+ verify this call is equal to the expected call.
+
+ Raises:
+ UnexpectedMethodCall if this call is supposed to match an expected method
+ call and it does not.
+ """
+
+ self._params = params
+ self._named_params = named_params
+
+ if not self._replay_mode:
+ self._call_queue.append(self)
+ return self
+
+ expected_method = self._VerifyMethodCall()
+
+ if expected_method._side_effects:
+ expected_method._side_effects(*params, **named_params)
+
+ if expected_method._exception:
+ raise expected_method._exception
+
+ return expected_method._return_value
+
+ def __getattr__(self, name):
+ """Raise an AttributeError with a helpful message."""
+
+ raise AttributeError('MockMethod has no attribute "%s". '
+ 'Did you remember to put your mocks in replay mode?' % name)
+
+ def _PopNextMethod(self):
+ """Pop the next method from our call queue."""
+ try:
+ return self._call_queue.popleft()
+ except IndexError:
+ raise UnexpectedMethodCallError(self, None)
+
+ def _VerifyMethodCall(self):
+ """Verify the called method is expected.
+
+ This can be an ordered method, or part of an unordered set.
+
+ Returns:
+ The expected mock method.
+
+ Raises:
+ UnexpectedMethodCall if the method called was not expected.
+ """
+
+ expected = self._PopNextMethod()
+
+ # Loop here, because we might have a MethodGroup followed by another
+ # group.
+ while isinstance(expected, MethodGroup):
+ expected, method = expected.MethodCalled(self)
+ if method is not None:
+ return method
+
+ # This is a mock method, so just check equality.
+ if expected != self:
+ raise UnexpectedMethodCallError(self, expected)
+
+ return expected
+
+ def __str__(self):
+ params = ', '.join(
+ [repr(p) for p in self._params or []] +
+ ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
+ desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
+ return desc
+
+ def __eq__(self, rhs):
+ """Test whether this MockMethod is equivalent to another MockMethod.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: MockMethod
+ """
+
+ return (isinstance(rhs, MockMethod) and
+ self._name == rhs._name and
+ self._params == rhs._params and
+ self._named_params == rhs._named_params)
+
+ def __ne__(self, rhs):
+ """Test whether this MockMethod is not equivalent to another MockMethod.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: MockMethod
+ """
+
+ return not self == rhs
+
+ def GetPossibleGroup(self):
+ """Returns a possible group from the end of the call queue or None if no
+ other methods are on the stack.
+ """
+
+ # Remove this method from the tail of the queue so we can add it to a group.
+ this_method = self._call_queue.pop()
+ assert this_method == self
+
+ # Determine if the tail of the queue is a group, or just a regular ordered
+ # mock method.
+ group = None
+ try:
+ group = self._call_queue[-1]
+ except IndexError:
+ pass
+
+ return group
+
+ def _CheckAndCreateNewGroup(self, group_name, group_class):
+ """Checks if the last method (a possible group) is an instance of our
+ group_class. Adds the current method to this group or creates a new one.
+
+ Args:
+
+ group_name: the name of the group.
+ group_class: the class used to create instance of this new group
+ """
+ group = self.GetPossibleGroup()
+
+ # If this is a group, and it is the correct group, add the method.
+ if isinstance(group, group_class) and group.group_name() == group_name:
+ group.AddMethod(self)
+ return self
+
+ # Create a new group and add the method.
+ new_group = group_class(group_name)
+ new_group.AddMethod(self)
+ self._call_queue.append(new_group)
+ return self
+
+ def InAnyOrder(self, group_name="default"):
+ """Move this method into a group of unordered calls.
+
+ A group of unordered calls must be defined together, and must be executed
+ in full before the next expected method can be called. There can be
+ multiple groups that are expected serially, if they are given
+ different group names. The same group name can be reused if there is a
+ standard method call, or a group with a different name, spliced between
+ usages.
+
+ Args:
+ group_name: the name of the unordered group.
+
+ Returns:
+ self
+ """
+ return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
+
+ def MultipleTimes(self, group_name="default"):
+ """Move this method into group of calls which may be called multiple times.
+
+ A group of repeating calls must be defined together, and must be executed in
+ full before the next expected mehtod can be called.
+
+ Args:
+ group_name: the name of the unordered group.
+
+ Returns:
+ self
+ """
+ return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
+
+ def AndReturn(self, return_value):
+ """Set the value to return when this method is called.
+
+ Args:
+ # return_value can be anything.
+ """
+
+ self._return_value = return_value
+ return return_value
+
+ def AndRaise(self, exception):
+ """Set the exception to raise when this method is called.
+
+ Args:
+ # exception: the exception to raise when this method is called.
+ exception: Exception
+ """
+
+ self._exception = exception
+
+ def WithSideEffects(self, side_effects):
+ """Set the side effects that are simulated when this method is called.
+
+ Args:
+ side_effects: A callable which modifies the parameters or other relevant
+ state which a given test case depends on.
+
+ Returns:
+ Self for chaining with AndReturn and AndRaise.
+ """
+ self._side_effects = side_effects
+ return self
+
+class Comparator:
+ """Base class for all Mox comparators.
+
+ A Comparator can be used as a parameter to a mocked method when the exact
+ value is not known. For example, the code you are testing might build up a
+ long SQL string that is passed to your mock DAO. You're only interested that
+ the IN clause contains the proper primary keys, so you can set your mock
+ up as follows:
+
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
+
+ Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
+
+ A Comparator may replace one or more parameters, for example:
+ # return at most 10 rows
+ mock_dao.RunQuery(StrContains('SELECT'), 10)
+
+ or
+
+ # Return some non-deterministic number of rows
+ mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
+ """
+
+ def equals(self, rhs):
+ """Special equals method that all comparators must implement.
+
+ Args:
+ rhs: any python object
+ """
+
+ raise NotImplementedError, 'method must be implemented by a subclass.'
+
+ def __eq__(self, rhs):
+ return self.equals(rhs)
+
+ def __ne__(self, rhs):
+ return not self.equals(rhs)
+
+
+class IsA(Comparator):
+ """This class wraps a basic Python type or class. It is used to verify
+ that a parameter is of the given type or class.
+
+ Example:
+ mock_dao.Connect(IsA(DbConnectInfo))
+ """
+
+ def __init__(self, class_name):
+ """Initialize IsA
+
+ Args:
+ class_name: basic python type or a class
+ """
+
+ self._class_name = class_name
+
+ def equals(self, rhs):
+ """Check to see if the RHS is an instance of class_name.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return isinstance(rhs, self._class_name)
+ except TypeError:
+ # Check raw types if there was a type error. This is helpful for
+ # things like cStringIO.StringIO.
+ return type(rhs) == type(self._class_name)
+
+ def __repr__(self):
+ return str(self._class_name)
+
+class IsAlmost(Comparator):
+ """Comparison class used to check whether a parameter is nearly equal
+ to a given value. Generally useful for floating point numbers.
+
+ Example mock_dao.SetTimeout((IsAlmost(3.9)))
+ """
+
+ def __init__(self, float_value, places=7):
+ """Initialize IsAlmost.
+
+ Args:
+ float_value: The value for making the comparison.
+ places: The number of decimal places to round to.
+ """
+
+ self._float_value = float_value
+ self._places = places
+
+ def equals(self, rhs):
+ """Check to see if RHS is almost equal to float_value
+
+ Args:
+ rhs: the value to compare to float_value
+
+ Returns:
+ bool
+ """
+
+ try:
+ return round(rhs-self._float_value, self._places) == 0
+ except TypeError:
+ # This is probably because either float_value or rhs is not a number.
+ return False
+
+ def __repr__(self):
+ return str(self._float_value)
+
+class StrContains(Comparator):
+ """Comparison class used to check whether a substring exists in a
+ string parameter. This can be useful in mocking a database with SQL
+ passed in as a string parameter, for example.
+
+ Example:
+ mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
+ """
+
+ def __init__(self, search_string):
+ """Initialize.
+
+ Args:
+ # search_string: the string you are searching for
+ search_string: str
+ """
+
+ self._search_string = search_string
+
+ def equals(self, rhs):
+ """Check to see if the search_string is contained in the rhs string.
+
+ Args:
+ # rhs: the right hand side of the test
+ rhs: object
+
+ Returns:
+ bool
+ """
+
+ try:
+ return rhs.find(self._search_string) > -1
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<str containing \'%s\'>' % self._search_string
+
+
+class Regex(Comparator):
+ """Checks if a string matches a regular expression.
+
+ This uses a given regular expression to determine equality.
+ """
+
+ def __init__(self, pattern, flags=0):
+ """Initialize.
+
+ Args:
+ # pattern is the regular expression to search for
+ pattern: str
+ # flags passed to re.compile function as the second argument
+ flags: int
+ """
+
+ self.regex = re.compile(pattern, flags=flags)
+
+ def equals(self, rhs):
+ """Check to see if rhs matches regular expression pattern.
+
+ Returns:
+ bool
+ """
+
+ return self.regex.search(rhs) is not None
+
+ def __repr__(self):
+ s = '<regular expression \'%s\'' % self.regex.pattern
+ if self.regex.flags:
+ s += ', flags=%d' % self.regex.flags
+ s += '>'
+ return s
+
+
+class In(Comparator):
+ """Checks whether an item (or key) is in a list (or dict) parameter.
+
+ Example:
+ mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
+ """
+
+ def __init__(self, key):
+ """Initialize.
+
+ Args:
+ # key is any thing that could be in a list or a key in a dict
+ """
+
+ self._key = key
+
+ def equals(self, rhs):
+ """Check to see whether key is in rhs.
+
+ Args:
+ rhs: dict
+
+ Returns:
+ bool
+ """
+
+ return self._key in rhs
+
+ def __repr__(self):
+ return '<sequence or map containing \'%s\'>' % self._key
+
+
+class ContainsKeyValue(Comparator):
+ """Checks whether a key/value pair is in a dict parameter.
+
+ Example:
+ mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
+ """
+
+ def __init__(self, key, value):
+ """Initialize.
+
+ Args:
+ # key: a key in a dict
+ # value: the corresponding value
+ """
+
+ self._key = key
+ self._value = value
+
+ def equals(self, rhs):
+ """Check whether the given key/value pair is in the rhs dict.
+
+ Returns:
+ bool
+ """
+
+ try:
+ return rhs[self._key] == self._value
+ except Exception:
+ return False
+
+ def __repr__(self):
+ return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
+
+
+class SameElementsAs(Comparator):
+ """Checks whether iterables contain the same elements (ignoring order).
+
+ Example:
+ mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
+ """
+
+ def __init__(self, expected_seq):
+ """Initialize.
+
+ Args:
+ expected_seq: a sequence
+ """
+
+ self._expected_seq = expected_seq
+
+ def equals(self, actual_seq):
+ """Check to see whether actual_seq has same elements as expected_seq.
+
+ Args:
+ actual_seq: sequence
+
+ Returns:
+ bool
+ """
+
+ try:
+ expected = dict([(element, None) for element in self._expected_seq])
+ actual = dict([(element, None) for element in actual_seq])
+ except TypeError:
+ # Fall back to slower list-compare if any of the objects are unhashable.
+ expected = list(self._expected_seq)
+ actual = list(actual_seq)
+ expected.sort()
+ actual.sort()
+ return expected == actual
+
+ def __repr__(self):
+ return '<sequence with same elements as \'%s\'>' % self._expected_seq
+
+
+class And(Comparator):
+ """Evaluates one or more Comparators on RHS and returns an AND of the results.
+ """
+
+ def __init__(self, *args):
+ """Initialize.
+
+ Args:
+ *args: One or more Comparator
+ """
+
+ self._comparators = args
+
+ def equals(self, rhs):
+ """Checks whether all Comparators are equal to rhs.
+
+ Args:
+ # rhs: can be anything
+
+ Returns:
+ bool
+ """
+
+ for comparator in self._comparators:
+ if not comparator.equals(rhs):
+ return False
+
+ return True
+
+ def __repr__(self):
+ return '<AND %s>' % str(self._comparators)
+
+
+class Or(Comparator):
+ """Evaluates one or more Comparators on RHS and returns an OR of the results.
+ """
+
+ def __init__(self, *args):
+ """Initialize.
+
+ Args:
+ *args: One or more Mox comparators
+ """
+
+ self._comparators = args
+
+ def equals(self, rhs):
+ """Checks whether any Comparator is equal to rhs.
+
+ Args:
+ # rhs: can be anything
+
+ Returns:
+ bool
+ """
+
+ for comparator in self._comparators:
+ if comparator.equals(rhs):
+ return True
+
+ return False
+
+ def __repr__(self):
+ return '<OR %s>' % str(self._comparators)
+
+
+class Func(Comparator):
+ """Call a function that should verify the parameter passed in is correct.
+
+ You may need the ability to perform more advanced operations on the parameter
+ in order to validate it. You can use this to have a callable validate any
+ parameter. The callable should return either True or False.
+
+
+ Example:
+
+ def myParamValidator(param):
+ # Advanced logic here
+ return True
+
+ mock_dao.DoSomething(Func(myParamValidator), true)
+ """
+
+ def __init__(self, func):
+ """Initialize.
+
+ Args:
+ func: callable that takes one parameter and returns a bool
+ """
+
+ self._func = func
+
+ def equals(self, rhs):
+ """Test whether rhs passes the function test.
+
+ rhs is passed into func.
+
+ Args:
+ rhs: any python object
+
+ Returns:
+ the result of func(rhs)
+ """
+
+ return self._func(rhs)
+
+ def __repr__(self):
+ return str(self._func)
+
+
+class IgnoreArg(Comparator):
+ """Ignore an argument.
+
+ This can be used when we don't care about an argument of a method call.
+
+ Example:
+ # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
+ mymock.CastMagic(3, IgnoreArg(), 'disappear')
+ """
+
+ def equals(self, unused_rhs):
+ """Ignores arguments and returns True.
+
+ Args:
+ unused_rhs: any python object
+
+ Returns:
+ always returns True
+ """
+
+ return True
+
+ def __repr__(self):
+ return '<IgnoreArg>'
+
+
+class MethodGroup(object):
+ """Base class containing common behaviour for MethodGroups."""
+
+ def __init__(self, group_name):
+ self._group_name = group_name
+
+ def group_name(self):
+ return self._group_name
+
+ def __str__(self):
+ return '<%s "%s">' % (self.__class__.__name__, self._group_name)
+
+ def AddMethod(self, mock_method):
+ raise NotImplementedError
+
+ def MethodCalled(self, mock_method):
+ raise NotImplementedError
+
+ def IsSatisfied(self):
+ raise NotImplementedError
+
+class UnorderedGroup(MethodGroup):
+ """UnorderedGroup holds a set of method calls that may occur in any order.
+
+ This construct is helpful for non-deterministic events, such as iterating
+ over the keys of a dict.
+ """
+
+ def __init__(self, group_name):
+ super(UnorderedGroup, self).__init__(group_name)
+ self._methods = []
+
+ def AddMethod(self, mock_method):
+ """Add a method to this group.
+
+ Args:
+ mock_method: A mock method to be added to this group.
+ """
+
+ self._methods.append(mock_method)
+
+ def MethodCalled(self, mock_method):
+ """Remove a method call from the group.
+
+ If the method is not in the set, an UnexpectedMethodCallError will be
+ raised.
+
+ Args:
+ mock_method: a mock method that should be equal to a method in the group.
+
+ Returns:
+ The mock method from the group
+
+ Raises:
+ UnexpectedMethodCallError if the mock_method was not in the group.
+ """
+
+ # Check to see if this method exists, and if so, remove it from the set
+ # and return it.
+ for method in self._methods:
+ if method == mock_method:
+ # Remove the called mock_method instead of the method in the group.
+ # The called method will match any comparators when equality is checked
+ # during removal. The method in the group could pass a comparator to
+ # another comparator during the equality check.
+ self._methods.remove(mock_method)
+
+ # If this group is not empty, put it back at the head of the queue.
+ if not self.IsSatisfied():
+ mock_method._call_queue.appendleft(self)
+
+ return self, method
+
+ raise UnexpectedMethodCallError(mock_method, self)
+
+ def IsSatisfied(self):
+ """Return True if there are not any methods in this group."""
+
+ return len(self._methods) == 0
+
+
+class MultipleTimesGroup(MethodGroup):
+ """MultipleTimesGroup holds methods that may be called any number of times.
+
+ Note: Each method must be called at least once.
+
+ This is helpful, if you don't know or care how many times a method is called.
+ """
+
+ def __init__(self, group_name):
+ super(MultipleTimesGroup, self).__init__(group_name)
+ self._methods = set()
+ self._methods_called = set()
+
+ def AddMethod(self, mock_method):
+ """Add a method to this group.
+
+ Args:
+ mock_method: A mock method to be added to this group.
+ """
+
+ self._methods.add(mock_method)
+
+ def MethodCalled(self, mock_method):
+ """Remove a method call from the group.
+
+ If the method is not in the set, an UnexpectedMethodCallError will be
+ raised.
+
+ Args:
+ mock_method: a mock method that should be equal to a method in the group.
+
+ Returns:
+ The mock method from the group
+
+ Raises:
+ UnexpectedMethodCallError if the mock_method was not in the group.
+ """
+
+ # Check to see if this method exists, and if so add it to the set of
+ # called methods.
+
+ for method in self._methods:
+ if method == mock_method:
+ self._methods_called.add(mock_method)
+ # Always put this group back on top of the queue, because we don't know
+ # when we are done.
+ mock_method._call_queue.appendleft(self)
+ return self, method
+
+ if self.IsSatisfied():
+ next_method = mock_method._PopNextMethod();
+ return next_method, None
+ else:
+ raise UnexpectedMethodCallError(mock_method, self)
+
+ def IsSatisfied(self):
+ """Return True if all methods in this group are called at least once."""
+ # NOTE(psycho): We can't use the simple set difference here because we want
+ # to match different parameters which are considered the same e.g. IsA(str)
+ # and some string. This solution is O(n^2) but n should be small.
+ tmp = self._methods.copy()
+ for called in self._methods_called:
+ for expected in tmp:
+ if called == expected:
+ tmp.remove(expected)
+ if not tmp:
+ return True
+ break
+ return False
+
+
+class MoxMetaTestBase(type):
+ """Metaclass to add mox cleanup and verification to every test.
+
+ As the mox unit testing class is being constructed (MoxTestBase or a
+ subclass), this metaclass will modify all test functions to call the
+ CleanUpMox method of the test class after they finish. This means that
+ unstubbing and verifying will happen for every test with no additional code,
+ and any failures will result in test failures as opposed to errors.
+ """
+
+ def __init__(cls, name, bases, d):
+ type.__init__(cls, name, bases, d)
+
+ # also get all the attributes from the base classes to account
+ # for a case when test class is not the immediate child of MoxTestBase
+ for base in bases:
+ for attr_name in dir(base):
+ d[attr_name] = getattr(base, attr_name)
+
+ for func_name, func in d.items():
+ if func_name.startswith('test') and callable(func):
+ setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
+
+ @staticmethod
+ def CleanUpTest(cls, func):
+ """Adds Mox cleanup code to any MoxTestBase method.
+
+ Always unsets stubs after a test. Will verify all mocks for tests that
+ otherwise pass.
+
+ Args:
+ cls: MoxTestBase or subclass; the class whose test method we are altering.
+ func: method; the method of the MoxTestBase test class we wish to alter.
+
+ Returns:
+ The modified method.
+ """
+ def new_method(self, *args, **kwargs):
+ mox_obj = getattr(self, 'mox', None)
+ cleanup_mox = False
+ if mox_obj and isinstance(mox_obj, Mox):
+ cleanup_mox = True
+ try:
+ func(self, *args, **kwargs)
+ finally:
+ if cleanup_mox:
+ mox_obj.UnsetStubs()
+ if cleanup_mox:
+ mox_obj.VerifyAll()
+ new_method.__name__ = func.__name__
+ new_method.__doc__ = func.__doc__
+ new_method.__module__ = func.__module__
+ return new_method
+
+
+class MoxTestBase(unittest.TestCase):
+ """Convenience test class to make stubbing easier.
+
+ Sets up a "mox" attribute which is an instance of Mox - any mox tests will
+ want this. Also automatically unsets any stubs and verifies that all mock
+ methods have been called at the end of each test, eliminating boilerplate
+ code.
+ """
+
+ __metaclass__ = MoxMetaTestBase
+
+ def setUp(self):
+ self.mox = Mox()
diff --git a/python/setup.py b/python/setup.py
new file mode 100755
index 0000000..d2997da
--- /dev/null
+++ b/python/setup.py
@@ -0,0 +1,136 @@
+#! /usr/bin/python
+#
+# See README for usage instructions.
+
+# We must use setuptools, not distutils, because we need to use the
+# namespace_packages option for the "google" package.
+from ez_setup import use_setuptools
+use_setuptools()
+
+from setuptools import setup
+from distutils.spawn import find_executable
+import sys
+import os
+import subprocess
+
+maintainer_email = "protobuf@googlegroups.com"
+
+# Find the Protocol Compiler.
+if os.path.exists("../src/protoc"):
+ protoc = "../src/protoc"
+elif os.path.exists("../src/protoc.exe"):
+ protoc = "../src/protoc.exe"
+else:
+ protoc = find_executable("protoc")
+
+def generate_proto(source):
+ """Invokes the Protocol Compiler to generate a _pb2.py from the given
+ .proto file. Does nothing if the output already exists and is newer than
+ the input."""
+
+ output = source.replace(".proto", "_pb2.py").replace("../src/", "")
+
+ if not os.path.exists(source):
+ print "Can't find required file: " + source
+ sys.exit(-1)
+
+ if (not os.path.exists(output) or
+ (os.path.exists(source) and
+ os.path.getmtime(source) > os.path.getmtime(output))):
+ print "Generating %s..." % output
+
+ if protoc == None:
+ sys.stderr.write(
+ "protoc is not installed nor found in ../src. Please compile it "
+ "or install the binary package.\n")
+ sys.exit(-1)
+
+ protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ]
+ if subprocess.call(protoc_command) != 0:
+ sys.exit(-1)
+
+def MakeTestSuite():
+ # This is apparently needed on some systems to make sure that the tests
+ # work even if a previous version is already installed.
+ if 'google' in sys.modules:
+ del sys.modules['google']
+
+ generate_proto("../src/google/protobuf/unittest.proto")
+ generate_proto("../src/google/protobuf/unittest_import.proto")
+ generate_proto("../src/google/protobuf/unittest_mset.proto")
+ generate_proto("google/protobuf/internal/more_extensions.proto")
+ generate_proto("google/protobuf/internal/more_messages.proto")
+
+ import unittest
+ import google.protobuf.internal.generator_test as generator_test
+ import google.protobuf.internal.decoder_test as decoder_test
+ import google.protobuf.internal.descriptor_test as descriptor_test
+ import google.protobuf.internal.encoder_test as encoder_test
+ import google.protobuf.internal.input_stream_test as input_stream_test
+ import google.protobuf.internal.output_stream_test as output_stream_test
+ import google.protobuf.internal.reflection_test as reflection_test
+ import google.protobuf.internal.service_reflection_test \
+ as service_reflection_test
+ import google.protobuf.internal.text_format_test as text_format_test
+ import google.protobuf.internal.wire_format_test as wire_format_test
+
+ loader = unittest.defaultTestLoader
+ suite = unittest.TestSuite()
+ for test in [ generator_test,
+ decoder_test,
+ descriptor_test,
+ encoder_test,
+ input_stream_test,
+ output_stream_test,
+ reflection_test,
+ service_reflection_test,
+ text_format_test,
+ wire_format_test ]:
+ suite.addTest(loader.loadTestsFromModule(test))
+
+ return suite
+
+if __name__ == '__main__':
+ # TODO(kenton): Integrate this into setuptools somehow?
+ if len(sys.argv) >= 2 and sys.argv[1] == "clean":
+ # Delete generated _pb2.py files and .pyc files in the code tree.
+ for (dirpath, dirnames, filenames) in os.walk("."):
+ for filename in filenames:
+ filepath = os.path.join(dirpath, filename)
+ if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"):
+ os.remove(filepath)
+ else:
+ # Generate necessary .proto file if it doesn't exist.
+ # TODO(kenton): Maybe we should hook this into a distutils command?
+ generate_proto("../src/google/protobuf/descriptor.proto")
+
+ setup(name = 'protobuf',
+ version = '2.2.0',
+ packages = [ 'google' ],
+ namespace_packages = [ 'google' ],
+ test_suite = 'setup.MakeTestSuite',
+ # Must list modules explicitly so that we don't install tests.
+ py_modules = [
+ 'google.protobuf.internal.containers',
+ 'google.protobuf.internal.decoder',
+ 'google.protobuf.internal.encoder',
+ 'google.protobuf.internal.input_stream',
+ 'google.protobuf.internal.message_listener',
+ 'google.protobuf.internal.output_stream',
+ 'google.protobuf.internal.type_checkers',
+ 'google.protobuf.internal.wire_format',
+ 'google.protobuf.descriptor',
+ 'google.protobuf.descriptor_pb2',
+ 'google.protobuf.message',
+ 'google.protobuf.reflection',
+ 'google.protobuf.service',
+ 'google.protobuf.service_reflection',
+ 'google.protobuf.text_format' ],
+ url = 'http://code.google.com/p/protobuf/',
+ maintainer = maintainer_email,
+ maintainer_email = 'protobuf@googlegroups.com',
+ license = 'New BSD License',
+ description = 'Protocol Buffers',
+ long_description =
+ "Protocol Buffers are Google's data interchange format.",
+ )
diff --git a/python/stubout.py b/python/stubout.py
new file mode 100755
index 0000000..aee4f2d
--- /dev/null
+++ b/python/stubout.py
@@ -0,0 +1,140 @@
+#!/usr/bin/python2.4
+#
+# Copyright 2008 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This file is used for testing. The original is at:
+# http://code.google.com/p/pymox/
+
+class StubOutForTesting:
+ """Sample Usage:
+ You want os.path.exists() to always return true during testing.
+
+ stubs = StubOutForTesting()
+ stubs.Set(os.path, 'exists', lambda x: 1)
+ ...
+ stubs.UnsetAll()
+
+ The above changes os.path.exists into a lambda that returns 1. Once
+ the ... part of the code finishes, the UnsetAll() looks up the old value
+ of os.path.exists and restores it.
+
+ """
+ def __init__(self):
+ self.cache = []
+ self.stubs = []
+
+ def __del__(self):
+ self.SmartUnsetAll()
+ self.UnsetAll()
+
+ def SmartSet(self, obj, attr_name, new_attr):
+ """Replace obj.attr_name with new_attr. This method is smart and works
+ at the module, class, and instance level while preserving proper
+ inheritance. It will not stub out C types however unless that has been
+ explicitly allowed by the type.
+
+ This method supports the case where attr_name is a staticmethod or a
+ classmethod of obj.
+
+ Notes:
+ - If obj is an instance, then it is its class that will actually be
+ stubbed. Note that the method Set() does not do that: if obj is
+ an instance, it (and not its class) will be stubbed.
+ - The stubbing is using the builtin getattr and setattr. So, the __get__
+ and __set__ will be called when stubbing (TODO: A better idea would
+ probably be to manipulate obj.__dict__ instead of getattr() and
+ setattr()).
+
+ Raises AttributeError if the attribute cannot be found.
+ """
+ if (inspect.ismodule(obj) or
+ (not inspect.isclass(obj) and obj.__dict__.has_key(attr_name))):
+ orig_obj = obj
+ orig_attr = getattr(obj, attr_name)
+
+ else:
+ if not inspect.isclass(obj):
+ mro = list(inspect.getmro(obj.__class__))
+ else:
+ mro = list(inspect.getmro(obj))
+
+ mro.reverse()
+
+ orig_attr = None
+
+ for cls in mro:
+ try:
+ orig_obj = cls
+ orig_attr = getattr(obj, attr_name)
+ except AttributeError:
+ continue
+
+ if orig_attr is None:
+ raise AttributeError("Attribute not found.")
+
+ # Calling getattr() on a staticmethod transforms it to a 'normal' function.
+ # We need to ensure that we put it back as a staticmethod.
+ old_attribute = obj.__dict__.get(attr_name)
+ if old_attribute is not None and isinstance(old_attribute, staticmethod):
+ orig_attr = staticmethod(orig_attr)
+
+ self.stubs.append((orig_obj, attr_name, orig_attr))
+ setattr(orig_obj, attr_name, new_attr)
+
+ def SmartUnsetAll(self):
+ """Reverses all the SmartSet() calls, restoring things to their original
+ definition. Its okay to call SmartUnsetAll() repeatedly, as later calls
+ have no effect if no SmartSet() calls have been made.
+
+ """
+ self.stubs.reverse()
+
+ for args in self.stubs:
+ setattr(*args)
+
+ self.stubs = []
+
+ def Set(self, parent, child_name, new_child):
+ """Replace child_name's old definition with new_child, in the context
+ of the given parent. The parent could be a module when the child is a
+ function at module scope. Or the parent could be a class when a class'
+ method is being replaced. The named child is set to new_child, while
+ the prior definition is saved away for later, when UnsetAll() is called.
+
+ This method supports the case where child_name is a staticmethod or a
+ classmethod of parent.
+ """
+ old_child = getattr(parent, child_name)
+
+ old_attribute = parent.__dict__.get(child_name)
+ if old_attribute is not None and isinstance(old_attribute, staticmethod):
+ old_child = staticmethod(old_child)
+
+ self.cache.append((parent, old_child, child_name))
+ setattr(parent, child_name, new_child)
+
+ def UnsetAll(self):
+ """Reverses all the Set() calls, restoring things to their original
+ definition. Its okay to call UnsetAll() repeatedly, as later calls have
+ no effect if no Set() calls have been made.
+
+ """
+ # Undo calls to Set() in reverse order, in case Set() was called on the
+ # same arguments repeatedly (want the original call to be last one undone)
+ self.cache.reverse()
+
+ for (parent, old_child, child_name) in self.cache:
+ setattr(parent, child_name, old_child)
+ self.cache = []