diff options
author | Wink Saville <wink@google.com> | 2010-05-27 16:25:37 -0700 |
---|---|---|
committer | Wink Saville <wink@google.com> | 2010-05-27 16:25:37 -0700 |
commit | fbaaef999ba563838ebd00874ed8a1c01fbf286d (patch) | |
tree | 24ff5c76344e90abc5b0fe6f07120ea0d2d011ee /python | |
parent | 79a4a60053f74ab71c7c3ec436d2f6caedc5be61 (diff) | |
download | external_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.zip external_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.tar.gz external_protobuf-fbaaef999ba563838ebd00874ed8a1c01fbf286d.tar.bz2 |
Add protobuf 2.2.0a sources
This is the contents of protobuf-2.2.0a.tar.bz2 from
http://code.google.com/p/protobuf/downloads/list and
is the base code for the javamicro code generator.
Change-Id: Ie9a0440a824d615086445b6636944484b3155afa
Diffstat (limited to 'python')
36 files changed, 12092 insertions, 0 deletions
diff --git a/python/README.txt b/python/README.txt new file mode 100644 index 0000000..96f1a73 --- /dev/null +++ b/python/README.txt @@ -0,0 +1,73 @@ +Protocol Buffers - Google's data interchange format +Copyright 2008 Google Inc. + +This directory contains the Python Protocol Buffers runtime library. + +Normally, this directory comes as part of the protobuf package, available +from: + + http://code.google.com/p/protobuf + +The complete package includes the C++ source code, which includes the +Protocol Compiler (protoc). If you downloaded this package from PyPI +or some other Python-specific source, you may have received only the +Python part of the code. In this case, you will need to obtain the +Protocol Compiler from some other source before you can use this +package. + +Development Warning +=================== + +The Python implementation of Protocol Buffers is not as mature as the C++ +and Java implementations. It may be more buggy, and it is known to be +pretty slow at this time. If you would like to help fix these issues, +join the Protocol Buffers discussion list and let us know! + +Installation +============ + +1) Make sure you have Python 2.4 or newer. If in doubt, run: + + $ python -V + +2) If you do not have setuptools installed, note that it will be + downloaded and installed automatically as soon as you run setup.py. + If you would rather install it manually, you may do so by following + the instructions on this page: + + http://peak.telecommunity.com/DevCenter/EasyInstall#installation-instructions + +3) Build the C++ code, or install a binary distribution of protoc. If + you install a binary distribution, make sure that it is the same + version as this package. If in doubt, run: + + $ protoc --version + +4) Run the tests: + + $ python setup.py test + + If some tests fail, this library may not work correctly on your + system. Continue at your own risk. + + Please note that there is a known problem with some versions of + Python on Cygwin which causes the tests to fail after printing the + error: "sem_init: Resource temporarily unavailable". This appears + to be a bug either in Cygwin or in Python: + http://www.cygwin.com/ml/cygwin/2005-07/msg01378.html + We do not know if or when it might me fixed. We also do not know + how likely it is that this bug will affect users in practice. + +5) Install: + + $ python setup.py install + + This step may require superuser privileges. + +Usage +===== + +The complete documentation for Protocol Buffers is available via the +web at: + + http://code.google.com/apis/protocolbuffers/ diff --git a/python/ez_setup.py b/python/ez_setup.py new file mode 100755 index 0000000..b7a9849 --- /dev/null +++ b/python/ez_setup.py @@ -0,0 +1,281 @@ +#!python + +# This file was obtained from: +# http://peak.telecommunity.com/dist/ez_setup.py +# on 2009/4/17. + +"""Bootstrap setuptools installation + +If you want to use setuptools in your package's setup.py, just include this +file in the same directory with it, and add this to the top of your setup.py:: + + from ez_setup import use_setuptools + use_setuptools() + +If you want to require a specific version of setuptools, set a download +mirror, or use an alternate download directory, you can do so by supplying +the appropriate options to ``use_setuptools()``. + +This file can also be run as a script to install or upgrade setuptools. +""" +import sys +DEFAULT_VERSION = "0.6c9" +DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] + +md5_data = { + 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca', + 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb', + 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b', + 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a', + 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618', + 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac', + 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5', + 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', + 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', + 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', + 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', + 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', + 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', + 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e', + 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e', + 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f', + 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2', + 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc', + 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167', + 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64', + 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d', + 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20', + 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab', + 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53', + 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2', + 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e', + 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372', + 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902', + 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de', + 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b', + 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03', + 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a', + 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6', + 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a', +} + +import sys, os +try: from hashlib import md5 +except ImportError: from md5 import md5 + +def _validate_md5(egg_name, data): + if egg_name in md5_data: + digest = md5(data).hexdigest() + if digest != md5_data[egg_name]: + print >>sys.stderr, ( + "md5 validation of %s failed! (Possible download problem?)" + % egg_name + ) + sys.exit(2) + return data + +def use_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + download_delay=15 +): + """Automatically find/download setuptools and make it available on sys.path + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end with + a '/'). `to_dir` is the directory where setuptools will be downloaded, if + it is not already available. If `download_delay` is specified, it should + be the number of seconds that will be paused before initiating a download, + should one be required. If an older version of setuptools is installed, + this routine will print a message to ``sys.stderr`` and raise SystemExit in + an attempt to abort the calling script. + """ + was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules + def do_download(): + egg = download_setuptools(version, download_base, to_dir, download_delay) + sys.path.insert(0, egg) + import setuptools; setuptools.bootstrap_install_from = egg + try: + import pkg_resources + except ImportError: + return do_download() + try: + pkg_resources.require("setuptools>="+version); return + except pkg_resources.VersionConflict, e: + if was_imported: + print >>sys.stderr, ( + "The required version of setuptools (>=%s) is not available, and\n" + "can't be installed while this script is running. Please install\n" + " a more recent version first, using 'easy_install -U setuptools'." + "\n\n(Currently using %r)" + ) % (version, e.args[0]) + sys.exit(2) + else: + del pkg_resources, sys.modules['pkg_resources'] # reload ok + return do_download() + except pkg_resources.DistributionNotFound: + return do_download() + +def download_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + delay = 15 +): + """Download setuptools from a specified location and return its filename + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end + with a '/'). `to_dir` is the directory where the egg will be downloaded. + `delay` is the number of seconds to pause before an actual download attempt. + """ + import urllib2, shutil + egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3]) + url = download_base + egg_name + saveto = os.path.join(to_dir, egg_name) + src = dst = None + if not os.path.exists(saveto): # Avoid repeated downloads + try: + from distutils import log + if delay: + log.warn(""" +--------------------------------------------------------------------------- +This script requires setuptools version %s to run (even to display +help). I will attempt to download it for you (from +%s), but +you may need to enable firewall access for this script first. +I will start the download in %d seconds. + +(Note: if this machine does not have network access, please obtain the file + + %s + +and place it in this directory before rerunning this script.) +---------------------------------------------------------------------------""", + version, download_base, delay, url + ); from time import sleep; sleep(delay) + log.warn("Downloading %s", url) + src = urllib2.urlopen(url) + # Read/write all in one block, so we don't create a corrupt file + # if the download is interrupted. + data = _validate_md5(egg_name, src.read()) + dst = open(saveto,"wb"); dst.write(data) + finally: + if src: src.close() + if dst: dst.close() + return os.path.realpath(saveto) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +def main(argv, version=DEFAULT_VERSION): + """Install or upgrade setuptools and EasyInstall""" + try: + import setuptools + except ImportError: + egg = None + try: + egg = download_setuptools(version, delay=0) + sys.path.insert(0,egg) + from setuptools.command.easy_install import main + return main(list(argv)+[egg]) # we're done here + finally: + if egg and os.path.exists(egg): + os.unlink(egg) + else: + if setuptools.__version__ == '0.0.1': + print >>sys.stderr, ( + "You have an obsolete version of setuptools installed. Please\n" + "remove it from your system entirely before rerunning this script." + ) + sys.exit(2) + + req = "setuptools>="+version + import pkg_resources + try: + pkg_resources.require(req) + except pkg_resources.VersionConflict: + try: + from setuptools.command.easy_install import main + except ImportError: + from easy_install import main + main(list(argv)+[download_setuptools(delay=0)]) + sys.exit(0) # try to force an exit + else: + if argv: + from setuptools.command.easy_install import main + main(argv) + else: + print "Setuptools version",version,"or greater has been installed." + print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)' + +def update_md5(filenames): + """Update our built-in md5 registry""" + + import re + + for name in filenames: + base = os.path.basename(name) + f = open(name,'rb') + md5_data[base] = md5(f.read()).hexdigest() + f.close() + + data = [" %r: %r,\n" % it for it in md5_data.items()] + data.sort() + repl = "".join(data) + + import inspect + srcfile = inspect.getsourcefile(sys.modules[__name__]) + f = open(srcfile, 'rb'); src = f.read(); f.close() + + match = re.search("\nmd5_data = {\n([^}]+)}", src) + if not match: + print >>sys.stderr, "Internal error!" + sys.exit(2) + + src = src[:match.start(1)] + repl + src[match.end(1):] + f = open(srcfile,'w') + f.write(src) + f.close() + + +if __name__=='__main__': + if len(sys.argv)>2 and sys.argv[1]=='--md5update': + update_md5(sys.argv[2:]) + else: + main(sys.argv[1:]) + + + + + + diff --git a/python/google/__init__.py b/python/google/__init__.py new file mode 100755 index 0000000..de40ea7 --- /dev/null +++ b/python/google/__init__.py @@ -0,0 +1 @@ +__import__('pkg_resources').declare_namespace(__name__) diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/python/google/protobuf/__init__.py diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py new file mode 100755 index 0000000..8e3fc2e --- /dev/null +++ b/python/google/protobuf/descriptor.py @@ -0,0 +1,433 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We probably need to provide deep-copy methods for +# descriptor types. When a FieldDescriptor is passed into +# Descriptor.__init__(), we should make a deep copy and then set +# containing_type on it. Alternatively, we could just get +# rid of containing_type (iit's not needed for reflection.py, at least). +# +# TODO(robinson): Print method? +# +# TODO(robinson): Useful __repr__? + +"""Descriptors essentially contain exactly the information found in a .proto +file, in types that make this information accessible in Python. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +class DescriptorBase(object): + + """Descriptors base class. + + This class is the base of all descriptor classes. It provides common options + related functionaility. + """ + + def __init__(self, options, options_class_name): + """Initialize the descriptor given its options message and the name of the + class of the options message. The name of the class is required in case + the options message is None and has to be created. + """ + self._options = options + self._options_class_name = options_class_name + + def GetOptions(self): + """Retrieves descriptor options. + + This method returns the options set or creates the default options for the + descriptor. + """ + if self._options: + return self._options + from google.protobuf import descriptor_pb2 + try: + options_class = getattr(descriptor_pb2, self._options_class_name) + except AttributeError: + raise RuntimeError('Unknown options class name %s!' % + (self._options_class_name)) + self._options = options_class() + return self._options + + +class Descriptor(DescriptorBase): + + """Descriptor for a protocol message type. + + A Descriptor instance has the following attributes: + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + + filename: (str) Name of the .proto file containing this message. + + containing_type: (Descriptor) Reference to the descriptor of the + type containing us, or None if we have no containing type. + + fields: (list of FieldDescriptors) Field descriptors for all + fields in this type. + fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "number" attribute in each + FieldDescriptor. + fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "name" attribute in each + FieldDescriptor. + + nested_types: (list of Descriptors) Descriptor references + for all protocol message types nested within this one. + nested_types_by_name: (dict str -> Descriptor) Same Descriptor + objects as in |nested_types|, but indexed by "name" attribute + in each Descriptor. + + enum_types: (list of EnumDescriptors) EnumDescriptor references + for all enums contained within this type. + enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor + objects as in |enum_types|, but indexed by "name" attribute + in each EnumDescriptor. + enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping + from enum value name to EnumValueDescriptor for that value. + + extensions: (list of FieldDescriptor) All extensions defined directly + within this message type (NOT within a nested type). + extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor + objects as |extensions|, but indexed by "name" attribute of each + FieldDescriptor. + + options: (descriptor_pb2.MessageOptions) Protocol message options or None + to use default message options. + """ + + def __init__(self, name, full_name, filename, containing_type, + fields, nested_types, enum_types, extensions, options=None): + """Arguments to __init__() are as described in the description + of Descriptor fields above. + """ + super(Descriptor, self).__init__(options, 'MessageOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.containing_type = containing_type + + # We have fields in addition to fields_by_name and fields_by_number, + # so that: + # 1. Clients can index fields by "order in which they're listed." + # 2. Clients can easily iterate over all fields with the terse + # syntax: for f in descriptor.fields: ... + self.fields = fields + for field in self.fields: + field.containing_type = self + self.fields_by_number = dict((f.number, f) for f in fields) + self.fields_by_name = dict((f.name, f) for f in fields) + + self.nested_types = nested_types + self.nested_types_by_name = dict((t.name, t) for t in nested_types) + + self.enum_types = enum_types + for enum_type in self.enum_types: + enum_type.containing_type = self + self.enum_types_by_name = dict((t.name, t) for t in enum_types) + self.enum_values_by_name = dict( + (v.name, v) for t in enum_types for v in t.values) + + self.extensions = extensions + for extension in self.extensions: + extension.extension_scope = self + self.extensions_by_name = dict((f.name, f) for f in extensions) + + +# TODO(robinson): We should have aggressive checking here, +# for example: +# * If you specify a repeated field, you should not be allowed +# to specify a default value. +# * [Other examples here as needed]. +# +# TODO(robinson): for this and other *Descriptor classes, we +# might also want to lock things down aggressively (e.g., +# prevent clients from setting the attributes). Having +# stronger invariants here in general will reduce the number +# of runtime checks we must do in reflection.py... +class FieldDescriptor(DescriptorBase): + + """Descriptor for a single field in a .proto file. + + A FieldDescriptor instance has the following attriubtes: + + name: (str) Name of this field, exactly as it appears in .proto. + full_name: (str) Name of this field, including containing scope. This is + particularly relevant for extensions. + index: (int) Dense, 0-indexed index giving the order that this + field textually appears within its message in the .proto file. + number: (int) Tag number declared for this field in the .proto file. + + type: (One of the TYPE_* constants below) Declared type. + cpp_type: (One of the CPPTYPE_* constants below) C++ type used to + represent this field. + + label: (One of the LABEL_* constants below) Tells whether this + field is optional, required, or repeated. + default_value: (Varies) Default value of this field. Only + meaningful for non-repeated scalar fields. Repeated fields + should always set this to [], and non-repeated composite + fields should always set this to None. + + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + Somewhat confusingly, for extension fields, this is the + descriptor of the EXTENDED message, not the descriptor + of the message containing this field. (See is_extension and + extension_scope below). + message_type: (Descriptor) If a composite field, a descriptor + of the message type contained in this field. Otherwise, this is None. + enum_type: (EnumDescriptor) If this field contains an enum, a + descriptor of that enum. Otherwise, this is None. + + is_extension: True iff this describes an extension field. + extension_scope: (Descriptor) Only meaningful if is_extension is True. + Gives the message that immediately contains this extension field. + Will be None iff we're a top-level (file-level) extension field. + + options: (descriptor_pb2.FieldOptions) Protocol message field options or + None to use default field options. + """ + + # Must be consistent with C++ FieldDescriptor::Type enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + TYPE_DOUBLE = 1 + TYPE_FLOAT = 2 + TYPE_INT64 = 3 + TYPE_UINT64 = 4 + TYPE_INT32 = 5 + TYPE_FIXED64 = 6 + TYPE_FIXED32 = 7 + TYPE_BOOL = 8 + TYPE_STRING = 9 + TYPE_GROUP = 10 + TYPE_MESSAGE = 11 + TYPE_BYTES = 12 + TYPE_UINT32 = 13 + TYPE_ENUM = 14 + TYPE_SFIXED32 = 15 + TYPE_SFIXED64 = 16 + TYPE_SINT32 = 17 + TYPE_SINT64 = 18 + MAX_TYPE = 18 + + # Must be consistent with C++ FieldDescriptor::CppType enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + CPPTYPE_INT32 = 1 + CPPTYPE_INT64 = 2 + CPPTYPE_UINT32 = 3 + CPPTYPE_UINT64 = 4 + CPPTYPE_DOUBLE = 5 + CPPTYPE_FLOAT = 6 + CPPTYPE_BOOL = 7 + CPPTYPE_ENUM = 8 + CPPTYPE_STRING = 9 + CPPTYPE_MESSAGE = 10 + MAX_CPPTYPE = 10 + + # Must be consistent with C++ FieldDescriptor::Label enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + LABEL_OPTIONAL = 1 + LABEL_REQUIRED = 2 + LABEL_REPEATED = 3 + MAX_LABEL = 3 + + def __init__(self, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None): + """The arguments are as described in the description of FieldDescriptor + attributes above. + + Note that containing_type may be None, and may be set later if necessary + (to deal with circular references between message types, for example). + Likewise for extension_scope. + """ + super(FieldDescriptor, self).__init__(options, 'FieldOptions') + self.name = name + self.full_name = full_name + self.index = index + self.number = number + self.type = type + self.cpp_type = cpp_type + self.label = label + self.default_value = default_value + self.containing_type = containing_type + self.message_type = message_type + self.enum_type = enum_type + self.is_extension = is_extension + self.extension_scope = extension_scope + + +class EnumDescriptor(DescriptorBase): + + """Descriptor for an enum defined in a .proto file. + + An EnumDescriptor instance has the following attributes: + + name: (str) Name of the enum type. + full_name: (str) Full name of the type, including package name + and any enclosing type(s). + filename: (str) Name of the .proto file in which this appears. + + values: (list of EnumValueDescriptors) List of the values + in this enum. + values_by_name: (dict str -> EnumValueDescriptor) Same as |values|, + but indexed by the "name" field of each EnumValueDescriptor. + values_by_number: (dict int -> EnumValueDescriptor) Same as |values|, + but indexed by the "number" field of each EnumValueDescriptor. + containing_type: (Descriptor) Descriptor of the immediate containing + type of this enum, or None if this is an enum defined at the + top level in a .proto file. Set by Descriptor's constructor + if we're passed into one. + options: (descriptor_pb2.EnumOptions) Enum options message or + None to use default enum options. + """ + + def __init__(self, name, full_name, filename, values, + containing_type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumDescriptor, self).__init__(options, 'EnumOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.values = values + for value in self.values: + value.type = self + self.values_by_name = dict((v.name, v) for v in values) + self.values_by_number = dict((v.number, v) for v in values) + self.containing_type = containing_type + + +class EnumValueDescriptor(DescriptorBase): + + """Descriptor for a single value within an enum. + + name: (str) Name of this value. + index: (int) Dense, 0-indexed index giving the order that this + value appears textually within its enum in the .proto file. + number: (int) Actual number assigned to this enum value. + type: (EnumDescriptor) EnumDescriptor to which this value + belongs. Set by EnumDescriptor's constructor if we're + passed into one. + options: (descriptor_pb2.EnumValueOptions) Enum value options message or + None to use default enum value options options. + """ + + def __init__(self, name, index, number, type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') + self.name = name + self.index = index + self.number = number + self.type = type + + +class ServiceDescriptor(DescriptorBase): + + """Descriptor for a service. + + name: (str) Name of the service. + full_name: (str) Full name of the service, including package name. + index: (int) 0-indexed index giving the order that this services + definition appears withing the .proto file. + methods: (list of MethodDescriptor) List of methods provided by this + service. + options: (descriptor_pb2.ServiceOptions) Service options message or + None to use default service options. + """ + + def __init__(self, name, full_name, index, methods, options=None): + super(ServiceDescriptor, self).__init__(options, 'ServiceOptions') + self.name = name + self.full_name = full_name + self.index = index + self.methods = methods + # Set the containing service for each method in this service. + for method in self.methods: + method.containing_service = self + + def FindMethodByName(self, name): + """Searches for the specified method, and returns its descriptor.""" + for method in self.methods: + if name == method.name: + return method + return None + + +class MethodDescriptor(DescriptorBase): + + """Descriptor for a method in a service. + + name: (str) Name of the method within the service. + full_name: (str) Full name of method. + index: (int) 0-indexed index of the method inside the service. + containing_service: (ServiceDescriptor) The service that contains this + method. + input_type: The descriptor of the message that this method accepts. + output_type: The descriptor of the message that this method returns. + options: (descriptor_pb2.MethodOptions) Method options message or + None to use default method options. + """ + + def __init__(self, name, full_name, index, containing_service, + input_type, output_type, options=None): + """The arguments are as described in the description of MethodDescriptor + attributes above. + + Note that containing_service may be None, and may be set later if necessary. + """ + super(MethodDescriptor, self).__init__(options, 'MethodOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_service = containing_service + self.input_type = input_type + self.output_type = output_type + + +def _ParseOptions(message, string): + """Parses serialized options. + + This helper function is used to parse serialized options in generated + proto2 files. It must not be used outside proto2. + """ + message.ParseFromString(string) + return message; diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py new file mode 100755 index 0000000..e69de29 --- /dev/null +++ b/python/google/protobuf/internal/__init__.py diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py new file mode 100755 index 0000000..d8a825d --- /dev/null +++ b/python/google/protobuf/internal/containers.py @@ -0,0 +1,227 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains container classes to represent different protocol buffer types. + +This file defines container classes which represent categories of protocol +buffer field types which need extra maintenance. Currently these categories +are: + - Repeated scalar fields - These are all repeated fields which aren't + composite (e.g. they are of simple types like int32, string, etc). + - Repeated composite fields - Repeated fields which are composite. This + includes groups and nested messages. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class BaseContainer(object): + + """Base container class.""" + + # Minimizes memory usage and disallows assignment to other attributes. + __slots__ = ['_message_listener', '_values'] + + def __init__(self, message_listener): + """ + Args: + message_listener: A MessageListener implementation. + The RepeatedScalarFieldContainer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + """ + self._message_listener = message_listener + self._values = [] + + def __getitem__(self, key): + """Retrieves item by the specified key.""" + return self._values[key] + + def __len__(self): + """Returns the number of elements in the container.""" + return len(self._values) + + def __ne__(self, other): + """Checks if another instance isn't equal to this one.""" + # The concrete classes should define __eq__. + return not self == other + + +class RepeatedScalarFieldContainer(BaseContainer): + + """Simple, type-checked, list-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_type_checker'] + + def __init__(self, message_listener, type_checker): + """ + Args: + message_listener: A MessageListener implementation. + The RepeatedScalarFieldContainer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + type_checker: A type_checkers.ValueChecker instance to run on elements + inserted into this container. + """ + super(RepeatedScalarFieldContainer, self).__init__(message_listener) + self._type_checker = type_checker + + def append(self, value): + """Appends an item to the list. Similar to list.append().""" + self.insert(len(self._values), value) + + def insert(self, key, value): + """Inserts the item at the specified position. Similar to list.insert().""" + self._type_checker.CheckValue(value) + self._values.insert(key, value) + self._message_listener.ByteSizeDirty() + if len(self._values) == 1: + self._message_listener.TransitionToNonempty() + + def extend(self, elem_seq): + """Extends by appending the given sequence. Similar to list.extend().""" + if not elem_seq: + return + + orig_empty = len(self._values) == 0 + new_values = [] + for elem in elem_seq: + self._type_checker.CheckValue(elem) + new_values.append(elem) + self._values.extend(new_values) + self._message_listener.ByteSizeDirty() + if orig_empty: + self._message_listener.TransitionToNonempty() + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.ByteSizeDirty() + + def __setitem__(self, key, value): + """Sets the item on the specified position.""" + # No need to call TransitionToNonempty(), since if we're able to + # set the element at this index, we were already nonempty before + # this method was called. + self._message_listener.ByteSizeDirty() + self._type_checker.CheckValue(value) + self._values[key] = value + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __setslice__(self, start, stop, values): + """Sets the subset of items from between the specified indices.""" + new_values = [] + for value in values: + self._type_checker.CheckValue(value) + new_values.append(value) + self._values[start:stop] = new_values + self._message_listener.ByteSizeDirty() + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.ByteSizeDirty() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.ByteSizeDirty() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + # Special case for the same type which should be common and fast. + if isinstance(other, self.__class__): + return other._values == self._values + # We are presumably comparing against some other sequence type. + return other == self._values + + +class RepeatedCompositeFieldContainer(BaseContainer): + + """Simple, list-like container for holding repeated composite fields.""" + + # Disallows assignment to other attributes. + __slots__ = ['_message_descriptor'] + + def __init__(self, message_listener, message_descriptor): + """ + Note that we pass in a descriptor instead of the generated directly, + since at the time we construct a _RepeatedCompositeFieldContainer we + haven't yet necessarily initialized the type that will be contained in the + container. + + Args: + message_listener: A MessageListener implementation. + The RepeatedCompositeFieldContainer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + message_descriptor: A Descriptor instance describing the protocol type + that should be present in this container. We'll use the + _concrete_class field of this descriptor when the client calls add(). + """ + super(RepeatedCompositeFieldContainer, self).__init__(message_listener) + self._message_descriptor = message_descriptor + + def add(self): + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values.append(new_element) + self._message_listener.ByteSizeDirty() + self._message_listener.TransitionToNonempty() + return new_element + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.ByteSizeDirty() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.ByteSizeDirty() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + return self._values == other._values diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py new file mode 100755 index 0000000..83d6fe0 --- /dev/null +++ b/python/google/protobuf/internal/decoder.py @@ -0,0 +1,209 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Class for decoding protocol buffer primitives. + +Contains the logic for decoding every logical protocol field type +from one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message +from google.protobuf.internal import input_stream +from google.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Decoder(object): + + """Decodes logical protocol buffer fields from the wire.""" + + def __init__(self, s): + """Initializes the decoder to read from s. + + Args: + s: An immutable sequence of bytes, which must be accessible + via the Python buffer() primitive (i.e., buffer(s)). + """ + self._stream = input_stream.InputStream(s) + + def EndOfStream(self): + """Returns true iff we've reached the end of the bytes we're reading.""" + return self._stream.EndOfStream() + + def Position(self): + """Returns the 0-indexed position in |s|.""" + return self._stream.Position() + + def ReadFieldNumberAndWireType(self): + """Reads a tag from the wire. Returns a (field_number, wire_type) pair.""" + tag_and_type = self.ReadUInt32() + return wire_format.UnpackTag(tag_and_type) + + def SkipBytes(self, bytes): + """Skips the specified number of bytes on the wire.""" + self._stream.SkipBytes(bytes) + + # Note that the Read*() methods below are not exactly symmetrical with the + # corresponding Encoder.Append*() methods. Those Encoder methods first + # encode a tag, but the Read*() methods below assume that the tag has already + # been read, and that the client wishes to read a field of the specified type + # starting at the current position. + + def ReadInt32(self): + """Reads and returns a signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarint32() + + def ReadInt64(self): + """Reads and returns a signed, varint-encoded, 64-bit integer.""" + return self._stream.ReadVarint64() + + def ReadUInt32(self): + """Reads and returns an signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarUInt32() + + def ReadUInt64(self): + """Reads and returns an signed, varint-encoded,64-bit integer.""" + return self._stream.ReadVarUInt64() + + def ReadSInt32(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 32-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt32()) + + def ReadSInt64(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 64-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt64()) + + def ReadFixed32(self): + """Reads and returns an unsigned, fixed-width, 32-bit integer.""" + return self._stream.ReadLittleEndian32() + + def ReadFixed64(self): + """Reads and returns an unsigned, fixed-width, 64-bit integer.""" + return self._stream.ReadLittleEndian64() + + def ReadSFixed32(self): + """Reads and returns a signed, fixed-width, 32-bit integer.""" + value = self._stream.ReadLittleEndian32() + if value >= (1 << 31): + value -= (1 << 32) + return value + + def ReadSFixed64(self): + """Reads and returns a signed, fixed-width, 64-bit integer.""" + value = self._stream.ReadLittleEndian64() + if value >= (1 << 63): + value -= (1 << 64) + return value + + def ReadFloat(self): + """Reads and returns a 4-byte floating-point number.""" + serialized = self._stream.ReadBytes(4) + return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0] + + def ReadDouble(self): + """Reads and returns an 8-byte floating-point number.""" + serialized = self._stream.ReadBytes(8) + return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0] + + def ReadBool(self): + """Reads and returns a bool.""" + i = self._stream.ReadVarUInt32() + return bool(i) + + def ReadEnum(self): + """Reads and returns an enum value.""" + return self._stream.ReadVarUInt32() + + def ReadString(self): + """Reads and returns a length-delimited string.""" + bytes = self.ReadBytes() + return unicode(bytes, 'utf-8') + + def ReadBytes(self): + """Reads and returns a length-delimited byte sequence.""" + length = self._stream.ReadVarUInt32() + return self._stream.ReadBytes(length) + + def ReadMessageInto(self, msg): + """Calls msg.MergeFromString() to merge + length-delimited serialized message data into |msg|. + + REQUIRES: The decoder must be positioned at the serialized "length" + prefix to a length-delmiited serialized message. + + POSTCONDITION: The decoder is positioned just after the + serialized message, and we have merged those serialized + contents into |msg|. + """ + length = self._stream.ReadVarUInt32() + sub_buffer = self._stream.GetSubBuffer(length) + num_bytes_used = msg.MergeFromString(sub_buffer) + if num_bytes_used != length: + raise message.DecodeError( + 'Submessage told to deserialize from %d-byte encoding, ' + 'but used only %d bytes' % (length, num_bytes_used)) + self._stream.SkipBytes(num_bytes_used) + + def ReadGroupInto(self, expected_field_number, group): + """Calls group.MergeFromString() to merge + END_GROUP-delimited serialized message data into |group|. + We'll raise an exception if we don't find an END_GROUP + tag immediately after the serialized message contents. + + REQUIRES: The decoder is positioned just after the START_GROUP + tag for this group. + + POSTCONDITION: The decoder is positioned just after the + END_GROUP tag for this group, and we have merged + the contents of the group into |group|. + """ + sub_buffer = self._stream.GetSubBuffer() # No a priori length limit. + num_bytes_used = group.MergeFromString(sub_buffer) + if num_bytes_used < 0: + raise message.DecodeError('Group message reported negative bytes read.') + self._stream.SkipBytes(num_bytes_used) + field_number, field_type = self.ReadFieldNumberAndWireType() + if field_type != wire_format.WIRETYPE_END_GROUP: + raise message.DecodeError('Group message did not end with an END_GROUP.') + if field_number != expected_field_number: + raise message.DecodeError('END_GROUP tag had field ' + 'number %d, was expecting field number %d' % ( + field_number, expected_field_number)) + # We're now positioned just after the END_GROUP tag. Perfect. diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py new file mode 100755 index 0000000..98e4647 --- /dev/null +++ b/python/google/protobuf/internal/decoder_test.py @@ -0,0 +1,256 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.decoder.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +import unittest +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import input_stream +from google.protobuf.internal import wire_format +from google.protobuf import message +import logging +import mox + + +class DecoderTest(unittest.TestCase): + + def setUp(self): + self.mox = mox.Mox() + self.mock_stream = self.mox.CreateMock(input_stream.InputStream) + self.mock_message = self.mox.CreateMock(message.Message) + + def testReadFieldNumberAndWireType(self): + # Test field numbers that will require various varint sizes. + for expected_field_number in (1, 15, 16, 2047, 2048): + for expected_wire_type in range(6): # Highest-numbered wiretype is 5. + e = encoder.Encoder() + e.AppendTag(expected_field_number, expected_wire_type) + s = e.ToString() + d = decoder.Decoder(s) + field_number, wire_type = d.ReadFieldNumberAndWireType() + self.assertEqual(expected_field_number, field_number) + self.assertEqual(expected_wire_type, wire_type) + + def ReadScalarTestHelper(self, test_name, decoder_method, expected_result, + expected_stream_method_name, + stream_method_return, *args): + """Helper for testReadScalars below. + + Calls one of the Decoder.Read*() methods and ensures that the results are + as expected. + + Args: + test_name: Name of this test, used for logging only. + decoder_method: Unbound decoder.Decoder method to call. + expected_result: Value we expect returned from decoder_method(). + expected_stream_method_name: (string) Name of the InputStream + method we expect Decoder to call to actually read the value + on the wire. + stream_method_return: Value our mocked-out stream method should + return to the decoder. + args: Additional arguments that we expect to be passed to the + stream method. + """ + logging.info('Testing %s scalar input.\n' + 'Calling %r(), and expecting that to call the ' + 'stream method %s(%r), which will return %r. Finally, ' + 'expecting the Decoder method to return %r'% ( + test_name, decoder_method, + expected_stream_method_name, args, stream_method_return, + expected_result)) + + d = decoder.Decoder('') + d._stream = self.mock_stream + if decoder_method in (decoder.Decoder.ReadString, + decoder.Decoder.ReadBytes): + self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return)) + # We have to use names instead of methods to work around some + # mox weirdness. (ResetAll() is overzealous). + expected_stream_method = getattr(self.mock_stream, + expected_stream_method_name) + expected_stream_method(*args).AndReturn(stream_method_return) + + self.mox.ReplayAll() + result = decoder_method(d) + self.assertEqual(expected_result, result) + self.assert_(isinstance(result, type(expected_result))) + self.mox.VerifyAll() + self.mox.ResetAll() + + VAL = 1.125 # Perfectly representable as a float (no rounding error). + LITTLE_FLOAT_VAL = '\x00\x00\x90?' + LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?' + + def testReadScalars(self): + test_string = 'I can feel myself getting sutpider.' + scalar_tests = [ + ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0], + ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0], + ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0], + ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0], + ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff, + 'ReadLittleEndian32', 0xffffffff], + ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff, + 'ReadLittleEndian64', 0xffffffffffffffff], + ['sfixed32', decoder.Decoder.ReadSFixed32, long(-1), + 'ReadLittleEndian32', long(0xffffffff)], + ['sfixed64', decoder.Decoder.ReadSFixed64, long(-1), + 'ReadLittleEndian64', 0xffffffffffffffff], + ['float', decoder.Decoder.ReadFloat, self.VAL, + 'ReadBytes', self.LITTLE_FLOAT_VAL, 4], + ['double', decoder.Decoder.ReadDouble, self.VAL, + 'ReadBytes', self.LITTLE_DOUBLE_VAL, 8], + ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1], + ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23], + ['string', decoder.Decoder.ReadString, + unicode(test_string, 'utf-8'), 'ReadBytes', test_string, + len(test_string)], + ['utf8-string', decoder.Decoder.ReadString, + unicode(test_string, 'utf-8'), 'ReadBytes', test_string, + len(test_string)], + ['bytes', decoder.Decoder.ReadBytes, + test_string, 'ReadBytes', test_string, len(test_string)], + # We test zigzag decoding routines more extensively below. + ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1], + ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1], + ] + # Ensure that we're testing different Decoder methods and using + # different test names in all test cases above. + self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) + self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests))) + for args in scalar_tests: + self.ReadScalarTestHelper(*args) + + def testReadMessageInto(self): + length = 23 + def Test(simulate_error): + d = decoder.Decoder('') + d._stream = self.mock_stream + self.mock_stream.ReadVarUInt32().AndReturn(length) + sub_buffer = object() + self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer) + + if simulate_error: + self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1) + self.mox.ReplayAll() + self.assertRaises( + message.DecodeError, d.ReadMessageInto, self.mock_message) + else: + self.mock_message.MergeFromString(sub_buffer).AndReturn(length) + self.mock_stream.SkipBytes(length) + self.mox.ReplayAll() + d.ReadMessageInto(self.mock_message) + + self.mox.VerifyAll() + self.mox.ResetAll() + + Test(simulate_error=False) + Test(simulate_error=True) + + def testReadGroupInto_Success(self): + # Test both the empty and nonempty cases. + for num_bytes in (5, 0): + field_number = expected_field_number = 10 + d = decoder.Decoder('') + d._stream = self.mock_stream + sub_buffer = object() + self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) + self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes) + self.mock_stream.SkipBytes(num_bytes) + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_END_GROUP)) + self.mox.ReplayAll() + d.ReadGroupInto(expected_field_number, self.mock_message) + self.mox.VerifyAll() + self.mox.ResetAll() + + def ReadGroupInto_FailureTestHelper(self, bytes_read): + d = decoder.Decoder('') + d._stream = self.mock_stream + sub_buffer = object() + self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) + self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read) + return d + + def testReadGroupInto_NegativeBytesReported(self): + expected_field_number = 10 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testReadGroupInto_NoEndGroupTag(self): + field_number = expected_field_number = 10 + num_bytes = 5 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) + self.mock_stream.SkipBytes(num_bytes) + # Right field number, wrong wire type. + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testReadGroupInto_WrongFieldNumberInEndGroupTag(self): + expected_field_number = 10 + field_number = expected_field_number + 1 + num_bytes = 5 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) + self.mock_stream.SkipBytes(num_bytes) + # Wrong field number, right wire type. + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_END_GROUP)) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testSkipBytes(self): + d = decoder.Decoder('') + num_bytes = 1024 + self.mock_stream.SkipBytes(num_bytes) + d._stream = self.mock_stream + self.mox.ReplayAll() + d.SkipBytes(num_bytes) + self.mox.VerifyAll() + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py new file mode 100755 index 0000000..eb9f2be --- /dev/null +++ b/python/google/protobuf/internal/descriptor_test.py @@ -0,0 +1,113 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for google.protobuf.internal.descriptor.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor + +class DescriptorTest(unittest.TestCase): + + def setUp(self): + self.my_enum = descriptor.EnumDescriptor( + name='ForeignEnum', + full_name='protobuf_unittest.ForeignEnum', + filename='ForeignEnum', + values=[ + descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), + descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), + descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), + ]) + self.my_message = descriptor.Descriptor( + name='NestedMessage', + full_name='protobuf_unittest.TestAllTypes.NestedMessage', + filename='some/filename/some.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='bb', + full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', + index=0, number=1, + type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None), + ], + nested_types=[], + enum_types=[ + self.my_enum, + ], + extensions=[]) + self.my_method = descriptor.MethodDescriptor( + name='Bar', + full_name='protobuf_unittest.TestService.Bar', + index=0, + containing_service=None, + input_type=None, + output_type=None) + self.my_service = descriptor.ServiceDescriptor( + name='TestServiceWithOptions', + full_name='protobuf_unittest.TestServiceWithOptions', + index=0, + methods=[ + self.my_method + ]) + + def testEnumFixups(self): + self.assertEqual(self.my_enum, self.my_enum.values[0].type) + + def testContainingTypeFixups(self): + self.assertEqual(self.my_message, self.my_message.fields[0].containing_type) + self.assertEqual(self.my_message, self.my_enum.containing_type) + + def testContainingServiceFixups(self): + self.assertEqual(self.my_service, self.my_method.containing_service) + + def testGetOptions(self): + self.assertEqual(self.my_enum.GetOptions(), + descriptor_pb2.EnumOptions()) + self.assertEqual(self.my_enum.values[0].GetOptions(), + descriptor_pb2.EnumValueOptions()) + self.assertEqual(self.my_message.GetOptions(), + descriptor_pb2.MessageOptions()) + self.assertEqual(self.my_message.fields[0].GetOptions(), + descriptor_pb2.FieldOptions()) + self.assertEqual(self.my_method.GetOptions(), + descriptor_pb2.MethodOptions()) + self.assertEqual(self.my_service.GetOptions(), + descriptor_pb2.ServiceOptions()) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py new file mode 100755 index 0000000..3ec3b2b --- /dev/null +++ b/python/google/protobuf/internal/encoder.py @@ -0,0 +1,280 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Class for encoding protocol message primitives. + +Contains the logic for encoding every logical protocol field type +into one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format +from google.protobuf.internal import output_stream + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Encoder(object): + + """Encodes logical protocol buffer fields to the wire format.""" + + def __init__(self): + self._stream = output_stream.OutputStream() + + def ToString(self): + """Returns all values encoded in this object as a string.""" + return self._stream.ToString() + + # Append*NoTag methods. These are necessary for serializing packed + # repeated fields. The Append*() methods call these methods to do + # the actual serialization. + def AppendInt32NoTag(self, value): + """Appends a 32-bit integer to our buffer, varint-encoded.""" + self._stream.AppendVarint32(value) + + def AppendInt64NoTag(self, value): + """Appends a 64-bit integer to our buffer, varint-encoded.""" + self._stream.AppendVarint64(value) + + def AppendUInt32NoTag(self, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" + self._stream.AppendVarUInt32(unsigned_value) + + def AppendUInt64NoTag(self, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" + self._stream.AppendVarUInt64(unsigned_value) + + def AppendSInt32NoTag(self, value): + """Appends a 32-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt32(zigzag_value) + + def AppendSInt64NoTag(self, value): + """Appends a 64-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt64(zigzag_value) + + def AppendFixed32NoTag(self, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, in little-endian + byte-order. + """ + self._stream.AppendLittleEndian32(unsigned_value) + + def AppendFixed64NoTag(self, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, in little-endian + byte-order. + """ + self._stream.AppendLittleEndian64(unsigned_value) + + def AppendSFixed32NoTag(self, value): + """Appends a signed 32-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x80000000) and -1 or 0 + if value >> 32 != sign: + raise message.EncodeError('SFixed32 out of range: %d' % value) + self._stream.AppendLittleEndian32(value & 0xffffffff) + + def AppendSFixed64NoTag(self, value): + """Appends a signed 64-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x8000000000000000) and -1 or 0 + if value >> 64 != sign: + raise message.EncodeError('SFixed64 out of range: %d' % value) + self._stream.AppendLittleEndian64(value & 0xffffffffffffffff) + + def AppendFloatNoTag(self, value): + """Appends a floating-point number to our buffer.""" + self._stream.AppendRawBytes( + struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value)) + + def AppendDoubleNoTag(self, value): + """Appends a double-precision floating-point number to our buffer.""" + self._stream.AppendRawBytes( + struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value)) + + def AppendBoolNoTag(self, value): + """Appends a boolean to our buffer.""" + self.AppendInt32NoTag(value) + + def AppendEnumNoTag(self, value): + """Appends an enum value to our buffer.""" + self.AppendInt32NoTag(value) + + + # All the Append*() methods below first append a tag+type pair to the buffer + # before appending the specified value. + + def AppendInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendInt32NoTag(value) + + def AppendInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendInt64NoTag(value) + + def AppendUInt32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendUInt32NoTag(unsigned_value) + + def AppendUInt64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendUInt64NoTag(unsigned_value) + + def AppendSInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendSInt32NoTag(value) + + def AppendSInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendSInt64NoTag(value) + + def AppendFixed32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendFixed32NoTag(unsigned_value) + + def AppendFixed64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendFixed64NoTag(unsigned_value) + + def AppendSFixed32(self, field_number, value): + """Appends a signed 32-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendSFixed32NoTag(value) + + def AppendSFixed64(self, field_number, value): + """Appends a signed 64-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendSFixed64NoTag(value) + + def AppendFloat(self, field_number, value): + """Appends a floating-point number to our buffer.""" + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendFloatNoTag(value) + + def AppendDouble(self, field_number, value): + """Appends a double-precision floating-point number to our buffer.""" + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendDoubleNoTag(value) + + def AppendBool(self, field_number, value): + """Appends a boolean to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendEnum(self, field_number, value): + """Appends an enum value to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendString(self, field_number, value): + """Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer, + with the length varint-encoded. + """ + self.AppendBytes(field_number, value.encode('utf-8')) + + def AppendBytes(self, field_number, value): + """Appends a length-prefixed sequence of bytes to our buffer, with the + length varint-encoded. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(len(value)) + self._stream.AppendRawBytes(value) + + # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to + # avoid the extra string copy here. We can do so if we widen the Message + # interface to be able to serialize to a stream in addition to a string. The + # challenge when thinking ahead to the Python/C API implementation of Message + # is finding a stream-like Python thing to which we can write raw bytes + # from C. I'm not sure such a thing exists(?). (array.array is pretty much + # what we want, but it's not directly exposed in the Python/C API). + + def AppendGroup(self, field_number, group): + """Appends a group to our buffer. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) + self._stream.AppendRawBytes(group.SerializeToString()) + self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) + + def AppendMessage(self, field_number, msg): + """Appends a nested message to our buffer. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(msg.ByteSize()) + self._stream.AppendRawBytes(msg.SerializeToString()) + + def AppendMessageSetItem(self, field_number, msg): + """Appends an item using the message set wire format. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + self.AppendTag(1, wire_format.WIRETYPE_START_GROUP) + self.AppendInt32(2, field_number) + self.AppendMessage(3, msg) + self.AppendTag(1, wire_format.WIRETYPE_END_GROUP) + + def AppendTag(self, field_number, wire_type): + """Appends a tag containing field number and wire type information.""" + self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py new file mode 100755 index 0000000..bf75ea8 --- /dev/null +++ b/python/google/protobuf/internal/encoder_test.py @@ -0,0 +1,286 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.encoder.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +import logging +import unittest +from google.protobuf.internal import wire_format +from google.protobuf.internal import encoder +from google.protobuf.internal import output_stream +from google.protobuf import message +import mox + + +class EncoderTest(unittest.TestCase): + + def setUp(self): + self.mox = mox.Mox() + self.encoder = encoder.Encoder() + self.mock_stream = self.mox.CreateMock(output_stream.OutputStream) + self.mock_message = self.mox.CreateMock(message.Message) + self.encoder._stream = self.mock_stream + + def PackTag(self, field_number, wire_type): + return wire_format.PackTag(field_number, wire_type) + + def AppendScalarTestHelper(self, test_name, encoder_method, + expected_stream_method_name, + wire_type, field_value, + expected_value=None, expected_length=None, + is_tag_test=True): + """Helper for testAppendScalars. + + Calls one of the Encoder methods, and ensures that the Encoder + in turn makes the expected calls into its OutputStream. + + Args: + test_name: Name of this test, used only for logging. + encoder_method: Callable on self.encoder. This is the Encoder + method we're testing. If is_tag_test=True, the encoder method + accepts a field_number and field_value. if is_tag_test=False, + the encoder method accepts a field_value. + expected_stream_method_name: (string) Name of the OutputStream + method we expect Encoder to call to actually put the value + on the wire. + wire_type: The WIRETYPE_* constant we expect encoder to + use in the specified encoder_method. + field_value: The value we're trying to encode. Passed + into encoder_method. + expected_value: The value we expect Encoder to pass into + the OutputStream method. If None, we expect field_value + to pass through unmodified. + expected_length: The length we expect Encoder to pass to the + AppendVarUInt32 method. If None we expect the length of the + field_value. + is_tag_test: A Boolean. If True (the default), we append the + the packed field number and wire_type to the stream before + the field value. + """ + if expected_value is None: + expected_value = field_value + + logging.info('Testing %s scalar output.\n' + 'Calling %r(%r), and expecting that to call the ' + 'stream method %s(%r).' % ( + test_name, encoder_method, field_value, + expected_stream_method_name, expected_value)) + + if is_tag_test: + field_number = 10 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type)) + # If we're length-delimited, we should then append the length. + if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + if expected_length is None: + expected_length = len(field_value) + self.mock_stream.AppendVarUInt32(expected_length) + + # Should then append the value itself. + # We have to use names instead of methods to work around some + # mox weirdness. (ResetAll() is overzealous). + expected_stream_method = getattr(self.mock_stream, + expected_stream_method_name) + expected_stream_method(expected_value) + + self.mox.ReplayAll() + if is_tag_test: + encoder_method(field_number, field_value) + else: + encoder_method(field_value) + self.mox.VerifyAll() + self.mox.ResetAll() + + VAL = 1.125 # Perfectly representable as a float (no rounding error). + LITTLE_FLOAT_VAL = '\x00\x00\x90?' + LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?' + + def testAppendScalars(self): + utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82' + utf8_string = unicode(utf8_bytes, 'utf-8') + scalar_tests = [ + ['int32', self.encoder.AppendInt32, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, 0], + ['int64', self.encoder.AppendInt64, 'AppendVarint64', + wire_format.WIRETYPE_VARINT, 0], + ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32', + wire_format.WIRETYPE_VARINT, 0], + ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64', + wire_format.WIRETYPE_VARINT, 0], + ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32', + wire_format.WIRETYPE_FIXED32, 0], + ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64', + wire_format.WIRETYPE_FIXED64, 0], + ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32', + wire_format.WIRETYPE_FIXED32, -1, 0xffffffff], + ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64', + wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff], + ['float', self.encoder.AppendFloat, 'AppendRawBytes', + wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL], + ['double', self.encoder.AppendDouble, 'AppendRawBytes', + wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL], + ['bool', self.encoder.AppendBool, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, False], + ['enum', self.encoder.AppendEnum, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, 0], + ['string', self.encoder.AppendString, 'AppendRawBytes', + wire_format.WIRETYPE_LENGTH_DELIMITED, + "You're in a maze of twisty little passages, all alike."], + ['utf8-string', self.encoder.AppendString, 'AppendRawBytes', + wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string, + utf8_bytes, len(utf8_bytes)], + # We test zigzag encoding routines more extensively below. + ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32', + wire_format.WIRETYPE_VARINT, -1, 1], + ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64', + wire_format.WIRETYPE_VARINT, -1, 1], + ] + # Ensure that we're testing different Encoder methods and using + # different test names in all test cases above. + self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) + self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests))) + for args in scalar_tests: + self.AppendScalarTestHelper(*args) + + def testAppendScalarsWithoutTags(self): + scalar_no_tag_tests = [ + ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0], + ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0], + ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0], + ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0], + ['fixed32', self.encoder.AppendFixed32NoTag, + 'AppendLittleEndian32', None, 0], + ['fixed64', self.encoder.AppendFixed64NoTag, + 'AppendLittleEndian64', None, 0], + ['sfixed32', self.encoder.AppendSFixed32NoTag, + 'AppendLittleEndian32', None, 0], + ['sfixed64', self.encoder.AppendSFixed64NoTag, + 'AppendLittleEndian64', None, 0], + ['float', self.encoder.AppendFloatNoTag, + 'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL], + ['double', self.encoder.AppendDoubleNoTag, + 'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL], + ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0], + ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0], + ['sint32', self.encoder.AppendSInt32NoTag, + 'AppendVarUInt32', None, -1, 1], + ['sint64', self.encoder.AppendSInt64NoTag, + 'AppendVarUInt64', None, -1, 1], + ] + + self.assertEqual(len(scalar_no_tag_tests), + len(set(t[0] for t in scalar_no_tag_tests))) + self.assert_(len(scalar_no_tag_tests) >= + len(set(t[1] for t in scalar_no_tag_tests))) + for args in scalar_no_tag_tests: + # For no tag tests, the wire_type is not used, so we put in None. + self.AppendScalarTestHelper(is_tag_test=False, *args) + + def testAppendGroup(self): + field_number = 23 + # Should first append the start-group marker. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP)) + # Should then serialize itself. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + # Should finally append the end-group marker. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP)) + + self.mox.ReplayAll() + self.encoder.AppendGroup(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendMessage(self): + field_number = 23 + byte_size = 42 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) + # Should then append its length. + self.mock_message.ByteSize().AndReturn(byte_size) + self.mock_stream.AppendVarUInt32(byte_size) + # Should then serialize itself to the encoder. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + + self.mox.ReplayAll() + self.encoder.AppendMessage(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendMessageSetItem(self): + field_number = 23 + byte_size = 42 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32( + self.PackTag(1, wire_format.WIRETYPE_START_GROUP)) + self.mock_stream.AppendVarUInt32( + self.PackTag(2, wire_format.WIRETYPE_VARINT)) + self.mock_stream.AppendVarint32(field_number) + self.mock_stream.AppendVarUInt32( + self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED)) + # Should then append its length. + self.mock_message.ByteSize().AndReturn(byte_size) + self.mock_stream.AppendVarUInt32(byte_size) + # Should then serialize itself to the encoder. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + self.mock_stream.AppendVarUInt32( + self.PackTag(1, wire_format.WIRETYPE_END_GROUP)) + + self.mox.ReplayAll() + self.encoder.AppendMessageSetItem(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendSFixed(self): + # Most of our bounds-checking is done in output_stream.py, + # but encoder.py is responsible for transforming signed + # fixed-width integers into unsigned ones, so we test here + # to ensure that we're not losing any entropy when we do + # that conversion. + field_number = 10 + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, + 10, wire_format.UINT32_MAX + 1) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, + 10, -(1 << 32)) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, + 10, wire_format.UINT64_MAX + 1) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, + 10, -(1 << 64)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py new file mode 100755 index 0000000..11fcfa0 --- /dev/null +++ b/python/google/protobuf/internal/generator_test.py @@ -0,0 +1,100 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py +# first, since it's testing the subtler code, and since it provides decent +# indirect testing of the protocol compiler output. + +"""Unittest that directly tests the output of the pure-Python protocol +compiler. See //net/proto2/internal/reflection_test.py for a test which +further ensures that we can use Python protocol message objects as we expect. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 + + +class GeneratorTest(unittest.TestCase): + + def testNestedMessageDescriptor(self): + field_name = 'optional_nested_message' + proto_type = unittest_pb2.TestAllTypes + self.assertEqual( + proto_type.NestedMessage.DESCRIPTOR, + proto_type.DESCRIPTOR.fields_by_name[field_name].message_type) + + def testEnums(self): + # We test only module-level enums here. + # TODO(robinson): Examine descriptors directly to check + # enum descriptor output. + self.assertEqual(4, unittest_pb2.FOREIGN_FOO) + self.assertEqual(5, unittest_pb2.FOREIGN_BAR) + self.assertEqual(6, unittest_pb2.FOREIGN_BAZ) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testContainingTypeBehaviorForExtensions(self): + self.assertEqual(unittest_pb2.optional_int32_extension.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + self.assertEqual(unittest_pb2.TestRequired.single.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + + def testExtensionScope(self): + self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope, + None) + self.assertEqual(unittest_pb2.TestRequired.single.extension_scope, + unittest_pb2.TestRequired.DESCRIPTOR) + + def testIsExtension(self): + self.assertTrue(unittest_pb2.optional_int32_extension.is_extension) + self.assertTrue(unittest_pb2.TestRequired.single.is_extension) + + message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR + non_extension_descriptor = message_descriptor.fields_by_name['a'] + self.assertTrue(not non_extension_descriptor.is_extension) + + def testOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py new file mode 100755 index 0000000..7bda17e --- /dev/null +++ b/python/google/protobuf/internal/input_stream.py @@ -0,0 +1,338 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""InputStream is the primitive interface for reading bits from the wire. + +All protocol buffer deserialization can be expressed in terms of +the InputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import array +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedInputStream from the C++ +# proto2 implementation. + + +class InputStreamBuffer(object): + + """Contains all logic for reading bits, and dealing with stream position. + + If an InputStream method ever raises an exception, the stream is left + in an indeterminate state and is not safe for further use. + """ + + def __init__(self, s): + # What we really want is something like array('B', s), where elements we + # read from the array are already given to us as one-byte integers. BUT + # using array() instead of buffer() would force full string copies to result + # from each GetSubBuffer() call. + # + # So, if the N serialized bytes of a single protocol buffer object are + # split evenly between 2 child messages, and so on recursively, using + # array('B', s) instead of buffer() would incur an additional N*logN bytes + # copied during deserialization. + # + # The higher constant overhead of having to ord() for every byte we read + # from the buffer in _ReadVarintHelper() could definitely lead to worse + # performance in many real-world scenarios, even if the asymptotic + # complexity is better. However, our real answer is that the mythical + # Python/C extension module output mode for the protocol compiler will + # be blazing-fast and will eliminate most use of this class anyway. + self._buffer = buffer(s) + self._pos = 0 + + def EndOfStream(self): + """Returns true iff we're at the end of the stream. + If this returns true, then a call to any other InputStream method + will raise an exception. + """ + return self._pos >= len(self._buffer) + + def Position(self): + """Returns the current position in the stream, or equivalently, the + number of bytes read so far. + """ + return self._pos + + def GetSubBuffer(self, size=None): + """Returns a sequence-like object that represents a portion of our + underlying sequence. + + Position 0 in the returned object corresponds to self.Position() + in this stream. + + If size is specified, then the returned object ends after the + next "size" bytes in this stream. If size is not specified, + then the returned object ends at the end of this stream. + + We guarantee that the returned object R supports the Python buffer + interface (and thus that the call buffer(R) will work). + + Note that the returned buffer is read-only. + + The intended use for this method is for nested-message and nested-group + deserialization, where we want to make a recursive MergeFromString() + call on the portion of the original sequence that contains the serialized + nested message. (And we'd like to do so without making unnecessary string + copies). + + REQUIRES: size is nonnegative. + """ + # Note that buffer() doesn't perform any actual string copy. + if size is None: + return buffer(self._buffer, self._pos) + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return buffer(self._buffer, self._pos, size) + + def SkipBytes(self, num_bytes): + """Skip num_bytes bytes ahead, or go to the end of the stream, whichever + comes first. + + REQUIRES: num_bytes is nonnegative. + """ + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadBytes(self, size): + """Reads up to 'size' bytes from the stream, stopping early + only if we reach the end of the stream. Returns the bytes read + as a string. + """ + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = (self._buffer[self._pos : self._pos + size]) + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + """Interprets the next 4 bytes of the stream as a little-endian + encoded, unsiged 32-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + """Interprets the next 8 bytes of the stream as a little-endian + encoded, unsiged 64-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + """Reads a varint from the stream, interprets this varint + as a signed, 32-bit integer, and returns the integer. + """ + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 32-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + """Reads a varint from the stream, interprets this varint + as a signed, 64-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 64-bit integer, and returns the integer. + """ + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + """Helper for the various varint-reading methods above. + Reads an unsigned, varint-encoded integer from the stream and + returns this integer. + + Does no bounds checking except to ensure that we read at most as many bytes + as could possibly be present in a varint-encoded 64-bit number. + """ + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = ord(self._buffer[self._pos]) + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result + + +class InputStreamArray(object): + + """Contains all logic for reading bits, and dealing with stream position. + + If an InputStream method ever raises an exception, the stream is left + in an indeterminate state and is not safe for further use. + + This alternative to InputStreamBuffer is used in environments where buffer() + is unavailble, such as Google App Engine. + """ + + def __init__(self, s): + self._buffer = array.array('B', s) + self._pos = 0 + + def EndOfStream(self): + return self._pos >= len(self._buffer) + + def Position(self): + return self._pos + + def GetSubBuffer(self, size=None): + if size is None: + return self._buffer[self._pos : ].tostring() + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return self._buffer[self._pos : self._pos + size].tostring() + + def SkipBytes(self, num_bytes): + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadBytes(self, size): + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = self._buffer[self._pos : self._pos + size].tostring() + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = self._buffer[self._pos] + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result + + +try: + buffer('') + InputStream = InputStreamBuffer +except NotImplementedError: + # Google App Engine: dev_appserver.py + InputStream = InputStreamArray +except RuntimeError: + # Google App Engine: production + InputStream = InputStreamArray diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py new file mode 100755 index 0000000..ecec7f7 --- /dev/null +++ b/python/google/protobuf/internal/input_stream_test.py @@ -0,0 +1,314 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.input_stream.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import wire_format +from google.protobuf.internal import input_stream + + +class InputStreamBufferTest(unittest.TestCase): + + def setUp(self): + self.__original_input_stream = input_stream.InputStream + input_stream.InputStream = input_stream.InputStreamBuffer + + def tearDown(self): + input_stream.InputStream = self.__original_input_stream + + def testEndOfStream(self): + stream = input_stream.InputStream('abcd') + self.assertFalse(stream.EndOfStream()) + self.assertEqual('abcd', stream.ReadBytes(10)) + self.assertTrue(stream.EndOfStream()) + + def testPosition(self): + stream = input_stream.InputStream('abcd') + self.assertEqual(0, stream.Position()) + self.assertEqual(0, stream.Position()) # No side-effects. + stream.ReadBytes(1) + self.assertEqual(1, stream.Position()) + stream.ReadBytes(1) + self.assertEqual(2, stream.Position()) + stream.ReadBytes(10) + self.assertEqual(4, stream.Position()) # Can't go past end of stream. + + def testGetSubBuffer(self): + stream = input_stream.InputStream('abcd') + # Try leaving out the size. + self.assertEqual('abcd', str(stream.GetSubBuffer())) + stream.SkipBytes(1) + # GetSubBuffer() always starts at current size. + self.assertEqual('bcd', str(stream.GetSubBuffer())) + # Try 0-size. + self.assertEqual('', str(stream.GetSubBuffer(0))) + # Negative sizes should raise an error. + self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1) + # Positive sizes should work as expected. + self.assertEqual('b', str(stream.GetSubBuffer(1))) + self.assertEqual('bc', str(stream.GetSubBuffer(2))) + # Sizes longer than remaining bytes in the buffer should + # return the whole remaining buffer. + self.assertEqual('bcd', str(stream.GetSubBuffer(1000))) + + def testSkipBytes(self): + stream = input_stream.InputStream('') + # Skipping bytes when at the end of stream + # should have no effect. + stream.SkipBytes(0) + stream.SkipBytes(1) + stream.SkipBytes(2) + self.assertTrue(stream.EndOfStream()) + self.assertEqual(0, stream.Position()) + + # Try skipping within a stream. + stream = input_stream.InputStream('abcd') + self.assertEqual(0, stream.Position()) + stream.SkipBytes(1) + self.assertEqual(1, stream.Position()) + stream.SkipBytes(10) # Can't skip past the end. + self.assertEqual(4, stream.Position()) + + # Ensure that a negative skip raises an exception. + stream = input_stream.InputStream('abcd') + stream.SkipBytes(1) + self.assertRaises(message.DecodeError, stream.SkipBytes, -1) + + def testReadBytes(self): + s = 'abcd' + # Also test going past the total stream length. + for i in range(len(s) + 10): + stream = input_stream.InputStream(s) + self.assertEqual(s[:i], stream.ReadBytes(i)) + self.assertEqual(min(i, len(s)), stream.Position()) + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadBytes, -1) + + def EnsureFailureOnEmptyStream(self, input_stream_method): + """Helper for integer-parsing tests below. + Ensures that the given InputStream method raises a DecodeError + if called on a stream with no bytes remaining. + """ + stream = input_stream.InputStream('') + self.assertRaises(message.DecodeError, input_stream_method, stream) + + def testReadLittleEndian32(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32) + s = '' + # Read 0. + s += '\x00\x00\x00\x00' + # Read 1. + s += '\x01\x00\x00\x00' + # Read a bunch of different bytes. + s += '\x01\x02\x03\x04' + # Read max unsigned 32-bit int. + s += '\xff\xff\xff\xff' + # Try a read with fewer than 4 bytes left in the stream. + s += '\x00\x00\x00' + stream = input_stream.InputStream(s) + self.assertEqual(0, stream.ReadLittleEndian32()) + self.assertEqual(4, stream.Position()) + self.assertEqual(1, stream.ReadLittleEndian32()) + self.assertEqual(8, stream.Position()) + self.assertEqual(0x04030201, stream.ReadLittleEndian32()) + self.assertEqual(12, stream.Position()) + self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32()) + self.assertEqual(16, stream.Position()) + self.assertRaises(message.DecodeError, stream.ReadLittleEndian32) + + def testReadLittleEndian64(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64) + s = '' + # Read 0. + s += '\x00\x00\x00\x00\x00\x00\x00\x00' + # Read 1. + s += '\x01\x00\x00\x00\x00\x00\x00\x00' + # Read a bunch of different bytes. + s += '\x01\x02\x03\x04\x05\x06\x07\x08' + # Read max unsigned 64-bit int. + s += '\xff\xff\xff\xff\xff\xff\xff\xff' + # Try a read with fewer than 8 bytes left in the stream. + s += '\x00\x00\x00' + stream = input_stream.InputStream(s) + self.assertEqual(0, stream.ReadLittleEndian64()) + self.assertEqual(8, stream.Position()) + self.assertEqual(1, stream.ReadLittleEndian64()) + self.assertEqual(16, stream.Position()) + self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64()) + self.assertEqual(24, stream.Position()) + self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64()) + self.assertEqual(32, stream.Position()) + self.assertRaises(message.DecodeError, stream.ReadLittleEndian64) + + def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method): + """Helper for tests below that test successful reads of various varints. + + Args: + varints_and_ints: Iterable of (str, integer) pairs, where the string + gives the wire encoding and the integer gives the value we expect + to be returned by the read_method upon encountering this string. + read_method: Unbound InputStream method that is capable of reading + the encoded strings provided in the first elements of varints_and_ints. + """ + s = ''.join(s for s, i in varints_and_ints) + stream = input_stream.InputStream(s) + expected_pos = 0 + self.assertEqual(expected_pos, stream.Position()) + for s, expected_int in varints_and_ints: + self.assertEqual(expected_int, read_method(stream)) + expected_pos += len(s) + self.assertEqual(expected_pos, stream.Position()) + + def testReadVarint32Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), + ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX), + ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarint32) + + def testReadVarint32Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32) + + # Try and fail to read INT32_MAX + 1. + s = '\x80\x80\x80\x80\x08' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + # Try and fail to read INT32_MIN - 1. + s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + def testReadVarUInt32Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarUInt32) + + def testReadVarUInt32Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32) + # Try and fail to read UINT32_MAX + 1 + s = '\x80\x80\x80\x80\x10' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt32) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt32) + + def testReadVarint64Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX), + ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarint64) + + def testReadVarint64Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64) + # Try and fail to read something with the mythical 64th bit set. + s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint64) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint64) + + def testReadVarUInt64Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarUInt64) + + def testReadVarUInt64Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64) + # Try and fail to read something with the mythical 64th bit set. + s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt64) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt64) + + +class InputStreamArrayTest(InputStreamBufferTest): + + def setUp(self): + # Test InputStreamArray against the same tests in InputStreamBuffer + self.__original_input_stream = input_stream.InputStream + input_stream.InputStream = input_stream.InputStreamArray + + def tearDown(self): + input_stream.InputStream = self.__original_input_stream + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py new file mode 100755 index 0000000..4397895 --- /dev/null +++ b/python/google/protobuf/internal/message_listener.py @@ -0,0 +1,69 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a listener interface for observing certain +state transitions on Message objects. + +Also defines a null implementation of this interface. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +class MessageListener(object): + + """Listens for transitions to nonempty and for invalidations of cached + byte sizes. Meant to be registered via Message._SetListener(). + """ + + def TransitionToNonempty(self): + """Called the *first* time that this message becomes nonempty. + Implementations are free (but not required) to call this method multiple + times after the message has become nonempty. + """ + raise NotImplementedError + + def ByteSizeDirty(self): + """Called *every* time the cached byte size value + for this object is invalidated (transitions from being + "clean" to "dirty"). + """ + raise NotImplementedError + + +class NullMessageListener(object): + + """No-op MessageListener implementation.""" + + def TransitionToNonempty(self): + pass + + def ByteSizeDirty(self): + pass diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py new file mode 100755 index 0000000..df344cf --- /dev/null +++ b/python/google/protobuf/internal/message_test.py @@ -0,0 +1,53 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests python protocol buffers against the golden message.""" + +__author__ = 'gps@google.com (Gregory P. Smith)' + +import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import test_util + + +class MessageTest(test_util.GoldenMessageTestCase): + + def testGoldenMessage(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.ExpectAllFieldsSet(golden_message) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto new file mode 100644 index 0000000..e2d9701 --- /dev/null +++ b/python/google/protobuf/internal/more_extensions.proto @@ -0,0 +1,58 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: robinson@google.com (Will Robinson) + + +package google.protobuf.internal; + + +message TopLevelMessage { + optional ExtendedMessage submessage = 1; +} + + +message ExtendedMessage { + extensions 1 to max; +} + + +message ForeignMessage { + optional int32 foreign_message_int = 1; +} + + +extend ExtendedMessage { + optional int32 optional_int_extension = 1; + optional ForeignMessage optional_message_extension = 2; + + repeated int32 repeated_int_extension = 3; + repeated ForeignMessage repeated_message_extension = 4; +} diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto new file mode 100644 index 0000000..c701b44 --- /dev/null +++ b/python/google/protobuf/internal/more_messages.proto @@ -0,0 +1,51 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: robinson@google.com (Will Robinson) + + +package google.protobuf.internal; + +// A message where tag numbers are listed out of order, to allow us to test our +// canonicalization of serialized output, which should always be in tag order. +// We also mix in some extensions for extra fun. +message OutOfOrderFields { + optional sint32 optional_sint32 = 5; + extensions 4 to 4; + optional uint32 optional_uint32 = 3; + extensions 2 to 2; + optional int32 optional_int32 = 1; +}; + + +extend OutOfOrderFields { + optional uint64 optional_uint64 = 4; + optional int64 optional_int64 = 2; +} diff --git a/python/google/protobuf/internal/output_stream.py b/python/google/protobuf/internal/output_stream.py new file mode 100755 index 0000000..6c2d6f6 --- /dev/null +++ b/python/google/protobuf/internal/output_stream.py @@ -0,0 +1,125 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""OutputStream is the primitive interface for sticking bits on the wire. + +All protocol buffer serialization can be expressed in terms of +the OutputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import array +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedOutputStream from the C++ +# proto2 implementation. + + +class OutputStream(object): + + """Contains all logic for writing bits, and ToString() to get the result.""" + + def __init__(self): + self._buffer = array.array('B') + + def AppendRawBytes(self, raw_bytes): + """Appends raw_bytes to our internal buffer.""" + self._buffer.fromstring(raw_bytes) + + def AppendLittleEndian32(self, unsigned_value): + """Appends an unsigned 32-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT32_MAX: + raise message.EncodeError( + 'Unsigned 32-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value)) + + def AppendLittleEndian64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError( + 'Unsigned 64-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value)) + + def AppendVarint32(self, value): + """Appends a signed 32-bit integer to the internal buffer, + encoded as a varint. (Note that a negative varint32 will + always require 10 bytes of space.) + """ + if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarint64(value) + + def AppendVarUInt32(self, value): + """Appends an unsigned 32-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= value <= wire_format.UINT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarUInt64(value) + + def AppendVarint64(self, value): + """Appends a signed 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX: + raise message.EncodeError('Value out of range: %d' % value) + if value < 0: + value += (1 << 64) + self.AppendVarUInt64(value) + + def AppendVarUInt64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % unsigned_value) + while True: + bits = unsigned_value & 0x7f + unsigned_value >>= 7 + if not unsigned_value: + self._buffer.append(bits) + break + self._buffer.append(0x80|bits) + + def ToString(self): + """Returns a string containing the bytes in our internal buffer.""" + return self._buffer.tostring() diff --git a/python/google/protobuf/internal/output_stream_test.py b/python/google/protobuf/internal/output_stream_test.py new file mode 100755 index 0000000..df92eec --- /dev/null +++ b/python/google/protobuf/internal/output_stream_test.py @@ -0,0 +1,178 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.output_stream.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import output_stream +from google.protobuf.internal import wire_format + + +class OutputStreamTest(unittest.TestCase): + + def setUp(self): + self.stream = output_stream.OutputStream() + + def testAppendRawBytes(self): + # Empty string. + self.stream.AppendRawBytes('') + self.assertEqual('', self.stream.ToString()) + + # Nonempty string. + self.stream.AppendRawBytes('abc') + self.assertEqual('abc', self.stream.ToString()) + + # Ensure that we're actually appending. + self.stream.AppendRawBytes('def') + self.assertEqual('abcdef', self.stream.ToString()) + + def AppendNumericTestHelper(self, append_fn, values_and_strings): + """For each (value, expected_string) pair in values_and_strings, + calls an OutputStream.Append*(value) method on an OutputStream and ensures + that the string written to that stream matches expected_string. + + Args: + append_fn: Unbound OutputStream method that takes an integer or + long value as input. + values_and_strings: Iterable of (value, expected_string) pairs. + """ + for conversion in (int, long): + for value, string in values_and_strings: + stream = output_stream.OutputStream() + expected_string = '' + append_fn(stream, conversion(value)) + expected_string += string + self.assertEqual(expected_string, stream.ToString()) + + def AppendOverflowTestHelper(self, append_fn, value): + """Calls an OutputStream.Append*(value) method and asserts + that the method raises message.EncodeError. + + Args: + append_fn: Unbound OutputStream method that takes an integer or + long value as input. + value: Value to pass to append_fn which should cause an + message.EncodeError. + """ + stream = output_stream.OutputStream() + self.assertRaises(message.EncodeError, append_fn, stream, value) + + def testAppendLittleEndian32(self): + append_fn = output_stream.OutputStream.AppendLittleEndian32 + values_and_expected_strings = [ + (0, '\x00\x00\x00\x00'), + (1, '\x01\x00\x00\x00'), + ((1 << 32) - 1, '\xff\xff\xff\xff'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, 1 << 32) + self.AppendOverflowTestHelper(append_fn, -1) + + def testAppendLittleEndian64(self): + append_fn = output_stream.OutputStream.AppendLittleEndian64 + values_and_expected_strings = [ + (0, '\x00\x00\x00\x00\x00\x00\x00\x00'), + (1, '\x01\x00\x00\x00\x00\x00\x00\x00'), + ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, 1 << 64) + self.AppendOverflowTestHelper(append_fn, -1) + + def testAppendVarint32(self): + append_fn = output_stream.OutputStream.AppendVarint32 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'), + (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1) + self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1) + + def testAppendVarUInt32(self): + append_fn = output_stream.OutputStream.AppendVarUInt32 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, -1) + self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1) + + def testAppendVarint64(self): + append_fn = output_stream.OutputStream.AppendVarint64 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'), + (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1) + self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1) + + def testAppendVarUInt64(self): + append_fn = output_stream.OutputStream.AppendVarUInt64 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, -1) + self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py new file mode 100755 index 0000000..8610177 --- /dev/null +++ b/python/google/protobuf/internal/reflection_test.py @@ -0,0 +1,1976 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which also indirectly tests the output of the +pure-Python protocol compiler. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import operator + +import unittest +# TODO(robinson): When we split this test in two, only some of these imports +# will be necessary in each test. +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import message +from google.protobuf import reflection +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import wire_format +from google.protobuf.internal import test_util +from google.protobuf.internal import decoder + + +class ReflectionTest(unittest.TestCase): + + def assertIs(self, values, others): + self.assertEqual(len(values), len(others)) + for i in range(len(values)): + self.assertTrue(values[i] is others[i]) + + def testSimpleHasBits(self): + # Test a scalar. + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_int32')) + self.assertEqual(0, proto.optional_int32) + # HasField() shouldn't be true if all we've done is + # read the default value. + self.assertTrue(not proto.HasField('optional_int32')) + proto.optional_int32 = 1 + # Setting a value however *should* set the "has" bit. + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + # And clearing that value should unset the "has" bit. + self.assertTrue(not proto.HasField('optional_int32')) + + def testHasBitsWithSinglyNestedScalar(self): + # Helper used to test foreign messages and groups. + # + # composite_field_name should be the name of a non-repeated + # composite (i.e., foreign or group) field in TestAllTypes, + # and scalar_field_name should be the name of an integer-valued + # scalar field within that composite. + # + # I never thought I'd miss C++ macros and templates so much. :( + # This helper is semantically just: + # + # assert proto.composite_field.scalar_field == 0 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # + # proto.composite_field.scalar_field = 10 + # old_composite_field = proto.composite_field + # + # assert proto.composite_field.scalar_field == 10 + # assert proto.composite_field.HasField('scalar_field') + # assert proto.HasField('composite_field') + # + # proto.ClearField('composite_field') + # + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # assert proto.composite_field.scalar_field == 0 + # + # # Now ensure that ClearField('composite_field') disconnected + # # the old field object from the object tree... + # assert old_composite_field is not proto.composite_field + # old_composite_field.scalar_field = 20 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + def TestCompositeHasBits(composite_field_name, scalar_field_name): + proto = unittest_pb2.TestAllTypes() + # First, check that we can get the scalar value, and see that it's the + # default (0), but that proto.HasField('omposite') and + # proto.composite.HasField('scalar') will still return False. + composite_field = getattr(proto, composite_field_name) + original_scalar_value = getattr(composite_field, scalar_field_name) + self.assertEqual(0, original_scalar_value) + # Assert that the composite object does not "have" the scalar. + self.assertTrue(not composite_field.HasField(scalar_field_name)) + # Assert that proto does not "have" the composite field. + self.assertTrue(not proto.HasField(composite_field_name)) + + # Now set the scalar within the composite field. Ensure that the setting + # is reflected, and that proto.HasField('composite') and + # proto.composite.HasField('scalar') now both return True. + new_val = 20 + setattr(composite_field, scalar_field_name, new_val) + self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) + # Hold on to a reference to the current composite_field object. + old_composite_field = composite_field + # Assert that the has methods now return true. + self.assertTrue(composite_field.HasField(scalar_field_name)) + self.assertTrue(proto.HasField(composite_field_name)) + + # Now call the clear method... + proto.ClearField(composite_field_name) + + # ...and ensure that the "has" bits are all back to False... + composite_field = getattr(proto, composite_field_name) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + # ...and ensure that the scalar field has returned to its default. + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Finally, ensure that modifications to the old composite field object + # don't have any effect on the parent. + # + # (NOTE that when we clear the composite field in the parent, we actually + # don't recursively clear down the tree. Instead, we just disconnect the + # cleared composite from the tree.) + self.assertTrue(old_composite_field is not composite_field) + setattr(old_composite_field, scalar_field_name, new_val) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Test simple, single-level nesting when we set a scalar. + TestCompositeHasBits('optionalgroup', 'a') + TestCompositeHasBits('optional_nested_message', 'bb') + TestCompositeHasBits('optional_foreign_message', 'c') + TestCompositeHasBits('optional_import_message', 'd') + + def testReferencesToNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + del proto + # A previous version had a bug where this would raise an exception when + # hitting a now-dead weak reference. + nested.bb = 23 + + def testDisconnectingNestedMessageBeforeSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testHasBitsWhenModifyingRepeatedFields(self): + # Test nesting when we add an element to a repeated field in a submessage. + proto = unittest_pb2.TestNestedMessageHasBits() + proto.optional_nested_message.nestedmessage_repeated_int32.append(5) + self.assertEqual( + [5], proto.optional_nested_message.nestedmessage_repeated_int32) + self.assertTrue(proto.HasField('optional_nested_message')) + + # Do the same test, but with a repeated composite field within the + # submessage. + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() + self.assertTrue(proto.HasField('optional_nested_message')) + + def testHasBitsForManyLevelsOfNesting(self): + # Test nesting many levels deep. + recursive_proto = unittest_pb2.TestMutualRecursionA() + self.assertTrue(not recursive_proto.HasField('bb')) + self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(not recursive_proto.HasField('bb')) + recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 + self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(recursive_proto.HasField('bb')) + self.assertTrue(recursive_proto.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.HasField('bb')) + self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) + self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) + + def testSingularListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_fixed32 = 1 + proto.optional_int32 = 5 + proto.optional_string = 'foo' + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], + proto.ListFields()) + + def testRepeatedListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.repeated_fixed32.append(1) + proto.repeated_int32.append(5) + proto.repeated_int32.append(11) + proto.repeated_string.extend(['foo', 'bar']) + proto.repeated_string.extend([]) + proto.repeated_string.append('baz') + proto.repeated_string.extend(str(x) for x in xrange(2)) + proto.optional_int32 = 21 + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), + (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), + (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), + (proto.DESCRIPTOR.fields_by_name['repeated_string' ], + ['foo', 'bar', 'baz', '0', '1']) ], + proto.ListFields()) + + def testSingularListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 + proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 + proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 5), + (unittest_pb2.optional_fixed32_extension, 1), + (unittest_pb2.optional_string_extension , 'foo') ], + proto.ListFields()) + + def testRepeatedListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) + proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') + proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 21), + (unittest_pb2.repeated_int32_extension , [5, 11]), + (unittest_pb2.repeated_fixed32_extension, [1]), + (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], + proto.ListFields()) + + def testListFieldsAndExtensions(self): + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + unittest_pb2.my_extension_int + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), + (unittest_pb2.my_extension_int , 23), + (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), + (unittest_pb2.my_extension_string , 'bar'), + (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], + proto.ListFields()) + + def testDefaultValues(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + self.assertEqual(0, proto.optional_int64) + self.assertEqual(0, proto.optional_uint32) + self.assertEqual(0, proto.optional_uint64) + self.assertEqual(0, proto.optional_sint32) + self.assertEqual(0, proto.optional_sint64) + self.assertEqual(0, proto.optional_fixed32) + self.assertEqual(0, proto.optional_fixed64) + self.assertEqual(0, proto.optional_sfixed32) + self.assertEqual(0, proto.optional_sfixed64) + self.assertEqual(0.0, proto.optional_float) + self.assertEqual(0.0, proto.optional_double) + self.assertEqual(False, proto.optional_bool) + self.assertEqual('', proto.optional_string) + self.assertEqual('', proto.optional_bytes) + + self.assertEqual(41, proto.default_int32) + self.assertEqual(42, proto.default_int64) + self.assertEqual(43, proto.default_uint32) + self.assertEqual(44, proto.default_uint64) + self.assertEqual(-45, proto.default_sint32) + self.assertEqual(46, proto.default_sint64) + self.assertEqual(47, proto.default_fixed32) + self.assertEqual(48, proto.default_fixed64) + self.assertEqual(49, proto.default_sfixed32) + self.assertEqual(-50, proto.default_sfixed64) + self.assertEqual(51.5, proto.default_float) + self.assertEqual(52e3, proto.default_double) + self.assertEqual(True, proto.default_bool) + self.assertEqual('hello', proto.default_string) + self.assertEqual('world', proto.default_bytes) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) + self.assertEqual(unittest_import_pb2.IMPORT_BAR, + proto.default_import_enum) + + proto = unittest_pb2.TestExtremeDefaultValues() + self.assertEqual(u'\u1234', proto.utf8_string) + + def testHasFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') + + def testClearFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + + def testDisallowedAssignments(self): + # It's illegal to assign values directly to repeated fields + # or to nonrepeated composite fields. Ensure that this fails. + proto = unittest_pb2.TestAllTypes() + # Repeated fields. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) + # Lists shouldn't work, either. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) + # Composite fields. + self.assertRaises(AttributeError, setattr, proto, + 'optional_nested_message', 23) + # Assignment to a repeated nested message field without specifying + # the index in the array of nested messages. + self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, + 'bb', 34) + # Assignment to an attribute of a repeated field. + self.assertRaises(AttributeError, setattr, proto.repeated_float, + 'some_attribute', 34) + # proto.nonexistent_field = 23 should fail as well. + self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) + + # TODO(robinson): Add type-safety check for enums. + def testSingleScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') + self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) + self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + + def testSingleScalarBoundsChecking(self): + def TestMinAndMaxIntegers(field_name, expected_min, expected_max): + pb = unittest_pb2.TestAllTypes() + setattr(pb, field_name, expected_min) + setattr(pb, field_name, expected_max) + self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) + self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) + + TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) + TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) + TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) + TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) + TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + def testRepeatedScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) + self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') + self.assertRaises(TypeError, proto.repeated_string, 10) + self.assertRaises(TypeError, proto.repeated_bytes, 10) + + proto.repeated_int32.append(10) + proto.repeated_int32[0] = 23 + self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) + self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + + def testSingleScalarGettersAndSetters(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + proto.optional_int32 = 1 + self.assertEqual(1, proto.optional_int32) + # TODO(robinson): Test all other scalar field types. + + def testSingleScalarClearField(self): + proto = unittest_pb2.TestAllTypes() + # Should be allowed to clear something that's not there (a no-op). + proto.ClearField('optional_int32') + proto.optional_int32 = 1 + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + self.assertEqual(0, proto.optional_int32) + self.assertTrue(not proto.HasField('optional_int32')) + # TODO(robinson): Test all other scalar field types. + + def testEnums(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testRepeatedScalars(self): + proto = unittest_pb2.TestAllTypes() + + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(5) + proto.repeated_int32.append(10) + proto.repeated_int32.append(15) + self.assertTrue(proto.repeated_int32) + self.assertEqual(3, len(proto.repeated_int32)) + + self.assertEqual([5, 10, 15], proto.repeated_int32) + + # Test single retrieval. + self.assertEqual(5, proto.repeated_int32[0]) + self.assertEqual(15, proto.repeated_int32[-1]) + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) + + # Test single assignment. + proto.repeated_int32[1] = 20 + self.assertEqual([5, 20, 15], proto.repeated_int32) + + # Test insertion. + proto.repeated_int32.insert(1, 25) + self.assertEqual([5, 25, 20, 15], proto.repeated_int32) + + # Test slice retrieval. + proto.repeated_int32.append(30) + self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) + self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) + + # Test slice assignment with an iterator + proto.repeated_int32[1:4] = (i for i in xrange(3)) + self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) + + # Test slice assignment. + proto.repeated_int32[1:4] = [35, 40, 45] + self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_int32: + result.append(i) + self.assertEqual([5, 35, 40, 45, 30], result) + + # Test single deletion. + del proto.repeated_int32[2] + self.assertEqual([5, 35, 45, 30], proto.repeated_int32) + + # Test slice deletion. + del proto.repeated_int32[2:] + self.assertEqual([5, 35], proto.repeated_int32) + + # Test clearing. + proto.ClearField('repeated_int32') + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + + def testRepeatedScalarsRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(5) + proto.repeated_int32.append(10) + proto.repeated_int32.append(5) + proto.repeated_int32.append(5) + + self.assertEqual(4, len(proto.repeated_int32)) + proto.repeated_int32.remove(5) + self.assertEqual(3, len(proto.repeated_int32)) + self.assertEqual(10, proto.repeated_int32[0]) + self.assertEqual(5, proto.repeated_int32[1]) + self.assertEqual(5, proto.repeated_int32[2]) + + proto.repeated_int32.remove(5) + self.assertEqual(2, len(proto.repeated_int32)) + self.assertEqual(10, proto.repeated_int32[0]) + self.assertEqual(5, proto.repeated_int32[1]) + + proto.repeated_int32.remove(10) + self.assertEqual(1, len(proto.repeated_int32)) + self.assertEqual(5, proto.repeated_int32[0]) + + # Remove a non-existent element. + self.assertRaises(ValueError, proto.repeated_int32.remove, 123) + + def testRepeatedComposites(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + m1 = proto.repeated_nested_message.add() + self.assertTrue(proto.repeated_nested_message) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) + + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + 1234) + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + -1234) + + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + 'foo') + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + None) + + # Test slice retrieval. + m2 = proto.repeated_nested_message.add() + m3 = proto.repeated_nested_message.add() + m4 = proto.repeated_nested_message.add() + self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) + self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_nested_message: + result.append(i) + self.assertIs([m0, m1, m2, m3, m4], result) + + # Test single deletion. + del proto.repeated_nested_message[2] + self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) + + # Test slice deletion. + del proto.repeated_nested_message[2:] + self.assertIs([m0, m1], proto.repeated_nested_message) + + # Test clearing. + proto.ClearField('repeated_nested_message') + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + + def testHandWrittenReflection(self): + # TODO(robinson): We probably need a better way to specify + # protocol types by hand. But then again, this isn't something + # we expect many people to do. Hmm. + FieldDescriptor = descriptor.FieldDescriptor + foo_field_descriptor = FieldDescriptor( + name='foo_field', full_name='MyProto.foo_field', + index=0, number=1, type=FieldDescriptor.TYPE_INT64, + cpp_type=FieldDescriptor.CPPTYPE_INT64, + label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, + containing_type=None, message_type=None, enum_type=None, + is_extension=False, extension_scope=None, + options=descriptor_pb2.FieldOptions()) + mydescriptor = descriptor.Descriptor( + name='MyProto', full_name='MyProto', filename='ignored', + containing_type=None, nested_types=[], enum_types=[], + fields=[foo_field_descriptor], extensions=[], + options=descriptor_pb2.MessageOptions()) + class MyProtoClass(message.Message): + DESCRIPTOR = mydescriptor + __metaclass__ = reflection.GeneratedProtocolMessageType + myproto_instance = MyProtoClass() + self.assertEqual(0, myproto_instance.foo_field) + self.assertTrue(not myproto_instance.HasField('foo_field')) + myproto_instance.foo_field = 23 + self.assertEqual(23, myproto_instance.foo_field) + self.assertTrue(myproto_instance.HasField('foo_field')) + + def testTopLevelExtensionsForOptionalScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_int32_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension]) + # As with normal scalar fields, just doing a read doesn't actually set the + # "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Actually set the thing. + extendee_proto.Extensions[extension] = 23 + self.assertEqual(23, extendee_proto.Extensions[extension]) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Ensure that clearing works as well. + extendee_proto.ClearExtension(extension) + self.assertEqual(0, extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + def testTopLevelExtensionsForRepeatedScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeated_string_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + extendee_proto.Extensions[extension].append('foo') + self.assertEqual(['foo'], extendee_proto.Extensions[extension]) + string_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(string_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForOptionalMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_foreign_message_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension].c) + # As with normal (non-extension) fields, merely reading from the + # thing shouldn't set the "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + extendee_proto.Extensions[extension].c = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].c) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Save a reference here. + foreign_message = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) + # Setting a field on foreign_message now shouldn't set + # any "has" bits on extendee_proto. + foreign_message.c = 42 + self.assertEqual(42, foreign_message.c) + self.assertTrue(foreign_message.HasField('c')) + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForRepeatedMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeatedgroup_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + group = extendee_proto.Extensions[extension].add() + group.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension][0].a) + group.a = 42 + self.assertEqual(42, extendee_proto.Extensions[extension][0].a) + group_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(group_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testNestedExtensions(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + + # We just test the non-repeated case. + self.assertTrue(not extendee_proto.HasExtension(extension)) + required = extendee_proto.Extensions[extension] + self.assertEqual(0, required.a) + self.assertTrue(not extendee_proto.HasExtension(extension)) + required.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].a) + self.assertTrue(extendee_proto.HasExtension(extension)) + extendee_proto.ClearExtension(extension) + self.assertTrue(required is not extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + # If message A directly contains message B, and + # a.HasField('b') is currently False, then mutating any + # extension in B should change a.HasField('b') to True + # (and so on up the object tree). + def testHasBitsForAncestorsOfExtendedMessage(self): + # Optional scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension] = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual([], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension].append(23) + self.assertEqual([23], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Optional message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, len(toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension])) + self.assertTrue(not toplevel.HasField('submessage')) + foreign = toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension].add() + self.assertTrue(foreign is toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension][0]) + self.assertTrue(toplevel.HasField('submessage')) + + def testDisconnectionAfterClearingEmptyMessage(self): + toplevel = more_extensions_pb2.TopLevelMessage() + extendee_proto = toplevel.submessage + extension = more_extensions_pb2.optional_message_extension + extension_proto = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + extension_proto.foreign_message_int = 23 + + self.assertTrue(not toplevel.HasField('submessage')) + self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) + + def testExtensionFailureModes(self): + extendee_proto = unittest_pb2.TestAllExtensions() + + # Try non-extension-handle arguments to HasExtension, + # ClearExtension(), and Extensions[]... + self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) + self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) + + # Try something that *is* an extension handle, just not for + # this message... + unknown_handle = more_extensions_pb2.optional_int_extension + self.assertRaises(KeyError, extendee_proto.HasExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.ClearExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, + unknown_handle, 5) + + # Try call HasExtension() with a valid handle, but for a + # *repeated* field. (Just as with non-extension repeated + # fields, Has*() isn't supported for extension repeated fields). + self.assertRaises(KeyError, extendee_proto.HasExtension, + unittest_pb2.repeated_string_extension) + + def testStaticParseFrom(self): + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + + string1 = proto1.SerializeToString() + proto2 = unittest_pb2.TestAllTypes.FromString(string1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + def testMergeFromSingularField(self): + # Test merge with just a singular field. + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + + proto2 = unittest_pb2.TestAllTypes() + # This shouldn't get overwritten. + proto2.optional_string = 'value' + + proto2.MergeFrom(proto1) + self.assertEqual(1, proto2.optional_int32) + self.assertEqual('value', proto2.optional_string) + + def testMergeFromRepeatedField(self): + # Test merge with just a repeated field. + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.repeated_int32.append(2) + + proto2 = unittest_pb2.TestAllTypes() + proto2.repeated_int32.append(0) + proto2.MergeFrom(proto1) + + self.assertEqual(0, proto2.repeated_int32[0]) + self.assertEqual(1, proto2.repeated_int32[1]) + self.assertEqual(2, proto2.repeated_int32[2]) + + def testMergeFromOptionalGroup(self): + # Test merge with an optional group. + proto1 = unittest_pb2.TestAllTypes() + proto1.optionalgroup.a = 12 + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFrom(proto1) + self.assertEqual(12, proto2.optionalgroup.a) + + def testMergeFromRepeatedNestedMessage(self): + # Test merge with a repeated nested message. + proto1 = unittest_pb2.TestAllTypes() + m = proto1.repeated_nested_message.add() + m.bb = 123 + m = proto1.repeated_nested_message.add() + m.bb = 321 + + proto2 = unittest_pb2.TestAllTypes() + m = proto2.repeated_nested_message.add() + m.bb = 999 + proto2.MergeFrom(proto1) + self.assertEqual(999, proto2.repeated_nested_message[0].bb) + self.assertEqual(123, proto2.repeated_nested_message[1].bb) + self.assertEqual(321, proto2.repeated_nested_message[2].bb) + + def testMergeFromAllFields(self): + # With all fields set. + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFrom(proto1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + # Serialized string should be equal too. + string1 = proto1.SerializeToString() + string2 = proto2.SerializeToString() + self.assertEqual(string1, string2) + + def testMergeFromExtensionsSingular(self): + proto1 = unittest_pb2.TestAllExtensions() + proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 + + proto2 = unittest_pb2.TestAllExtensions() + proto2.MergeFrom(proto1) + self.assertEqual( + 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) + + def testMergeFromExtensionsRepeated(self): + proto1 = unittest_pb2.TestAllExtensions() + proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) + proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) + + proto2 = unittest_pb2.TestAllExtensions() + proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) + proto2.MergeFrom(proto1) + self.assertEqual( + 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) + self.assertEqual( + 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) + self.assertEqual( + 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) + self.assertEqual( + 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) + + def testMergeFromExtensionsNestedMessage(self): + proto1 = unittest_pb2.TestAllExtensions() + ext1 = proto1.Extensions[ + unittest_pb2.repeated_nested_message_extension] + m = ext1.add() + m.bb = 222 + m = ext1.add() + m.bb = 333 + + proto2 = unittest_pb2.TestAllExtensions() + ext2 = proto2.Extensions[ + unittest_pb2.repeated_nested_message_extension] + m = ext2.add() + m.bb = 111 + + proto2.MergeFrom(proto1) + ext2 = proto2.Extensions[ + unittest_pb2.repeated_nested_message_extension] + self.assertEqual(3, len(ext2)) + self.assertEqual(111, ext2[0].bb) + self.assertEqual(222, ext2[1].bb) + self.assertEqual(333, ext2[2].bb) + + def testCopyFromSingularField(self): + # Test copy with just a singular field. + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto1.optional_string = 'important-text' + + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_string = 'value' + + proto2.CopyFrom(proto1) + self.assertEqual(1, proto2.optional_int32) + self.assertEqual('important-text', proto2.optional_string) + + def testCopyFromRepeatedField(self): + # Test copy with a repeated field. + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.repeated_int32.append(2) + + proto2 = unittest_pb2.TestAllTypes() + proto2.repeated_int32.append(0) + proto2.CopyFrom(proto1) + + self.assertEqual(1, proto2.repeated_int32[0]) + self.assertEqual(2, proto2.repeated_int32[1]) + + def testCopyFromAllFields(self): + # With all fields set. + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + proto2 = unittest_pb2.TestAllTypes() + proto2.CopyFrom(proto1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + # Serialized string should be equal too. + string1 = proto1.SerializeToString() + string2 = proto2.SerializeToString() + self.assertEqual(string1, string2) + + def testCopyFromSelf(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.optional_int32 = 2 + proto1.optional_string = 'important-text' + + proto1.CopyFrom(proto1) + self.assertEqual(1, proto1.repeated_int32[0]) + self.assertEqual(2, proto1.optional_int32) + self.assertEqual('important-text', proto1.optional_string) + + def testClear(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllTypes() + self.assertEquals(proto, empty_proto) + + # Test if extensions which were set are cleared. + proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllExtensions() + self.assertEquals(proto, empty_proto) + + def testIsInitialized(self): + # Trivial cases - all optional fields and extensions. + proto = unittest_pb2.TestAllTypes() + self.assertTrue(proto.IsInitialized()) + proto = unittest_pb2.TestAllExtensions() + self.assertTrue(proto.IsInitialized()) + + # The case of uninitialized required fields. + proto = unittest_pb2.TestRequired() + self.assertFalse(proto.IsInitialized()) + proto.a = proto.b = proto.c = 2 + self.assertTrue(proto.IsInitialized()) + + # The case of uninitialized submessage. + proto = unittest_pb2.TestRequiredForeign() + self.assertTrue(proto.IsInitialized()) + proto.optional_message.a = 1 + self.assertFalse(proto.IsInitialized()) + proto.optional_message.b = 0 + proto.optional_message.c = 0 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized repeated submessage. + message1 = proto.repeated_message.add() + self.assertFalse(proto.IsInitialized()) + message1.a = message1.b = message1.c = 0 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized repeated group in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.multi + message1 = proto.Extensions[extension].add() + message2 = proto.Extensions[extension].add() + self.assertFalse(proto.IsInitialized()) + message1.a = 1 + message1.b = 1 + message1.c = 1 + self.assertFalse(proto.IsInitialized()) + message2.a = 2 + message2.b = 2 + message2.c = 2 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized nonrepeated message in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + proto.Extensions[extension].a = 1 + self.assertFalse(proto.IsInitialized()) + proto.Extensions[extension].b = 2 + proto.Extensions[extension].c = 3 + self.assertTrue(proto.IsInitialized()) + + def testStringUTF8Encoding(self): + proto = unittest_pb2.TestAllTypes() + + # Assignment of a unicode object to a field of type 'bytes' is not allowed. + self.assertRaises(TypeError, + setattr, proto, 'optional_bytes', u'unicode object') + + # Check that the default value is of python's 'unicode' type. + self.assertEqual(type(proto.optional_string), unicode) + + proto.optional_string = unicode('Testing') + self.assertEqual(proto.optional_string, str('Testing')) + + # Assign a value of type 'str' which can be encoded in UTF-8. + proto.optional_string = str('Testing') + self.assertEqual(proto.optional_string, unicode('Testing')) + + # Values of type 'str' are also accepted as long as they can be encoded in + # UTF-8. + self.assertEqual(type(proto.optional_string), str) + + # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', str('a\x80a')) + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + # No exception thrown. + proto.optional_string = 'abc' + + def testStringUTF8Serialization(self): + proto = unittest_mset_pb2.TestMessageSet() + extension_message = unittest_mset_pb2.TestMessageSetExtension2 + extension = extension_message.message_set_extension + + test_utf8 = u'Тест' + test_utf8_bytes = test_utf8.encode('utf-8') + + # 'Test' in another language, using UTF-8 charset. + proto.Extensions[extension].str = test_utf8 + + # Serialize using the MessageSet wire format (this is specified in the + # .proto file). + serialized = proto.SerializeToString() + + # Check byte size. + self.assertEqual(proto.ByteSize(), len(serialized)) + + raw = unittest_mset_pb2.RawMessageSet() + raw.MergeFromString(serialized) + + message2 = unittest_mset_pb2.TestMessageSetExtension2() + + self.assertEqual(1, len(raw.item)) + # Check that the type_id is the same as the tag ID in the .proto file. + self.assertEqual(raw.item[0].type_id, 1547769) + + # Check the actually bytes on the wire. + self.assertTrue( + raw.item[0].message.endswith(test_utf8_bytes)) + message2.MergeFromString(raw.item[0].message) + + self.assertEqual(type(message2.str), unicode) + self.assertEqual(message2.str, test_utf8) + + # How about if the bytes on the wire aren't a valid UTF-8 encoded string. + bytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * '\xff') + self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + + +# Since we had so many tests for protocol buffer equality, we broke these out +# into separate TestCase classes. + + +class TestAllTypesEqualityTest(unittest.TestCase): + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + + def testSelfEquality(self): + self.assertEqual(self.first_proto, self.first_proto) + + def testEmptyProtosEqual(self): + self.assertEqual(self.first_proto, self.second_proto) + + +class FullProtosEqualityTest(unittest.TestCase): + + """Equality tests using completely-full protos as a starting point.""" + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.first_proto) + test_util.SetAllFields(self.second_proto) + + def testNoneNotEqual(self): + self.assertNotEqual(self.first_proto, None) + self.assertNotEqual(None, self.second_proto) + + def testNotEqualToOtherMessage(self): + third_proto = unittest_pb2.TestRequired() + self.assertNotEqual(self.first_proto, third_proto) + self.assertNotEqual(third_proto, self.second_proto) + + def testAllFieldsFilledEquality(self): + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalar(self): + # Nonrepeated scalar field change should cause inequality. + self.first_proto.optional_int32 += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + # ...as should clearing a field. + self.first_proto.ClearField('optional_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedComposite(self): + # Change a nonrepeated composite field. + self.first_proto.optional_nested_message.bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Clear a field in the nested message. + self.first_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb = ( + self.second_proto.optional_nested_message.bb) + self.assertEqual(self.first_proto, self.second_proto) + # Remove the nested message entirely. + self.first_proto.ClearField('optional_nested_message') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedScalar(self): + # Change a repeated scalar field. + self.first_proto.repeated_int32.append(5) + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.ClearField('repeated_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedComposite(self): + # Change value within a repeated composite field. + self.first_proto.repeated_nested_message[0].bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.repeated_nested_message[0].bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Add a value to a repeated composite field. + self.first_proto.repeated_nested_message.add() + self.assertNotEqual(self.first_proto, self.second_proto) + self.second_proto.repeated_nested_message.add() + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalarHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated scalar field. + self.first_proto.ClearField('optional_int32') + self.second_proto.optional_int32 = 0 + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedCompositeHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated composite field. + self.first_proto.ClearField('optional_nested_message') + self.second_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + # TODO(robinson): Replace next two lines with method + # to set the "has" bit without changing the value, + # if/when such a method exists. + self.first_proto.optional_nested_message.bb = 0 + self.first_proto.optional_nested_message.ClearField('bb') + self.assertEqual(self.first_proto, self.second_proto) + + +class ExtensionEqualityTest(unittest.TestCase): + + def testExtensionEquality(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual(first_proto, second_proto) + test_util.SetAllExtensions(first_proto) + self.assertNotEqual(first_proto, second_proto) + test_util.SetAllExtensions(second_proto) + self.assertEqual(first_proto, second_proto) + + # Ensure that we check value equality. + first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 + self.assertEqual(first_proto, second_proto) + + # Ensure that we also look at "has" bits. + first_proto.ClearExtension(unittest_pb2.optional_int32_extension) + second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertEqual(first_proto, second_proto) + + # Ensure that differences in cached values + # don't matter if "has" bits are both false. + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual( + 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) + self.assertEqual(first_proto, second_proto) + + +class MutualRecursionEqualityTest(unittest.TestCase): + + def testEqualityWithMutualRecursion(self): + first_proto = unittest_pb2.TestMutualRecursionA() + second_proto = unittest_pb2.TestMutualRecursionA() + self.assertEqual(first_proto, second_proto) + first_proto.bb.a.bb.optional_int32 = 23 + self.assertNotEqual(first_proto, second_proto) + second_proto.bb.a.bb.optional_int32 = 23 + self.assertEqual(first_proto, second_proto) + + +class ByteSizeTest(unittest.TestCase): + + def setUp(self): + self.proto = unittest_pb2.TestAllTypes() + self.extended_proto = more_extensions_pb2.ExtendedMessage() + self.packed_proto = unittest_pb2.TestPackedTypes() + self.packed_extended_proto = unittest_pb2.TestPackedExtensions() + + def Size(self): + return self.proto.ByteSize() + + def testEmptyMessage(self): + self.assertEqual(0, self.proto.ByteSize()) + + def testVarints(self): + def Test(i, expected_varint_size): + self.proto.Clear() + self.proto.optional_int64 = i + # Add one to the varint size for the tag info + # for tag 1. + self.assertEqual(expected_varint_size + 1, self.Size()) + Test(0, 1) + Test(1, 1) + for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): + Test((1 << i) - 1, num_bytes) + Test(-1, 10) + Test(-2, 10) + Test(-(1 << 63), 10) + + def testStrings(self): + self.proto.optional_string = '' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2, self.Size()) + + self.proto.optional_string = 'abc' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2 + len(self.proto.optional_string), self.Size()) + + self.proto.optional_string = 'x' * 128 + # Need one byte for tag info (tag #14), and TWO bytes for length. + self.assertEqual(3 + len(self.proto.optional_string), self.Size()) + + def testOtherNumerics(self): + self.proto.optional_fixed32 = 1234 + # One byte for tag and 4 bytes for fixed32. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_fixed64 = 1234 + # One byte for tag and 8 bytes for fixed64. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_float = 1.234 + # One byte for tag and 4 bytes for float. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_double = 1.234 + # One byte for tag and 8 bytes for float. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_sint32 = 64 + # One byte for tag and 2 bytes for zig-zag-encoded 64. + self.assertEqual(3, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + def testComposites(self): + # 3 bytes. + self.proto.optional_nested_message.bb = (1 << 14) + # Plus one byte for bb tag. + # Plus 1 byte for optional_nested_message serialized size. + # Plus two bytes for optional_nested_message tag. + self.assertEqual(3 + 1 + 1 + 2, self.Size()) + + def testGroups(self): + # 4 bytes. + self.proto.optionalgroup.a = (1 << 21) + # Plus two bytes for |a| tag. + # Plus 2 * two bytes for START_GROUP and END_GROUP tags. + self.assertEqual(4 + 2 + 2*2, self.Size()) + + def testRepeatedScalars(self): + self.proto.repeated_int32.append(10) # 1 byte. + self.proto.repeated_int32.append(128) # 2 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + + def testRepeatedScalarsExtend(self): + self.proto.repeated_int32.extend([10, 128]) # 3 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + + def testRepeatedScalarsRemove(self): + self.proto.repeated_int32.append(10) # 1 byte. + self.proto.repeated_int32.append(128) # 2 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + self.proto.repeated_int32.remove(128) + self.assertEqual(1 + 2, self.Size()) + + def testRepeatedComposites(self): + # Empty message. 2 bytes tag plus 1 byte length. + foreign_message_0 = self.proto.repeated_nested_message.add() + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + foreign_message_1 = self.proto.repeated_nested_message.add() + foreign_message_1.bb = 7 + self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + + def testRepeatedCompositesDelete(self): + # Empty message. 2 bytes tag plus 1 byte length. + foreign_message_0 = self.proto.repeated_nested_message.add() + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + foreign_message_1 = self.proto.repeated_nested_message.add() + foreign_message_1.bb = 9 + self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + del self.proto.repeated_nested_message[0] + self.assertEqual(2 + 1 + 1 + 1, self.Size()) + + # Now add a new message. + foreign_message_2 = self.proto.repeated_nested_message.add() + foreign_message_2.bb = 12 + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + del self.proto.repeated_nested_message[1] + self.assertEqual(2 + 1 + 1 + 1, self.Size()) + + del self.proto.repeated_nested_message[0] + self.assertEqual(0, self.Size()) + + def testRepeatedGroups(self): + # 2-byte START_GROUP plus 2-byte END_GROUP. + group_0 = self.proto.repeatedgroup.add() + # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| + # plus 2-byte END_GROUP. + group_1 = self.proto.repeatedgroup.add() + group_1.a = 7 + self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) + + def testExtensions(self): + proto = unittest_pb2.TestAllExtensions() + self.assertEqual(0, proto.ByteSize()) + extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. + proto.Extensions[extension] = 23 + # 1 byte for tag, 1 byte for value. + self.assertEqual(2, proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedScalar(self): + # Test non-extension. + self.proto.optional_int32 = 1 + self.assertEqual(2, self.proto.ByteSize()) + self.proto.optional_int32 = 128 + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_int_extension + self.extended_proto.Extensions[extension] = 1 + self.assertEqual(2, self.extended_proto.ByteSize()) + self.extended_proto.Extensions[extension] = 128 + self.assertEqual(3, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedScalar(self): + # Test non-extension. + self.proto.repeated_int32.append(1) + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_int32.append(1) + self.assertEqual(6, self.proto.ByteSize()) + self.proto.repeated_int32[1] = 128 + self.assertEqual(7, self.proto.ByteSize()) + self.proto.ClearField('repeated_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_int_extension + repeated = self.extended_proto.Extensions[extension] + repeated.append(1) + self.assertEqual(2, self.extended_proto.ByteSize()) + repeated.append(1) + self.assertEqual(4, self.extended_proto.ByteSize()) + repeated[1] = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedMessage(self): + # Test non-extension. + self.proto.optional_foreign_message.c = 1 + self.assertEqual(5, self.proto.ByteSize()) + self.proto.optional_foreign_message.c = 128 + self.assertEqual(6, self.proto.ByteSize()) + self.proto.optional_foreign_message.ClearField('c') + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_message_extension + child = self.extended_proto.Extensions[extension] + self.assertEqual(0, self.extended_proto.ByteSize()) + child.foreign_message_int = 1 + self.assertEqual(4, self.extended_proto.ByteSize()) + child.foreign_message_int = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedMessage(self): + # Test non-extension. + child0 = self.proto.repeated_foreign_message.add() + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_foreign_message.add() + self.assertEqual(6, self.proto.ByteSize()) + child0.c = 1 + self.assertEqual(8, self.proto.ByteSize()) + self.proto.ClearField('repeated_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_message_extension + child_list = self.extended_proto.Extensions[extension] + child0 = child_list.add() + self.assertEqual(2, self.extended_proto.ByteSize()) + child_list.add() + self.assertEqual(4, self.extended_proto.ByteSize()) + child0.foreign_message_int = 1 + self.assertEqual(6, self.extended_proto.ByteSize()) + child0.ClearField('foreign_message_int') + self.assertEqual(4, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testPackedRepeatedScalars(self): + self.assertEqual(0, self.packed_proto.ByteSize()) + + self.packed_proto.packed_int32.append(10) # 1 byte. + self.packed_proto.packed_int32.append(128) # 2 bytes. + # The tag is 2 bytes (the field number is 90), and the varint + # storing the length is 1 byte. + int_size = 1 + 2 + 3 + self.assertEqual(int_size, self.packed_proto.ByteSize()) + + self.packed_proto.packed_double.append(4.2) # 8 bytes + self.packed_proto.packed_double.append(3.25) # 8 bytes + # 2 more tag bytes, 1 more length byte. + double_size = 8 + 8 + 3 + self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) + + self.packed_proto.ClearField('packed_int32') + self.assertEqual(double_size, self.packed_proto.ByteSize()) + + def testPackedExtensions(self): + self.assertEqual(0, self.packed_extended_proto.ByteSize()) + extension = self.packed_extended_proto.Extensions[ + unittest_pb2.packed_fixed32_extension] + extension.extend([1, 2, 3, 4]) # 16 bytes + # Tag is 3 bytes. + self.assertEqual(19, self.packed_extended_proto.ByteSize()) + + +# TODO(robinson): We need cross-language serialization consistency tests. +# Issues to be sure to cover include: +# * Handling of unrecognized tags ("uninterpreted_bytes"). +# * Handling of MessageSets. +# * Consistent ordering of tags in the wire format, +# including ordering between extensions and non-extension +# fields. +# * Consistent serialization of negative numbers, especially +# negative int32s. +# * Handling of empty submessages (with and without "has" +# bits set). + +class SerializationTest(unittest.TestCase): + + def testSerializeEmtpyMessage(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllFields(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllExtensions(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(first_proto) + serialized = first_proto.SerializeToString() + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testCanonicalSerializationOrder(self): + proto = more_messages_pb2.OutOfOrderFields() + # These are also their tag numbers. Even though we're setting these in + # reverse-tag order AND they're listed in reverse tag-order in the .proto + # file, they should nonetheless be serialized in tag order. + proto.optional_sint32 = 5 + proto.Extensions[more_messages_pb2.optional_uint64] = 4 + proto.optional_uint32 = 3 + proto.Extensions[more_messages_pb2.optional_int64] = 2 + proto.optional_int32 = 1 + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = decoder.Decoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(2, d.ReadInt64()) + self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(3, d.ReadUInt32()) + self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(4, d.ReadUInt64()) + self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(5, d.ReadSInt32()) + + def testCanonicalSerializationOrderSameAsCpp(self): + # Copy of the same test we use for C++. + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + serialized = proto.SerializeToString() + test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) + + def testMergeFromStringWhenFieldsAlreadySet(self): + first_proto = unittest_pb2.TestAllTypes() + first_proto.repeated_string.append('foobar') + first_proto.optional_int32 = 23 + first_proto.optional_nested_message.bb = 42 + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestAllTypes() + second_proto.repeated_string.append('baz') + second_proto.optional_int32 = 100 + second_proto.optional_nested_message.bb = 999 + + second_proto.MergeFromString(serialized) + # Ensure that we append to repeated fields. + self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) + # Ensure that we overwrite nonrepeatd scalars. + self.assertEqual(23, second_proto.optional_int32) + # Ensure that we recursively call MergeFromString() on + # submessages. + self.assertEqual(42, second_proto.optional_nested_message.bb) + + def testMessageSetWireFormat(self): + proto = unittest_mset_pb2.TestMessageSet() + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + extension1 = extension_message1.message_set_extension + extension2 = extension_message2.message_set_extension + proto.Extensions[extension1].i = 123 + proto.Extensions[extension2].str = 'foo' + + # Serialize using the MessageSet wire format (this is specified in the + # .proto file). + serialized = proto.SerializeToString() + + raw = unittest_mset_pb2.RawMessageSet() + self.assertEqual(False, + raw.DESCRIPTOR.GetOptions().message_set_wire_format) + raw.MergeFromString(serialized) + self.assertEqual(2, len(raw.item)) + + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.MergeFromString(raw.item[0].message) + self.assertEqual(123, message1.i) + + message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2.MergeFromString(raw.item[1].message) + self.assertEqual('foo', message2.str) + + # Deserialize using the MessageSet wire format. + proto2 = unittest_mset_pb2.TestMessageSet() + proto2.MergeFromString(serialized) + self.assertEqual(123, proto2.Extensions[extension1].i) + self.assertEqual('foo', proto2.Extensions[extension2].str) + + # Check byte size. + self.assertEqual(proto2.ByteSize(), len(serialized)) + self.assertEqual(proto.ByteSize(), len(serialized)) + + def testMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an item. + item = raw.item.add() + item.type_id = 1545008 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + # Add a second, unknown extension. + item = raw.item.add() + item.type_id = 1545009 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12346 + item.message = message1.SerializeToString() + + # Add another unknown extension. + item = raw.item.add() + item.type_id = 1545010 + message1 = unittest_mset_pb2.TestMessageSetExtension2() + message1.str = 'foo' + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Check that the message parsed well. + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension1 = extension_message1.message_set_extension + self.assertEquals(12345, proto.Extensions[extension1].i) + + def testUnknownFields(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + + serialized = proto.SerializeToString() + + # The empty message should be parsable with all of the fields + # unknown. + proto2 = unittest_pb2.TestEmptyMessage() + + # Parsing this message should succeed. + proto2.MergeFromString(serialized) + + # Now test with a int64 field set. + proto = unittest_pb2.TestAllTypes() + proto.optional_int64 = 0x0fffffffffffffff + serialized = proto.SerializeToString() + # The empty message should be parsable with all of the fields + # unknown. + proto2 = unittest_pb2.TestEmptyMessage() + # Parsing this message should succeed. + proto2.MergeFromString(serialized) + + def _CheckRaises(self, exc_class, callable_obj, exception): + """This method checks if the excpetion type and message are as expected.""" + try: + callable_obj() + except exc_class, ex: + # Check if the exception message is the right one. + self.assertEqual(exception, str(ex)) + return + else: + raise self.failureException('%s not raised' % str(exc_class)) + + def testSerializeUninitialized(self): + proto = unittest_pb2.TestRequired() + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Required field protobuf_unittest.TestRequired.a is not set.') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.a = 1 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Required field protobuf_unittest.TestRequired.b is not set.') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.b = 2 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Required field protobuf_unittest.TestRequired.c is not set.') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.c = 3 + serialized = proto.SerializeToString() + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto2 = unittest_pb2.TestRequired() + proto2.MergeFromString(serialized) + self.assertEqual(1, proto2.a) + self.assertEqual(2, proto2.b) + self.assertEqual(3, proto2.c) + proto2.ParseFromString(partial) + self.assertEqual(1, proto2.a) + self.assertEqual(2, proto2.b) + self.assertEqual(3, proto2.c) + + def testSerializeAllPackedFields(self): + first_proto = unittest_pb2.TestPackedTypes() + second_proto = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + bytes_read = second_proto.MergeFromString(serialized) + self.assertEqual(second_proto.ByteSize(), bytes_read) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllPackedExtensions(self): + first_proto = unittest_pb2.TestPackedExtensions() + second_proto = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(first_proto) + serialized = first_proto.SerializeToString() + bytes_read = second_proto.MergeFromString(serialized) + self.assertEqual(second_proto.ByteSize(), bytes_read) + self.assertEqual(first_proto, second_proto) + + def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): + first_proto = unittest_pb2.TestPackedTypes() + first_proto.packed_int32.extend([1, 2]) + first_proto.packed_double.append(3.0) + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestPackedTypes() + second_proto.packed_int32.append(3) + second_proto.packed_double.extend([1.0, 2.0]) + second_proto.packed_sint32.append(4) + + second_proto.MergeFromString(serialized) + self.assertEqual([3, 1, 2], second_proto.packed_int32) + self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) + self.assertEqual([4], second_proto.packed_sint32) + + def testPackedFieldsWireFormat(self): + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes + proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes + proto.packed_float.append(2.0) # 4 bytes, will be before double + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = decoder.Decoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(1+1+1+2, d.ReadInt32()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual(2, d.ReadInt32()) + self.assertEqual(150, d.ReadInt32()) + self.assertEqual(3, d.ReadInt32()) + self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(4, d.ReadInt32()) + self.assertEqual(2.0, d.ReadFloat()) + self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(8+8, d.ReadInt32()) + self.assertEqual(1.0, d.ReadDouble()) + self.assertEqual(1000.0, d.ReadDouble()) + self.assertTrue(d.EndOfStream()) + + def testFieldNumbers(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) + self.assertEqual( + unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) + self.assertEqual( + unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) + self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) + self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) + self.assertEqual( + unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) + self.assertEqual( + unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) + + def testExtensionFieldNumbers(self): + self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) + self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) + self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) + self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) + self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) + self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) + self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) + self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) + self.assertEqual( + unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) + self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) + self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, + 21) + self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) + self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) + self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) + self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) + self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) + self.assertEqual( + unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) + self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) + self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, + 51) + + def testInitKwargs(self): + proto = unittest_pb2.TestAllTypes( + optional_int32=1, + optional_string='foo', + optional_bool=True, + optional_bytes='bar', + optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), + optional_foreign_message=unittest_pb2.ForeignMessage(c=1), + optional_nested_enum=unittest_pb2.TestAllTypes.FOO, + optional_foreign_enum=unittest_pb2.FOREIGN_FOO, + repeated_int32=[1, 2, 3]) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('optional_int32')) + self.assertTrue(proto.HasField('optional_string')) + self.assertTrue(proto.HasField('optional_bool')) + self.assertTrue(proto.HasField('optional_bytes')) + self.assertTrue(proto.HasField('optional_nested_message')) + self.assertTrue(proto.HasField('optional_foreign_message')) + self.assertTrue(proto.HasField('optional_nested_enum')) + self.assertTrue(proto.HasField('optional_foreign_enum')) + self.assertEqual(1, proto.optional_int32) + self.assertEqual('foo', proto.optional_string) + self.assertEqual(True, proto.optional_bool) + self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(1, proto.optional_nested_message.bb) + self.assertEqual(1, proto.optional_foreign_message.c) + self.assertEqual(unittest_pb2.TestAllTypes.FOO, + proto.optional_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) + self.assertEqual([1, 2, 3], proto.repeated_int32) + + def testInitArgsUnknownFieldName(self): + def InitalizeEmptyMessageWithExtraKeywordArg(): + unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') + self._CheckRaises(ValueError, + InitalizeEmptyMessageWithExtraKeywordArg, + 'Protocol message has no "unknown" field.') + + def testInitRequiredKwargs(self): + proto = unittest_pb2.TestRequired(a=1, b=1, c=1) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('a')) + self.assertTrue(proto.HasField('b')) + self.assertTrue(proto.HasField('c')) + self.assertTrue(not proto.HasField('dummy2')) + self.assertEqual(1, proto.a) + self.assertEqual(1, proto.b) + self.assertEqual(1, proto.c) + + def testInitRequiredForeignKwargs(self): + proto = unittest_pb2.TestRequiredForeign( + optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('optional_message')) + self.assertTrue(proto.optional_message.IsInitialized()) + self.assertTrue(proto.optional_message.HasField('a')) + self.assertTrue(proto.optional_message.HasField('b')) + self.assertTrue(proto.optional_message.HasField('c')) + self.assertTrue(not proto.optional_message.HasField('dummy2')) + self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), + proto.optional_message) + self.assertEqual(1, proto.optional_message.a) + self.assertEqual(1, proto.optional_message.b) + self.assertEqual(1, proto.optional_message.c) + + def testInitRepeatedKwargs(self): + proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) + self.assertTrue(proto.IsInitialized()) + self.assertEqual(1, proto.repeated_int32[0]) + self.assertEqual(2, proto.repeated_int32[1]) + self.assertEqual(3, proto.repeated_int32[2]) + + +class OptionsTest(unittest.TestCase): + + def testMessageOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertEqual(True, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + proto = unittest_pb2.TestAllTypes() + self.assertEqual(False, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + def testPackedOptions(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_int32 = 1 + proto.optional_double = 3.0 + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(False, field_descriptor.GetOptions().packed) + + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.append(1) + proto.packed_double.append(3.0) + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(True, field_descriptor.GetOptions().packed) + self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + field_descriptor.label) + + +class UtilityTest(unittest.TestCase): + + def testImergeSorted(self): + ImergeSorted = reflection._ImergeSorted + # Various types of emptiness. + self.assertEqual([], list(ImergeSorted())) + self.assertEqual([], list(ImergeSorted([]))) + self.assertEqual([], list(ImergeSorted([], []))) + + # One nonempty list. + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], []))) + self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3]))) + + # Merging some nonempty lists together. + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], []))) + + # Elements repeated across component iterators. + self.assertEqual([1, 2, 2, 3, 3], + list(ImergeSorted([1, 2], [3], [2, 3]))) + + # Elements repeated within an iterator. + self.assertEqual([1, 2, 2, 3, 3], + list(ImergeSorted([1, 2, 2], [3], [3]))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py new file mode 100755 index 0000000..e04f825 --- /dev/null +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -0,0 +1,136 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.internal.service_reflection.""" + +__author__ = 'petar@google.com (Petar Petrov)' + +import unittest +from google.protobuf import unittest_pb2 +from google.protobuf import service_reflection +from google.protobuf import service + + +class FooUnitTest(unittest.TestCase): + + def testService(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, response, callback): + self.method = method + self.controller = controller + self.request = request + callback(response) + + class MockRpcController(service.RpcController): + def SetFailed(self, msg): + self.failure_message = msg + + self.callback_response = None + + class MyService(unittest_pb2.TestService): + pass + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + rpc_controller = MockRpcController() + channel = MockRpcChannel() + srvc = MyService() + srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback) + self.assertEqual('Method Foo not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + rpc_controller.failure_message = None + + service_descriptor = unittest_pb2.TestService.GetDescriptor() + srvc.CallMethod(service_descriptor.methods[1], rpc_controller, + unittest_pb2.BarRequest(), MyCallback) + self.assertEqual('Method Bar not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + class MyServiceImpl(unittest_pb2.TestService): + def Foo(self, rpc_controller, request, done): + self.foo_called = True + def Bar(self, rpc_controller, request, done): + self.bar_called = True + + srvc = MyServiceImpl() + rpc_controller.failure_message = None + srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback) + self.assertEqual(None, rpc_controller.failure_message) + self.assertEqual(True, srvc.foo_called) + + rpc_controller.failure_message = None + srvc.CallMethod(service_descriptor.methods[1], rpc_controller, + unittest_pb2.BarRequest(), MyCallback) + self.assertEqual(None, rpc_controller.failure_message) + self.assertEqual(True, srvc.bar_called) + + def testServiceStub(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, + response_class, callback): + self.method = method + self.controller = controller + self.request = request + callback(response_class()) + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + channel = MockRpcChannel() + stub = unittest_pb2.TestService_Stub(channel) + rpc_controller = 'controller' + request = 'request' + + # GetDescriptor now static, still works as instance method for compatability + self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(), + stub.GetDescriptor()) + + # Invoke method. + stub.Foo(rpc_controller, request, MyCallback) + + self.assertTrue(isinstance(self.callback_response, + unittest_pb2.FooResponse)) + self.assertEqual(request, channel.request) + self.assertEqual(rpc_controller, channel.controller) + self.assertEqual(stub.GetDescriptor().methods[0], channel.method) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py new file mode 100755 index 0000000..1a0da55 --- /dev/null +++ b/python/google/protobuf/internal/test_util.py @@ -0,0 +1,612 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Utilities for Python proto2 tests. + +This is intentionally modeled on C++ code in +//net/proto2/internal/test_util.*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import os.path + +import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 + + +def SetAllFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllTypes instance. + """ + + # + # Optional fields. + # + + message.optional_int32 = 101 + message.optional_int64 = 102 + message.optional_uint32 = 103 + message.optional_uint64 = 104 + message.optional_sint32 = 105 + message.optional_sint64 = 106 + message.optional_fixed32 = 107 + message.optional_fixed64 = 108 + message.optional_sfixed32 = 109 + message.optional_sfixed64 = 110 + message.optional_float = 111 + message.optional_double = 112 + message.optional_bool = True + # TODO(robinson): Firmly spec out and test how + # protos interact with unicode. One specific example: + # what happens if we change the literal below to + # u'115'? What *should* happen? Still some discussion + # to finish with Kenton about bytes vs. strings + # and forcing everything to be utf8. :-/ + message.optional_string = '115' + message.optional_bytes = '116' + + message.optionalgroup.a = 117 + message.optional_nested_message.bb = 118 + message.optional_foreign_message.c = 119 + message.optional_import_message.d = 120 + + message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ + message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ + message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ + + message.optional_string_piece = '124' + message.optional_cord = '125' + + # + # Repeated fields. + # + + message.repeated_int32.append(201) + message.repeated_int64.append(202) + message.repeated_uint32.append(203) + message.repeated_uint64.append(204) + message.repeated_sint32.append(205) + message.repeated_sint64.append(206) + message.repeated_fixed32.append(207) + message.repeated_fixed64.append(208) + message.repeated_sfixed32.append(209) + message.repeated_sfixed64.append(210) + message.repeated_float.append(211) + message.repeated_double.append(212) + message.repeated_bool.append(True) + message.repeated_string.append('215') + message.repeated_bytes.append('216') + + message.repeatedgroup.add().a = 217 + message.repeated_nested_message.add().bb = 218 + message.repeated_foreign_message.add().c = 219 + message.repeated_import_message.add().d = 220 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) + + message.repeated_string_piece.append('224') + message.repeated_cord.append('225') + + # Add a second one of each field. + message.repeated_int32.append(301) + message.repeated_int64.append(302) + message.repeated_uint32.append(303) + message.repeated_uint64.append(304) + message.repeated_sint32.append(305) + message.repeated_sint64.append(306) + message.repeated_fixed32.append(307) + message.repeated_fixed64.append(308) + message.repeated_sfixed32.append(309) + message.repeated_sfixed64.append(310) + message.repeated_float.append(311) + message.repeated_double.append(312) + message.repeated_bool.append(False) + message.repeated_string.append('315') + message.repeated_bytes.append('316') + + message.repeatedgroup.add().a = 317 + message.repeated_nested_message.add().bb = 318 + message.repeated_foreign_message.add().c = 319 + message.repeated_import_message.add().d = 320 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) + + message.repeated_string_piece.append('324') + message.repeated_cord.append('325') + + # + # Fields that have defaults. + # + + message.default_int32 = 401 + message.default_int64 = 402 + message.default_uint32 = 403 + message.default_uint64 = 404 + message.default_sint32 = 405 + message.default_sint64 = 406 + message.default_fixed32 = 407 + message.default_fixed64 = 408 + message.default_sfixed32 = 409 + message.default_sfixed64 = 410 + message.default_float = 411 + message.default_double = 412 + message.default_bool = False + message.default_string = '415' + message.default_bytes = '416' + + message.default_nested_enum = unittest_pb2.TestAllTypes.FOO + message.default_foreign_enum = unittest_pb2.FOREIGN_FOO + message.default_import_enum = unittest_import_pb2.IMPORT_FOO + + message.default_string_piece = '424' + message.default_cord = '425' + + +def SetAllExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions instance. + """ + + extensions = message.Extensions + pb2 = unittest_pb2 + import_pb2 = unittest_import_pb2 + + # + # Optional fields. + # + + extensions[pb2.optional_int32_extension] = 101 + extensions[pb2.optional_int64_extension] = 102 + extensions[pb2.optional_uint32_extension] = 103 + extensions[pb2.optional_uint64_extension] = 104 + extensions[pb2.optional_sint32_extension] = 105 + extensions[pb2.optional_sint64_extension] = 106 + extensions[pb2.optional_fixed32_extension] = 107 + extensions[pb2.optional_fixed64_extension] = 108 + extensions[pb2.optional_sfixed32_extension] = 109 + extensions[pb2.optional_sfixed64_extension] = 110 + extensions[pb2.optional_float_extension] = 111 + extensions[pb2.optional_double_extension] = 112 + extensions[pb2.optional_bool_extension] = True + extensions[pb2.optional_string_extension] = '115' + extensions[pb2.optional_bytes_extension] = '116' + + extensions[pb2.optionalgroup_extension].a = 117 + extensions[pb2.optional_nested_message_extension].bb = 118 + extensions[pb2.optional_foreign_message_extension].c = 119 + extensions[pb2.optional_import_message_extension].d = 120 + + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ + extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ + + extensions[pb2.optional_string_piece_extension] = '124' + extensions[pb2.optional_cord_extension] = '125' + + # + # Repeated fields. + # + + extensions[pb2.repeated_int32_extension].append(201) + extensions[pb2.repeated_int64_extension].append(202) + extensions[pb2.repeated_uint32_extension].append(203) + extensions[pb2.repeated_uint64_extension].append(204) + extensions[pb2.repeated_sint32_extension].append(205) + extensions[pb2.repeated_sint64_extension].append(206) + extensions[pb2.repeated_fixed32_extension].append(207) + extensions[pb2.repeated_fixed64_extension].append(208) + extensions[pb2.repeated_sfixed32_extension].append(209) + extensions[pb2.repeated_sfixed64_extension].append(210) + extensions[pb2.repeated_float_extension].append(211) + extensions[pb2.repeated_double_extension].append(212) + extensions[pb2.repeated_bool_extension].append(True) + extensions[pb2.repeated_string_extension].append('215') + extensions[pb2.repeated_bytes_extension].append('216') + + extensions[pb2.repeatedgroup_extension].add().a = 217 + extensions[pb2.repeated_nested_message_extension].add().bb = 218 + extensions[pb2.repeated_foreign_message_extension].add().c = 219 + extensions[pb2.repeated_import_message_extension].add().d = 220 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) + + extensions[pb2.repeated_string_piece_extension].append('224') + extensions[pb2.repeated_cord_extension].append('225') + + # Append a second one of each field. + extensions[pb2.repeated_int32_extension].append(301) + extensions[pb2.repeated_int64_extension].append(302) + extensions[pb2.repeated_uint32_extension].append(303) + extensions[pb2.repeated_uint64_extension].append(304) + extensions[pb2.repeated_sint32_extension].append(305) + extensions[pb2.repeated_sint64_extension].append(306) + extensions[pb2.repeated_fixed32_extension].append(307) + extensions[pb2.repeated_fixed64_extension].append(308) + extensions[pb2.repeated_sfixed32_extension].append(309) + extensions[pb2.repeated_sfixed64_extension].append(310) + extensions[pb2.repeated_float_extension].append(311) + extensions[pb2.repeated_double_extension].append(312) + extensions[pb2.repeated_bool_extension].append(False) + extensions[pb2.repeated_string_extension].append('315') + extensions[pb2.repeated_bytes_extension].append('316') + + extensions[pb2.repeatedgroup_extension].add().a = 317 + extensions[pb2.repeated_nested_message_extension].add().bb = 318 + extensions[pb2.repeated_foreign_message_extension].add().c = 319 + extensions[pb2.repeated_import_message_extension].add().d = 320 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) + + extensions[pb2.repeated_string_piece_extension].append('324') + extensions[pb2.repeated_cord_extension].append('325') + + # + # Fields with defaults. + # + + extensions[pb2.default_int32_extension] = 401 + extensions[pb2.default_int64_extension] = 402 + extensions[pb2.default_uint32_extension] = 403 + extensions[pb2.default_uint64_extension] = 404 + extensions[pb2.default_sint32_extension] = 405 + extensions[pb2.default_sint64_extension] = 406 + extensions[pb2.default_fixed32_extension] = 407 + extensions[pb2.default_fixed64_extension] = 408 + extensions[pb2.default_sfixed32_extension] = 409 + extensions[pb2.default_sfixed64_extension] = 410 + extensions[pb2.default_float_extension] = 411 + extensions[pb2.default_double_extension] = 412 + extensions[pb2.default_bool_extension] = False + extensions[pb2.default_string_extension] = '415' + extensions[pb2.default_bytes_extension] = '416' + + extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO + extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO + extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO + + extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_cord_extension] = '425' + + +def SetAllFieldsAndExtensions(message): + """Sets every field and extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions message. + """ + message.my_int = 1 + message.my_string = 'foo' + message.my_float = 1.0 + message.Extensions[unittest_pb2.my_extension_int] = 23 + message.Extensions[unittest_pb2.my_extension_string] = 'bar' + + +def ExpectAllFieldsAndExtensionsInOrder(serialized): + """Ensures that serialized is the serialization we expect for a message + filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the + serialization is in canonical, tag-number order). + """ + my_extension_int = unittest_pb2.my_extension_int + my_extension_string = unittest_pb2.my_extension_string + expected_strings = [] + message = unittest_pb2.TestFieldOrderings() + message.my_int = 1 # Field 1. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_int] = 23 # Field 5. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_string = 'foo' # Field 11. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_string] = 'bar' # Field 50. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_float = 1.0 + expected_strings.append(message.SerializeToString()) + message.Clear() + expected = ''.join(expected_strings) + + if expected != serialized: + raise ValueError('Expected %r, found %r' % (expected, serialized)) + + +class GoldenMessageTestCase(unittest.TestCase): + """This adds methods to TestCase useful for verifying our Golden Message.""" + + def ExpectAllFieldsSet(self, message): + """Check all fields for correct values have after Set*Fields() is called.""" + self.assertTrue(message.HasField('optional_int32')) + self.assertTrue(message.HasField('optional_int64')) + self.assertTrue(message.HasField('optional_uint32')) + self.assertTrue(message.HasField('optional_uint64')) + self.assertTrue(message.HasField('optional_sint32')) + self.assertTrue(message.HasField('optional_sint64')) + self.assertTrue(message.HasField('optional_fixed32')) + self.assertTrue(message.HasField('optional_fixed64')) + self.assertTrue(message.HasField('optional_sfixed32')) + self.assertTrue(message.HasField('optional_sfixed64')) + self.assertTrue(message.HasField('optional_float')) + self.assertTrue(message.HasField('optional_double')) + self.assertTrue(message.HasField('optional_bool')) + self.assertTrue(message.HasField('optional_string')) + self.assertTrue(message.HasField('optional_bytes')) + + self.assertTrue(message.HasField('optionalgroup')) + self.assertTrue(message.HasField('optional_nested_message')) + self.assertTrue(message.HasField('optional_foreign_message')) + self.assertTrue(message.HasField('optional_import_message')) + + self.assertTrue(message.optionalgroup.HasField('a')) + self.assertTrue(message.optional_nested_message.HasField('bb')) + self.assertTrue(message.optional_foreign_message.HasField('c')) + self.assertTrue(message.optional_import_message.HasField('d')) + + self.assertTrue(message.HasField('optional_nested_enum')) + self.assertTrue(message.HasField('optional_foreign_enum')) + self.assertTrue(message.HasField('optional_import_enum')) + + self.assertTrue(message.HasField('optional_string_piece')) + self.assertTrue(message.HasField('optional_cord')) + + self.assertEqual(101, message.optional_int32) + self.assertEqual(102, message.optional_int64) + self.assertEqual(103, message.optional_uint32) + self.assertEqual(104, message.optional_uint64) + self.assertEqual(105, message.optional_sint32) + self.assertEqual(106, message.optional_sint64) + self.assertEqual(107, message.optional_fixed32) + self.assertEqual(108, message.optional_fixed64) + self.assertEqual(109, message.optional_sfixed32) + self.assertEqual(110, message.optional_sfixed64) + self.assertEqual(111, message.optional_float) + self.assertEqual(112, message.optional_double) + self.assertEqual(True, message.optional_bool) + self.assertEqual('115', message.optional_string) + self.assertEqual('116', message.optional_bytes) + + self.assertEqual(117, message.optionalgroup.a); + self.assertEqual(118, message.optional_nested_message.bb) + self.assertEqual(119, message.optional_foreign_message.c) + self.assertEqual(120, message.optional_import_message.d) + + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum) + self.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.optional_import_enum) + + # ----------------------------------------------------------------- + + self.assertEqual(2, len(message.repeated_int32)) + self.assertEqual(2, len(message.repeated_int64)) + self.assertEqual(2, len(message.repeated_uint32)) + self.assertEqual(2, len(message.repeated_uint64)) + self.assertEqual(2, len(message.repeated_sint32)) + self.assertEqual(2, len(message.repeated_sint64)) + self.assertEqual(2, len(message.repeated_fixed32)) + self.assertEqual(2, len(message.repeated_fixed64)) + self.assertEqual(2, len(message.repeated_sfixed32)) + self.assertEqual(2, len(message.repeated_sfixed64)) + self.assertEqual(2, len(message.repeated_float)) + self.assertEqual(2, len(message.repeated_double)) + self.assertEqual(2, len(message.repeated_bool)) + self.assertEqual(2, len(message.repeated_string)) + self.assertEqual(2, len(message.repeated_bytes)) + + self.assertEqual(2, len(message.repeatedgroup)) + self.assertEqual(2, len(message.repeated_nested_message)) + self.assertEqual(2, len(message.repeated_foreign_message)) + self.assertEqual(2, len(message.repeated_import_message)) + self.assertEqual(2, len(message.repeated_nested_enum)) + self.assertEqual(2, len(message.repeated_foreign_enum)) + self.assertEqual(2, len(message.repeated_import_enum)) + + self.assertEqual(2, len(message.repeated_string_piece)) + self.assertEqual(2, len(message.repeated_cord)) + + self.assertEqual(201, message.repeated_int32[0]) + self.assertEqual(202, message.repeated_int64[0]) + self.assertEqual(203, message.repeated_uint32[0]) + self.assertEqual(204, message.repeated_uint64[0]) + self.assertEqual(205, message.repeated_sint32[0]) + self.assertEqual(206, message.repeated_sint64[0]) + self.assertEqual(207, message.repeated_fixed32[0]) + self.assertEqual(208, message.repeated_fixed64[0]) + self.assertEqual(209, message.repeated_sfixed32[0]) + self.assertEqual(210, message.repeated_sfixed64[0]) + self.assertEqual(211, message.repeated_float[0]) + self.assertEqual(212, message.repeated_double[0]) + self.assertEqual(True, message.repeated_bool[0]) + self.assertEqual('215', message.repeated_string[0]) + self.assertEqual('216', message.repeated_bytes[0]) + + self.assertEqual(217, message.repeatedgroup[0].a) + self.assertEqual(218, message.repeated_nested_message[0].bb) + self.assertEqual(219, message.repeated_foreign_message[0].c) + self.assertEqual(220, message.repeated_import_message[0].d) + + self.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[0]) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + message.repeated_foreign_enum[0]) + self.assertEqual(unittest_import_pb2.IMPORT_BAR, + message.repeated_import_enum[0]) + + self.assertEqual(301, message.repeated_int32[1]) + self.assertEqual(302, message.repeated_int64[1]) + self.assertEqual(303, message.repeated_uint32[1]) + self.assertEqual(304, message.repeated_uint64[1]) + self.assertEqual(305, message.repeated_sint32[1]) + self.assertEqual(306, message.repeated_sint64[1]) + self.assertEqual(307, message.repeated_fixed32[1]) + self.assertEqual(308, message.repeated_fixed64[1]) + self.assertEqual(309, message.repeated_sfixed32[1]) + self.assertEqual(310, message.repeated_sfixed64[1]) + self.assertEqual(311, message.repeated_float[1]) + self.assertEqual(312, message.repeated_double[1]) + self.assertEqual(False, message.repeated_bool[1]) + self.assertEqual('315', message.repeated_string[1]) + self.assertEqual('316', message.repeated_bytes[1]) + + self.assertEqual(317, message.repeatedgroup[1].a) + self.assertEqual(318, message.repeated_nested_message[1].bb) + self.assertEqual(319, message.repeated_foreign_message[1].c) + self.assertEqual(320, message.repeated_import_message[1].d) + + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.repeated_nested_enum[1]) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + message.repeated_foreign_enum[1]) + self.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.repeated_import_enum[1]) + + # ----------------------------------------------------------------- + + self.assertTrue(message.HasField('default_int32')) + self.assertTrue(message.HasField('default_int64')) + self.assertTrue(message.HasField('default_uint32')) + self.assertTrue(message.HasField('default_uint64')) + self.assertTrue(message.HasField('default_sint32')) + self.assertTrue(message.HasField('default_sint64')) + self.assertTrue(message.HasField('default_fixed32')) + self.assertTrue(message.HasField('default_fixed64')) + self.assertTrue(message.HasField('default_sfixed32')) + self.assertTrue(message.HasField('default_sfixed64')) + self.assertTrue(message.HasField('default_float')) + self.assertTrue(message.HasField('default_double')) + self.assertTrue(message.HasField('default_bool')) + self.assertTrue(message.HasField('default_string')) + self.assertTrue(message.HasField('default_bytes')) + + self.assertTrue(message.HasField('default_nested_enum')) + self.assertTrue(message.HasField('default_foreign_enum')) + self.assertTrue(message.HasField('default_import_enum')) + + self.assertEqual(401, message.default_int32) + self.assertEqual(402, message.default_int64) + self.assertEqual(403, message.default_uint32) + self.assertEqual(404, message.default_uint64) + self.assertEqual(405, message.default_sint32) + self.assertEqual(406, message.default_sint64) + self.assertEqual(407, message.default_fixed32) + self.assertEqual(408, message.default_fixed64) + self.assertEqual(409, message.default_sfixed32) + self.assertEqual(410, message.default_sfixed64) + self.assertEqual(411, message.default_float) + self.assertEqual(412, message.default_double) + self.assertEqual(False, message.default_bool) + self.assertEqual('415', message.default_string) + self.assertEqual('416', message.default_bytes) + + self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum) + self.assertEqual(unittest_import_pb2.IMPORT_FOO, + message.default_import_enum) + +def GoldenFile(filename): + """Finds the given golden file and returns a file object representing it.""" + + # Search up the directory tree looking for the C++ protobuf source code. + path = '.' + while os.path.exists(path): + if os.path.exists(os.path.join(path, 'src/google/protobuf')): + # Found it. Load the golden file from the testdata directory. + full_path = os.path.join(path, 'src/google/protobuf/testdata', filename) + return open(full_path, 'rb') + path = os.path.join(path, '..') + + raise RuntimeError( + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def SetAllPackedFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedTypes instance. + """ + message.packed_int32.extend([101, 102]) + message.packed_int64.extend([103, 104]) + message.packed_uint32.extend([105, 106]) + message.packed_uint64.extend([107, 108]) + message.packed_sint32.extend([109, 110]) + message.packed_sint64.extend([111, 112]) + message.packed_fixed32.extend([113, 114]) + message.packed_fixed64.extend([115, 116]) + message.packed_sfixed32.extend([117, 118]) + message.packed_sfixed64.extend([119, 120]) + message.packed_float.extend([121.0, 122.0]) + message.packed_double.extend([122.0, 123.0]) + message.packed_bool.extend([True, False]) + message.packed_enum.extend([unittest_pb2.FOREIGN_FOO, + unittest_pb2.FOREIGN_BAR]) + + +def SetAllPackedExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedExtensions instance. + """ + extensions = message.Extensions + pb2 = unittest_pb2 + + extensions[pb2.packed_int32_extension].append(101) + extensions[pb2.packed_int64_extension].append(102) + extensions[pb2.packed_uint32_extension].append(103) + extensions[pb2.packed_uint64_extension].append(104) + extensions[pb2.packed_sint32_extension].append(105) + extensions[pb2.packed_sint64_extension].append(106) + extensions[pb2.packed_fixed32_extension].append(107) + extensions[pb2.packed_fixed64_extension].append(108) + extensions[pb2.packed_sfixed32_extension].append(109) + extensions[pb2.packed_sfixed64_extension].append(110) + extensions[pb2.packed_float_extension].append(111.0) + extensions[pb2.packed_double_extension].append(112.0) + extensions[pb2.packed_bool_extension].append(True) + extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py new file mode 100755 index 0000000..0cf2718 --- /dev/null +++ b/python/google/protobuf/internal/text_format_test.py @@ -0,0 +1,397 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.text_format.""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import difflib + +import unittest +from google.protobuf import text_format +from google.protobuf.internal import test_util +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_mset_pb2 + + +class TextFormatTest(test_util.GoldenMessageTestCase): + def ReadGolden(self, golden_filename): + f = test_util.GoldenFile(golden_filename) + golden_lines = f.readlines() + f.close() + return golden_lines + + def CompareToGoldenFile(self, text, golden_filename): + golden_lines = self.ReadGolden(golden_filename) + self.CompareToGoldenLines(text, golden_lines) + + def CompareToGoldenText(self, text, golden_text): + self.CompareToGoldenLines(text, golden_text.splitlines(1)) + + def CompareToGoldenLines(self, text, golden_lines): + actual_lines = text.splitlines(1) + self.assertEqual(golden_lines, actual_lines, + "Text doesn't match golden. Diff:\n" + + ''.join(difflib.ndiff(golden_lines, actual_lines))) + + def testPrintAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data.txt') + + def testPrintAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintMessageSet(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText(text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + + def testPrintExotic(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808); + message.repeated_uint64.append(18446744073709551615); + message.repeated_double.append(123.456); + message.repeated_double.append(1.23e22); + message.repeated_double.append(1.23e-18); + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"'); + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string: ' + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + + def testMessageToString(self): + message = unittest_pb2.ForeignMessage() + message.c = 123 + self.assertEqual('c: 123\n', str(message)) + + def RemoveRedundantZeros(self, text): + # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove + # these zeros in order to match the golden file. + return text.replace('e+0','e+').replace('e+0','e+') \ + .replace('e-0','e-').replace('e-0','e-') + + def testMergeGolden(self): + golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) + parsed_message = unittest_pb2.TestAllTypes() + text_format.Merge(golden_text, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeGoldenExtensions(self): + golden_text = '\n'.join(self.ReadGolden( + 'text_format_unittest_extensions_data.txt')) + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Merge(golden_text, parsed_message) + + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.assertEquals(message, parsed_message) + + def testMergeAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllTypes() + text_format.Merge(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + self.ExpectAllFieldsSet(message) + + def testMergeAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Merge(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testMergeMessageSet(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_uint64: 1\n' + 'repeated_uint64: 2\n') + text_format.Merge(text, message) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Merge(text, message) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEquals(23, message.message_set.Extensions[ext1].i) + self.assertEquals('foo', message.message_set.Extensions[ext2].str) + + def testMergeExotic(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string: \n' + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + text_format.Merge(text, message) + + self.assertEqual(-9223372036854775808, message.repeated_int64[0]) + self.assertEqual(18446744073709551615, message.repeated_uint64[0]) + self.assertEqual(123.456, message.repeated_double[0]) + self.assertEqual(1.23e22, message.repeated_double[1]) + self.assertEqual(1.23e-18, message.repeated_double[2]) + self.assertEqual( + '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0]) + + def testMergeUnknownField(self): + message = unittest_pb2.TestAllTypes() + text = 'unknown_field: 8\n' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' + '"unknown_field".'), + text_format.Merge, text, message) + + def testMergeBadExtension(self): + message = unittest_pb2.TestAllTypes() + text = '[unknown_extension]: 8\n' + self.assertRaisesWithMessage( + text_format.ParseError, + '1:2 : Extension "unknown_extension" not registered.', + text_format.Merge, text, message) + + def testMergeGroupNotClosed(self): + message = unittest_pb2.TestAllTypes() + text = 'RepeatedGroup: <' + self.assertRaisesWithMessage( + text_format.ParseError, '1:16 : Expected ">".', + text_format.Merge, text, message) + + text = 'RepeatedGroup: {' + self.assertRaisesWithMessage( + text_format.ParseError, '1:16 : Expected "}".', + text_format.Merge, text, message) + + def testMergeBadEnumValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_nested_enum: BARR' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + 'has no value named BARR.'), + text_format.Merge, text, message) + + message = unittest_pb2.TestAllTypes() + text = 'optional_nested_enum: 100' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + 'has no value with number 100.'), + text_format.Merge, text, message) + + def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): + """Same as assertRaises, but also compares the exception message.""" + if hasattr(e_class, '__name__'): + exc_name = e_class.__name__ + else: + exc_name = str(e_class) + + try: + func(*args, **kwargs) + except e_class, expr: + if str(expr) != e: + msg = '%s raised, but with wrong message: "%s" instead of "%s"' + raise self.failureException(msg % (exc_name, + str(expr).encode('string_escape'), + e.encode('string_escape'))) + return + else: + raise self.failureException('%s not raised' % exc_name) + + +class TokenizerTest(unittest.TestCase): + + def testSimpleTokenCases(self): + text = ('identifier1:"string1"\n \n\n' + 'identifier2 : \n \n123 \n identifier3 :\'string\'\n' + 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' + 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' + 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' + 'ID12: 2222222222222222222') + tokenizer = text_format._Tokenizer(text) + methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), + ':', + (tokenizer.ConsumeString, 'string1'), + (tokenizer.ConsumeIdentifier, 'identifier2'), + ':', + (tokenizer.ConsumeInt32, 123), + (tokenizer.ConsumeIdentifier, 'identifier3'), + ':', + (tokenizer.ConsumeString, 'string'), + (tokenizer.ConsumeIdentifier, 'identifiER_4'), + ':', + (tokenizer.ConsumeFloat, 1.1e+2), + (tokenizer.ConsumeIdentifier, 'ID5'), + ':', + (tokenizer.ConsumeFloat, -0.23), + (tokenizer.ConsumeIdentifier, 'ID6'), + ':', + (tokenizer.ConsumeString, 'aaaa\'bbbb'), + (tokenizer.ConsumeIdentifier, 'ID7'), + ':', + (tokenizer.ConsumeString, 'aa\"bb'), + (tokenizer.ConsumeIdentifier, 'ID8'), + ':', + '{', + (tokenizer.ConsumeIdentifier, 'A'), + ':', + (tokenizer.ConsumeFloat, float('inf')), + (tokenizer.ConsumeIdentifier, 'B'), + ':', + (tokenizer.ConsumeFloat, float('-inf')), + (tokenizer.ConsumeIdentifier, 'C'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'D'), + ':', + (tokenizer.ConsumeBool, False), + '}', + (tokenizer.ConsumeIdentifier, 'ID9'), + ':', + (tokenizer.ConsumeUint32, 22), + (tokenizer.ConsumeIdentifier, 'ID10'), + ':', + (tokenizer.ConsumeInt64, -111111111111111111), + (tokenizer.ConsumeIdentifier, 'ID11'), + ':', + (tokenizer.ConsumeInt32, -22), + (tokenizer.ConsumeIdentifier, 'ID12'), + ':', + (tokenizer.ConsumeUint64, 2222222222222222222)] + + i = 0 + while not tokenizer.AtEnd(): + m = methods[i] + if type(m) == str: + token = tokenizer.token + self.assertEqual(token, m) + tokenizer.NextToken() + else: + self.assertEqual(m[1], m[0]()) + i += 1 + + def testConsumeIntegers(self): + # This test only tests the failures in the integer parsing methods as well + # as the '0' special cases. + int64_max = (1 << 63) - 1 + uint32_max = (1 << 32) - 1 + text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) + self.assertEqual(-1, tokenizer.ConsumeInt32()) + + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32) + self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64()) + + self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64) + self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64()) + self.assertTrue(tokenizer.AtEnd()) + + text = '-0 -0 0 0' + tokenizer = text_format._Tokenizer(text) + self.assertEqual(0, tokenizer.ConsumeUint32()) + self.assertEqual(0, tokenizer.ConsumeUint64()) + self.assertEqual(0, tokenizer.ConsumeUint32()) + self.assertEqual(0, tokenizer.ConsumeUint64()) + self.assertTrue(tokenizer.AtEnd()) + + def testConsumeByteString(self): + text = '"string1\'' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = 'string1"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\xt"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\x"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + def testConsumeBool(self): + text = 'not-a-bool' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py new file mode 100755 index 0000000..a3bc57f --- /dev/null +++ b/python/google/protobuf/internal/type_checkers.py @@ -0,0 +1,287 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides type checking routines. + +This module defines type checking utilities in the forms of dictionaries: + +VALUE_CHECKERS: A dictionary of field types and a value validation object. +TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing + function. +TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization + function. +FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their + coresponding wire types. +TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization + function. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import wire_format +from google.protobuf import descriptor + +_FieldDescriptor = descriptor.FieldDescriptor + + +def GetTypeChecker(cpp_type, field_type): + """Returns a type checker for a message field of the specified types. + + Args: + cpp_type: C++ type of the field (see descriptor.py). + field_type: Protocol message field type (see descriptor.py). + + Returns: + An instance of TypeChecker which can be used to verify the types + of values assigned to a field of the specified type. + """ + if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field_type == _FieldDescriptor.TYPE_STRING): + return UnicodeValueChecker() + return _VALUE_CHECKERS[cpp_type] + + +# None of the typecheckers below make any attempt to guard against people +# subclassing builtin types and doing weird things. We're not trying to +# protect against malicious clients here, just people accidentally shooting +# themselves in the foot in obvious ways. + +class TypeChecker(object): + + """Type checker used to catch type errors as early as possible + when the client is setting scalar fields in protocol messages. + """ + + def __init__(self, *acceptable_types): + self._acceptable_types = acceptable_types + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, self._acceptable_types): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), self._acceptable_types)) + raise TypeError(message) + + +# IntValueChecker and its subclasses perform integer type-checks +# and bounds-checks. +class IntValueChecker(object): + + """Checker used for integer fields. Performs type-check and range check.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if not self._MIN <= proposed_value <= self._MAX: + raise ValueError('Value out of range: %d' % proposed_value) + + +class UnicodeValueChecker(object): + + """Checker used for string fields.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (str, unicode)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (str, unicode))) + raise TypeError(message) + + # If the value is of type 'str' make sure that it is in 7-bit ASCII + # encoding. + if isinstance(proposed_value, str): + try: + unicode(proposed_value, 'ascii') + except UnicodeDecodeError: + raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + 'encoding. Non-ASCII strings must be converted to ' + 'unicode objects before being added.' % + (proposed_value)) + + +class Int32ValueChecker(IntValueChecker): + # We're sure to use ints instead of longs here since comparison may be more + # efficient. + _MIN = -2147483648 + _MAX = 2147483647 + + +class Uint32ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 32) - 1 + + +class Int64ValueChecker(IntValueChecker): + _MIN = -(1 << 63) + _MAX = (1 << 63) - 1 + + +class Uint64ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 64) - 1 + + +# Type-checkers for all scalar CPPTYPEs. +_VALUE_CHECKERS = { + _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), + _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + } + + +# Map from field type to a function F, such that F(field_num, value) +# gives the total byte size for a value of the given type. This +# byte size includes tag information and any other additional space +# associated with serializing "value". +TYPE_TO_BYTE_SIZE_FN = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, + _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, + _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, + _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, + _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, + _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, + _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, + _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, + _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, + _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, + _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, + _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, + _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, + _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, + _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, + _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, + _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, + _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize + } + + +# Maps from field type to an unbound Encoder method F, such that +# F(encoder, field_number, value) will append the serialization +# of a value of this type to the encoder. +_Encoder = encoder.Encoder +TYPE_TO_SERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, + _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, + _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, + _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, + _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, + _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, + _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, + _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, + _FieldDescriptor.TYPE_STRING: _Encoder.AppendString, + _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, + _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, + _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, + _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, + _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, + _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64, + _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32, + _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64, + } + + +TYPE_TO_NOTAG_SERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag, + _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag, + _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag, + _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag, + _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag, + _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag, + _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag, + _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag, + _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag, + _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag, + _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag, + _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag, + _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag, + _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag, + } + +# Maps from field type to expected wiretype. +FIELD_TYPE_TO_WIRE_TYPE = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_STRING: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, + _FieldDescriptor.TYPE_MESSAGE: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_BYTES: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, + } + + +# Maps from field type to an unbound Decoder method F, +# such that F(decoder) will read a field of the requested type. +# +# Note that Message and Group are intentionally missing here. +# They're handled by _RecursivelyMerge(). +_Decoder = decoder.Decoder +TYPE_TO_DESERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble, + _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat, + _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64, + _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64, + _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32, + _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64, + _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32, + _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool, + _FieldDescriptor.TYPE_STRING: _Decoder.ReadString, + _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes, + _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32, + _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum, + _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64, + _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32, + _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64, + } diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py new file mode 100755 index 0000000..da6464d --- /dev/null +++ b/python/google/protobuf/internal/wire_format.py @@ -0,0 +1,247 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Constants and static functions to support protocol buffer wire format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message + + +TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. +_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 + +# These numbers identify the wire type of a protocol buffer value. +# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded +# tag-and-type to store one of these WIRETYPE_* constants. +# These values must match WireType enum in //net/proto2/public/wire_format.h. +WIRETYPE_VARINT = 0 +WIRETYPE_FIXED64 = 1 +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 +WIRETYPE_END_GROUP = 4 +WIRETYPE_FIXED32 = 5 +_WIRETYPE_MAX = 5 + + +# Bounds for various integer types. +INT32_MAX = int((1 << 31) - 1) +INT32_MIN = int(-(1 << 31)) +UINT32_MAX = (1 << 32) - 1 + +INT64_MAX = (1 << 63) - 1 +INT64_MIN = -(1 << 63) +UINT64_MAX = (1 << 64) - 1 + +# "struct" format strings that will encode/decode the specified formats. +FORMAT_UINT32_LITTLE_ENDIAN = '<I' +FORMAT_UINT64_LITTLE_ENDIAN = '<Q' +FORMAT_FLOAT_LITTLE_ENDIAN = '<f' +FORMAT_DOUBLE_LITTLE_ENDIAN = '<d' + + +# We'll have to provide alternate implementations of AppendLittleEndian*() on +# any architectures where these checks fail. +if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4: + raise AssertionError('Format "I" is not a 32-bit number.') +if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8: + raise AssertionError('Format "Q" is not a 64-bit number.') + + +def PackTag(field_number, wire_type): + """Returns an unsigned 32-bit integer that encodes the field number and + wire type information in standard protocol message wire format. + + Args: + field_number: Expected to be an integer in the range [1, 1 << 29) + wire_type: One of the WIRETYPE_* constants. + """ + if not 0 <= wire_type <= _WIRETYPE_MAX: + raise message.EncodeError('Unknown wire type: %d' % wire_type) + return (field_number << TAG_TYPE_BITS) | wire_type + + +def UnpackTag(tag): + """The inverse of PackTag(). Given an unsigned 32-bit number, + returns a (field_number, wire_type) tuple. + """ + return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK) + + +def ZigZagEncode(value): + """ZigZag Transform: Encodes signed integers so that they can be + effectively used with varint encoding. See wire_format.h for + more details. + """ + if value >= 0: + return value << 1 + return (value << 1) ^ (~0) + + +def ZigZagDecode(value): + """Inverse of ZigZagEncode().""" + if not value & 0x1: + return value >> 1 + return (value >> 1) ^ (~0) + + + +# The *ByteSize() functions below return the number of bytes required to +# serialize "field number + type" information and then serialize the value. + + +def Int32ByteSize(field_number, int32): + return Int64ByteSize(field_number, int32) + + +def Int32ByteSizeNoTag(int32): + return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32) + + +def Int64ByteSize(field_number, int64): + # Have to convert to uint before calling UInt64ByteSize(). + return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) + + +def UInt32ByteSize(field_number, uint32): + return UInt64ByteSize(field_number, uint32) + + +def UInt64ByteSize(field_number, uint64): + return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + + +def SInt32ByteSize(field_number, int32): + return UInt32ByteSize(field_number, ZigZagEncode(int32)) + + +def SInt64ByteSize(field_number, int64): + return UInt64ByteSize(field_number, ZigZagEncode(int64)) + + +def Fixed32ByteSize(field_number, fixed32): + return TagByteSize(field_number) + 4 + + +def Fixed64ByteSize(field_number, fixed64): + return TagByteSize(field_number) + 8 + + +def SFixed32ByteSize(field_number, sfixed32): + return TagByteSize(field_number) + 4 + + +def SFixed64ByteSize(field_number, sfixed64): + return TagByteSize(field_number) + 8 + + +def FloatByteSize(field_number, flt): + return TagByteSize(field_number) + 4 + + +def DoubleByteSize(field_number, double): + return TagByteSize(field_number) + 8 + + +def BoolByteSize(field_number, b): + return TagByteSize(field_number) + 1 + + +def EnumByteSize(field_number, enum): + return UInt32ByteSize(field_number, enum) + + +def StringByteSize(field_number, string): + return BytesByteSize(field_number, string.encode('utf-8')) + + +def BytesByteSize(field_number, b): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(len(b)) + + len(b)) + + +def GroupByteSize(field_number, message): + return (2 * TagByteSize(field_number) # START and END group. + + message.ByteSize()) + + +def MessageByteSize(field_number, message): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(message.ByteSize()) + + message.ByteSize()) + + +def MessageSetItemByteSize(field_number, msg): + # First compute the sizes of the tags. + # There are 2 tags for the beginning and ending of the repeated group, that + # is field number 1, one with field number 2 (type_id) and one with field + # number 3 (message). + total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3)) + + # Add the number of bytes for type_id. + total_size += _VarUInt64ByteSizeNoTag(field_number) + + message_size = msg.ByteSize() + + # The number of bytes for encoding the length of the message. + total_size += _VarUInt64ByteSizeNoTag(message_size) + + # The size of the message. + total_size += message_size + return total_size + + +def TagByteSize(field_number): + """Returns the bytes required to serialize a tag with this field number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) + + +# Private helper function for the *ByteSize() functions above. + +def _VarUInt64ByteSizeNoTag(uint64): + """Returns the number of bytes required to serialize a single varint + using boundary value comparisons. (unrolled loop optimization -WPierce) + uint64 must be unsigned. + """ + if uint64 <= 0x7f: return 1 + if uint64 <= 0x3fff: return 2 + if uint64 <= 0x1fffff: return 3 + if uint64 <= 0xfffffff: return 4 + if uint64 <= 0x7ffffffff: return 5 + if uint64 <= 0x3ffffffffff: return 6 + if uint64 <= 0x1ffffffffffff: return 7 + if uint64 <= 0xffffffffffffff: return 8 + if uint64 <= 0x7fffffffffffffff: return 9 + if uint64 > UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % uint64) + return 10 diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py new file mode 100755 index 0000000..7600778 --- /dev/null +++ b/python/google/protobuf/internal/wire_format_test.py @@ -0,0 +1,253 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.wire_format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import wire_format + + +class WireFormatTest(unittest.TestCase): + + def testPackTag(self): + field_number = 0xabc + tag_type = 2 + self.assertEqual((field_number << 3) | tag_type, + wire_format.PackTag(field_number, tag_type)) + PackTag = wire_format.PackTag + # Number too high. + self.assertRaises(message.EncodeError, PackTag, field_number, 6) + # Number too low. + self.assertRaises(message.EncodeError, PackTag, field_number, -1) + + def testUnpackTag(self): + # Test field numbers that will require various varint sizes. + for expected_field_number in (1, 15, 16, 2047, 2048): + for expected_wire_type in range(6): # Highest-numbered wiretype is 5. + field_number, wire_type = wire_format.UnpackTag( + wire_format.PackTag(expected_field_number, expected_wire_type)) + self.assertEqual(expected_field_number, field_number) + self.assertEqual(expected_wire_type, wire_type) + + self.assertRaises(TypeError, wire_format.UnpackTag, None) + self.assertRaises(TypeError, wire_format.UnpackTag, 'abc') + self.assertRaises(TypeError, wire_format.UnpackTag, 0.0) + self.assertRaises(TypeError, wire_format.UnpackTag, object()) + + def testZigZagEncode(self): + Z = wire_format.ZigZagEncode + self.assertEqual(0, Z(0)) + self.assertEqual(1, Z(-1)) + self.assertEqual(2, Z(1)) + self.assertEqual(3, Z(-2)) + self.assertEqual(4, Z(2)) + self.assertEqual(0xfffffffe, Z(0x7fffffff)) + self.assertEqual(0xffffffff, Z(-0x80000000)) + self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff)) + self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def testZigZagDecode(self): + Z = wire_format.ZigZagDecode + self.assertEqual(0, Z(0)) + self.assertEqual(-1, Z(1)) + self.assertEqual(1, Z(2)) + self.assertEqual(-2, Z(3)) + self.assertEqual(2, Z(4)) + self.assertEqual(0x7fffffff, Z(0xfffffffe)) + self.assertEqual(-0x80000000, Z(0xffffffff)) + self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe)) + self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size): + # Use field numbers that cause various byte sizes for the tag information. + for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)): + expected_size = expected_value_size + tag_bytes + actual_size = byte_size_fn(field_number, value) + self.assertEqual(expected_size, actual_size, + 'byte_size_fn: %s, field_number: %d, value: %r\n' + 'Expected: %d, Actual: %d'% ( + byte_size_fn, field_number, value, expected_size, actual_size)) + + def testByteSizeFunctions(self): + # Test all numeric *ByteSize() functions. + NUMERIC_ARGS = [ + # Int32ByteSize(). + [wire_format.Int32ByteSize, 0, 1], + [wire_format.Int32ByteSize, 127, 1], + [wire_format.Int32ByteSize, 128, 2], + [wire_format.Int32ByteSize, -1, 10], + # Int64ByteSize(). + [wire_format.Int64ByteSize, 0, 1], + [wire_format.Int64ByteSize, 127, 1], + [wire_format.Int64ByteSize, 128, 2], + [wire_format.Int64ByteSize, -1, 10], + # UInt32ByteSize(). + [wire_format.UInt32ByteSize, 0, 1], + [wire_format.UInt32ByteSize, 127, 1], + [wire_format.UInt32ByteSize, 128, 2], + [wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5], + # UInt64ByteSize(). + [wire_format.UInt64ByteSize, 0, 1], + [wire_format.UInt64ByteSize, 127, 1], + [wire_format.UInt64ByteSize, 128, 2], + [wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10], + # SInt32ByteSize(). + [wire_format.SInt32ByteSize, 0, 1], + [wire_format.SInt32ByteSize, -1, 1], + [wire_format.SInt32ByteSize, 1, 1], + [wire_format.SInt32ByteSize, -63, 1], + [wire_format.SInt32ByteSize, 63, 1], + [wire_format.SInt32ByteSize, -64, 1], + [wire_format.SInt32ByteSize, 64, 2], + # SInt64ByteSize(). + [wire_format.SInt64ByteSize, 0, 1], + [wire_format.SInt64ByteSize, -1, 1], + [wire_format.SInt64ByteSize, 1, 1], + [wire_format.SInt64ByteSize, -63, 1], + [wire_format.SInt64ByteSize, 63, 1], + [wire_format.SInt64ByteSize, -64, 1], + [wire_format.SInt64ByteSize, 64, 2], + # Fixed32ByteSize(). + [wire_format.Fixed32ByteSize, 0, 4], + [wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4], + # Fixed64ByteSize(). + [wire_format.Fixed64ByteSize, 0, 8], + [wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8], + # SFixed32ByteSize(). + [wire_format.SFixed32ByteSize, 0, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4], + # SFixed64ByteSize(). + [wire_format.SFixed64ByteSize, 0, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8], + # FloatByteSize(). + [wire_format.FloatByteSize, 0.0, 4], + [wire_format.FloatByteSize, 1000000000.0, 4], + [wire_format.FloatByteSize, -1000000000.0, 4], + # DoubleByteSize(). + [wire_format.DoubleByteSize, 0.0, 8], + [wire_format.DoubleByteSize, 1000000000.0, 8], + [wire_format.DoubleByteSize, -1000000000.0, 8], + # BoolByteSize(). + [wire_format.BoolByteSize, False, 1], + [wire_format.BoolByteSize, True, 1], + # EnumByteSize(). + [wire_format.EnumByteSize, 0, 1], + [wire_format.EnumByteSize, 127, 1], + [wire_format.EnumByteSize, 128, 2], + [wire_format.EnumByteSize, wire_format.UINT32_MAX, 5], + ] + for args in NUMERIC_ARGS: + self.NumericByteSizeTestHelper(*args) + + # Test strings and bytes. + for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize): + # 1 byte for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(5, byte_size_fn(10, 'abc')) + # 2 bytes for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(6, byte_size_fn(16, 'abc')) + # 2 bytes for tag, 2 bytes for length, 128 bytes for contents. + self.assertEqual(132, byte_size_fn(16, 'a' * 128)) + + # Test UTF-8 string byte size calculation. + # 1 byte for tag, 1 byte for length, 8 bytes for content. + self.assertEqual(10, wire_format.StringByteSize( + 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + + class MockMessage(object): + def __init__(self, byte_size): + self.byte_size = byte_size + def ByteSize(self): + return self.byte_size + + message_byte_size = 10 + mock_message = MockMessage(byte_size=message_byte_size) + # Test groups. + # (2 * 1) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(2 + message_byte_size, + wire_format.GroupByteSize(1, mock_message)) + # (2 * 2) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(4 + message_byte_size, + wire_format.GroupByteSize(16, mock_message)) + + # Test messages. + # 1 byte for tag, plus 1 byte for length, plus contents. + self.assertEqual(2 + mock_message.byte_size, + wire_format.MessageByteSize(1, mock_message)) + # 2 bytes for tag, plus 1 byte for length, plus contents. + self.assertEqual(3 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + # 2 bytes for tag, plus 2 bytes for length, plus contents. + mock_message.byte_size = 128 + self.assertEqual(4 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + + + # Test message set item byte size. + # 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 10 + self.assertEqual(mock_message.byte_size + 6, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 128 + self.assertEqual(mock_message.byte_size + 7, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id, + # plus contents. + self.assertEqual(mock_message.byte_size + 8, + wire_format.MessageSetItemByteSize(128, mock_message)) + + # Too-long varint. + self.assertRaises(message.EncodeError, + wire_format.UInt64ByteSize, 1, 1 << 128) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py new file mode 100755 index 0000000..9a88bdc --- /dev/null +++ b/python/google/protobuf/message.py @@ -0,0 +1,245 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We should just make these methods all "pure-virtual" and move +# all implementation out, into reflection.py for now. + + +"""Contains an abstract base class for protocol messages.""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +class Error(Exception): pass +class DecodeError(Error): pass +class EncodeError(Error): pass + + +class Message(object): + + """Abstract base class for protocol messages. + + Protocol message classes are almost always generated by the protocol + compiler. These generated types subclass Message and implement the methods + shown below. + + TODO(robinson): Link to an HTML document here. + + TODO(robinson): Document that instances of this class will also + have an Extensions attribute with __getitem__ and __setitem__. + Again, not sure how to best convey this. + + TODO(robinson): Document that the class must also have a static + RegisterExtension(extension_field) method. + Not sure how to best express at this point. + """ + + # TODO(robinson): Document these fields and methods. + + __slots__ = [] + + DESCRIPTOR = None + + def __eq__(self, other_msg): + raise NotImplementedError + + def __ne__(self, other_msg): + # Can't just say self != other_msg, since that would infinitely recurse. :) + return not self == other_msg + + def __str__(self): + raise NotImplementedError + + def MergeFrom(self, other_msg): + """Merges the contents of the specified message into current message. + + This method merges the contents of the specified message into the current + message. Singular fields that are set in the specified message overwrite + the corresponding fields in the current message. Repeated fields are + appended. Singular sub-messages and groups are recursively merged. + + Args: + other_msg: Message to merge into the current message. + """ + raise NotImplementedError + + def CopyFrom(self, other_msg): + """Copies the content of the specified message into the current message. + + The method clears the current message and then merges the specified + message using MergeFrom. + + Args: + other_msg: Message to copy into the current one. + """ + if self == other_msg: + return + self.Clear() + self.MergeFrom(other_msg) + + def Clear(self): + """Clears all data that was set in the message.""" + raise NotImplementedError + + def IsInitialized(self): + """Checks if the message is initialized. + + Returns: + The method returns True if the message is initialized (i.e. all of its + required fields are set). + """ + raise NotImplementedError + + # TODO(robinson): MergeFromString() should probably return None and be + # implemented in terms of a helper that returns the # of bytes read. Our + # deserialization routines would use the helper when recursively + # deserializing, but the end user would almost always just want the no-return + # MergeFromString(). + + def MergeFromString(self, serialized): + """Merges serialized protocol buffer data into this message. + + When we find a field in |serialized| that is already present + in this message: + - If it's a "repeated" field, we append to the end of our list. + - Else, if it's a scalar, we overwrite our field. + - Else, (it's a nonrepeated composite), we recursively merge + into the existing composite. + + TODO(robinson): Document handling of unknown fields. + + Args: + serialized: Any object that allows us to call buffer(serialized) + to access a string of bytes using the buffer interface. + + TODO(robinson): When we switch to a helper, this will return None. + + Returns: + The number of bytes read from |serialized|. + For non-group messages, this will always be len(serialized), + but for messages which are actually groups, this will + generally be less than len(serialized), since we must + stop when we reach an END_GROUP tag. Note that if + we *do* stop because of an END_GROUP tag, the number + of bytes returned does not include the bytes + for the END_GROUP tag information. + """ + raise NotImplementedError + + def ParseFromString(self, serialized): + """Like MergeFromString(), except we clear the object first.""" + self.Clear() + self.MergeFromString(serialized) + + def SerializeToString(self): + """Serializes the protocol message to a binary string. + + Returns: + A binary string representation of the message if all of the required + fields in the message are set (i.e. the message is initialized). + + Raises: + message.EncodeError if the message isn't initialized. + """ + raise NotImplementedError + + def SerializePartialToString(self): + """Serializes the protocol message to a binary string. + + This method is similar to SerializeToString but doesn't check if the + message is initialized. + + Returns: + A string representation of the partial message. + """ + raise NotImplementedError + + # TODO(robinson): Decide whether we like these better + # than auto-generated has_foo() and clear_foo() methods + # on the instances themselves. This way is less consistent + # with C++, but it makes reflection-type access easier and + # reduces the number of magically autogenerated things. + # + # TODO(robinson): Be sure to document (and test) exactly + # which field names are accepted here. Are we case-sensitive? + # What do we do with fields that share names with Python keywords + # like 'lambda' and 'yield'? + # + # nnorwitz says: + # """ + # Typically (in python), an underscore is appended to names that are + # keywords. So they would become lambda_ or yield_. + # """ + def ListFields(self): + """Returns a list of (FieldDescriptor, value) tuples for all + fields in the message which are not empty. A singular field is non-empty + if HasField() would return true, and a repeated field is non-empty if + it contains at least one element. The fields are ordered by field + number""" + raise NotImplementedError + + def HasField(self, field_name): + raise NotImplementedError + + def ClearField(self, field_name): + raise NotImplementedError + + def HasExtension(self, extension_handle): + raise NotImplementedError + + def ClearExtension(self, extension_handle): + raise NotImplementedError + + def ByteSize(self): + """Returns the serialized size of this message. + Recursively calls ByteSize() on all contained messages. + """ + raise NotImplementedError + + def _SetListener(self, message_listener): + """Internal method used by the protocol message implementation. + Clients should not call this directly. + + Sets a listener that this message will call on certain state transitions. + + The purpose of this method is to register back-edges from children to + parents at runtime, for the purpose of setting "has" bits and + byte-size-dirty bits in the parent and ancestor objects whenever a child or + descendant object is modified. + + If the client wants to disconnect this Message from the object tree, she + explicitly sets callback to None. + + If message_listener is None, unregisters any existing listener. Otherwise, + message_listener must implement the MessageListener interface in + internal/message_listener.py, and we discard any listener registered + via a previous _SetListener() call. + """ + raise NotImplementedError diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py new file mode 100755 index 0000000..d65d8b6 --- /dev/null +++ b/python/google/protobuf/reflection.py @@ -0,0 +1,1661 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import heapq +import threading +import weakref +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import containers +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod +from google.protobuf import text_format + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddSlots(descriptor, dictionary) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + superclass = super(GeneratedProtocolMessageType, cls) + return superclass.__new__(cls, name, bases, dictionary) + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + # We act as a "friend" class of the descriptor, setting + # its _concrete_class attribute the first time we use a + # given descriptor to initialize a concrete protocol message + # class. + concrete_class_attr_name = '_concrete_class' + if not hasattr(descriptor, concrete_class_attr_name): + setattr(descriptor, concrete_class_attr_name, cls) + cls._known_extensions = [] + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(cls) + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + return proto_field_name + + +def _ValueFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which objects + should use to store the current value for a given protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_value_' + proto_field_name + + +def _HasFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which + objects should use to store a boolean telling whether this field + is explicitly set or not. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_has_' + proto_field_name + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] + field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields + if f.label != _FieldDescriptor.LABEL_REPEATED) + field_names.extend(('Extensions', + '_cached_byte_size', + '_cached_byte_size_dirty', + '_called_transition_to_nonempty', + '_listener', + '_lock', '__weakref__')) + dictionary['__slots__'] = field_names + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueForField(message, field): + """Returns a default value for a field. + + Args: + message: Message instance containing this field, or a weakref proxy + of same. + field: FieldDescriptor object for this field. + + Returns: A default value for this field. May refer back to |message| + via a weak reference. + """ + # TODO(robinson): Only the repeated fields need a reference to 'message' (so + # that they can set the 'has' bit on the containing Message when someone + # append()s a value). We could special-case this, and avoid an extra + # function call on __init__() and Clear() for non-repeated fields. + + # TODO(robinson): Find a better place for the default value assertion in this + # function. No need to repeat them every time the client calls Clear('foo'). + # (We should probably just assert these things once and as early as possible, + # by tightening checking in the descriptor classes.) + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + listener = _Listener(message, None) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + return containers.RepeatedCompositeFieldContainer( + listener, field.message_type) + else: + return containers.RepeatedScalarFieldContainer( + listener, type_checkers.GetTypeChecker(field.cpp_type, field.type)) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + assert field.default_value is None + + return field.default_value + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self, **kwargs): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = False + self._listener = message_listener_mod.NullMessageListener() + self._called_transition_to_nonempty = False + # TODO(robinson): We should only create a lock if we really need one + # in this class. + self._lock = threading.Lock() + for field in fields: + default_value = _DefaultValueForField(self, field) + python_field_name = _ValueFieldName(field.name) + setattr(self, python_field_name, default_value) + if field.label != _FieldDescriptor.LABEL_REPEATED: + setattr(self, _HasFieldName(field.name), False) + self.Extensions = _ExtensionDict(self, cls._known_extensions) + for field_name, field_value in kwargs.iteritems(): + field = _GetFieldByName(message_descriptor, field_name) + _MergeFieldOrExtension(self, field, field_value) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + constant_name = field.name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, field.number) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def setter(self, new_value): + type_checker.CheckValue(new_value) + setattr(self, has_field_name, True) + self._MarkByteSizeDirty() + self._MaybeCallTransitionToNonemptyCallback() + setattr(self, python_field_name, new_value) + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + message_type = field.message_type + + def getter(self): + # TODO(robinson): Appropriately scary note about double-checked locking. + field_value = getattr(self, python_field_name) + if field_value is None: + self._lock.acquire() + try: + field_value = getattr(self, python_field_name) + if field_value is None: + field_class = message_type._concrete_class + field_value = field_class() + field_value._SetListener(_Listener(self, has_field_name)) + setattr(self, python_field_name, field_value) + finally: + self._lock.release() + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForExtensions(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, extension_field.number) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + cls._known_extensions.append(extension_handle) + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + # Ensure that we always list in ascending field-number order. + # For non-extension fields, we can do the sort once, here, at import-time. + # For extensions, we sort on each ListFields() call, though + # we could do better if we have to. + fields = sorted(message_descriptor.fields, key=lambda f: f.number) + has_field_names = (_HasFieldName(f.name) for f in fields) + value_field_names = (_ValueFieldName(f.name) for f in fields) + triplets = zip(has_field_names, value_field_names, fields) + + def ListFields(self): + # We need to list all extension and non-extension fields + # together, in sorted order by field number. + + # Step 0: Get an iterator over all "set" non-extension fields, + # sorted by field number. + # This iterator yields (field_number, field_descriptor, value) tuples. + def SortedSetFieldsIter(): + # Note that triplets is already sorted by field number. + for has_field_name, value_field_name, field_descriptor in triplets: + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + value = getattr(self, _ValueFieldName(field_descriptor.name)) + if len(value) > 0: + yield (field_descriptor.number, field_descriptor, value) + elif getattr(self, _HasFieldName(field_descriptor.name)): + value = getattr(self, _ValueFieldName(field_descriptor.name)) + yield (field_descriptor.number, field_descriptor, value) + sorted_fields = SortedSetFieldsIter() + + # Step 1: Get an iterator over all "set" extension fields, + # sorted by field number. + # This iterator ALSO yields (field_number, field_descriptor, value) tuples. + # TODO(robinson): It's not necessary to repeat this with each + # serialization call. We can do better. + sorted_extension_fields = sorted( + [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) + + # Step 2: Create a composite iterator that merges the extension- + # and non-extension fields, and that still yields fields in + # sorted order. + all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) + + # Step 3: Strip off the field numbers and return. + return [field[1:] for field in all_set_fields] + + cls.ListFields = ListFields + +def _AddHasFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def HasField(self, field_name): + try: + return getattr(self, _HasFieldName(field_name)) + except AttributeError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + cls.HasField = HasField + + +def _AddClearFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + field = _GetFieldByName(self.DESCRIPTOR, field_name) + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + default_value = _DefaultValueForField(self, field) + if field.label == _FieldDescriptor.LABEL_REPEATED: + self._MarkByteSizeDirty() + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + old_field_value = getattr(self, python_field_name) + if old_field_value is not None: + # Snip the old object out of the object tree. + old_field_value._SetListener(None) + if getattr(self, has_field_name): + setattr(self, has_field_name, False) + # Set dirty bit on ourself and parents only if + # we're actually changing state. + self._MarkByteSizeDirty() + setattr(self, python_field_name, default_value) + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + self.Extensions._ClearExtension(extension_handle) + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + fields = self.DESCRIPTOR.fields + for field in fields: + self.ClearField(field.name) + # Clear extensions. + extensions = self.Extensions._ListSetExtensions() + for extension in extensions: + self.ClearExtension(extension[0]) + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + return self.Extensions._HasExtension(extension_handle) + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + + if self is other: + return True + + # Compare all fields contained directly in this message. + for field_descriptor in message_descriptor.fields: + label = field_descriptor.label + property_name = _PropertyName(field_descriptor.name) + # Non-repeated field equality requires matching "has" bits as well + # as having an equal value. + if label != _FieldDescriptor.LABEL_REPEATED: + self_has = self.HasField(property_name) + other_has = other.HasField(property_name) + if self_has != other_has: + return False + if not self_has: + # If the "has" bit for this field is False, we must stop here. + # Otherwise we will recurse forever on recursively-defined protos. + continue + if getattr(self, property_name) != getattr(other, property_name): + return False + + # Compare the extensions present in both messages. + return self.Extensions == other.Extensions + cls.__eq__ = __eq__ + + +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def BytesForField(message, field, value): + """Returns the number of bytes required to serialize a single field + in message. The field may be repeated or not, composite or not. + + Args: + message: The Message instance containing a field of the given type. + field: A FieldDescriptor describing the field of interest. + value: The value whose byte size we're interested in. + + Returns: The number of bytes required to serialize the current value + of "field" in "message", including space for tags and any other + necessary information. + """ + + if _MessageSetField(field): + return wire_format.MessageSetItemByteSize(field.number, value) + + field_number, field_type = field.number, field.type + + # Repeated fields. + if field.label == _FieldDescriptor.LABEL_REPEATED: + elements = value + else: + elements = [value] + + if field.GetOptions().packed: + content_size = _ContentBytesForPackedField(message, field, elements) + if content_size: + tag_size = wire_format.TagByteSize(field_number) + length_size = wire_format.Int32ByteSizeNoTag(content_size) + return tag_size + length_size + content_size + else: + return 0 + else: + return sum(_BytesForNonRepeatedElement(element, field_number, field_type) + for element in elements) + + def _ContentBytesForPackedField(self, field, value): + """Returns the number of bytes required to serialize the actual + content of a packed field (not including the tag or the encoding + of the length. + + Args: + self: The Message instance containing a field of the given type. + field: A FieldDescriptor describing the field of interest. + value: The value whose byte size we're interested in. + + Returns: The number of bytes required to serialize the current value + of the packed "field" in "message", excluding space for tags and the + length encoding. + """ + size = sum(_BytesForNonRepeatedElement(element, field.number, field.type) + for element in value) + # In the packed case, there are no per element tags. + return size - wire_format.TagByteSize(field.number) * len(value) + + fields = message_descriptor.fields + has_field_names = (_HasFieldName(f.name) for f in fields) + zipped = zip(has_field_names, fields) + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + # Hardcoded fields first. + for has_field_name, field in zipped: + if (field.label == _FieldDescriptor.LABEL_REPEATED + or getattr(self, has_field_name)): + value = getattr(self, _ValueFieldName(field.name)) + size += BytesForField(self, field, value) + # Extensions next. + for field, value in self.Extensions._ListSetExtensions(): + size += BytesForField(self, field, value) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + return size + + cls._ContentBytesForPackedField = _ContentBytesForPackedField + cls.ByteSize = ByteSize + + +def _MessageSetField(field_descriptor): + """Checks if a field should be serialized using the message set wire format. + + Args: + field_descriptor: Descriptor of the field. + + Returns: + True if the field should be serialized using the message set wire format, + false otherwise. + """ + return (field_descriptor.is_extension and + field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and + field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + field_descriptor.containing_type.GetOptions().message_set_wire_format) + + +def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): + """Appends the serialization of a single value to encoder. + + Args: + value: Value to serialize. + field_number: Field number of this value. + field_descriptor: Descriptor of the field to serialize. + encoder: encoder.Encoder object to which we should serialize this value. + """ + if _MessageSetField(field_descriptor): + encoder.AppendMessageSetItem(field_number, value) + return + + try: + method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] + method(encoder, field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % + field_descriptor.type) + + +def _ImergeSorted(*streams): + """Merges N sorted iterators into a single sorted iterator. + Each element in streams must be an iterable that yields + its elements in sorted order, and the elements contained + in each stream must all be comparable. + + There may be repeated elements in the component streams or + across the streams; the repeated elements will all be repeated + in the merged iterator as well. + + I believe that the heapq module at HEAD in the Python + sources has a method like this, but for now we roll our own. + """ + iters = [iter(stream) for stream in streams] + heap = [] + for index, it in enumerate(iters): + try: + heap.append((it.next(), index)) + except StopIteration: + pass + heapq.heapify(heap) + + while heap: + smallest_value, idx = heap[0] + yield smallest_value + try: + next_element = iters[idx].next() + heapq.heapreplace(heap, (next_element, idx)) + except StopIteration: + heapq.heappop(heap) + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not _InternalIsInitialized(self, errors): + raise message_mod.EncodeError('\n'.join(errors)) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Encoder = encoder.Encoder + + def SerializePartialToString(self): + encoder = Encoder() + # We need to serialize all extension and non-extension fields + # together, in sorted order by field number. + for field_descriptor, field_value in self.ListFields(): + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + repeated_value = field_value + else: + repeated_value = [field_value] + if field_descriptor.GetOptions().packed: + # First, write the field number and WIRETYPE_LENGTH_DELIMITED. + field_number = field_descriptor.number + encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + # Next, write the number of bytes. + content_bytes = self._ContentBytesForPackedField( + field_descriptor, field_value) + encoder.AppendInt32NoTag(content_bytes) + # Finally, write the actual values. + try: + method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[ + field_descriptor.type] + for value in repeated_value: + method(encoder, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % + field_descriptor.type) + else: + for element in repeated_value: + _SerializeValueToEncoder(element, field_descriptor.number, + field_descriptor, encoder) + return encoder.ToString() + + cls.SerializePartialToString = SerializePartialToString + + +def _WireTypeForFieldType(field_type): + """Given a field type, returns the expected wire type.""" + try: + return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] + except KeyError: + raise message_mod.DecodeError('Unknown field type: %d' % field_type) + + +def _WireTypeForField(field_descriptor): + """Given a field descriptor, returns the expected wire type.""" + if field_descriptor.GetOptions().packed: + return wire_format.WIRETYPE_LENGTH_DELIMITED + else: + return _WireTypeForFieldType(field_descriptor.type) + + +def _RecursivelyMerge(field_number, field_type, decoder, message): + """Decodes a message from decoder into message. + message is either a group or a nested message within some containing + protocol message. If it's a group, we use the group protocol to + deserialize, and if it's a nested message, we use the nested-message + protocol. + + Args: + field_number: The field number of message in its enclosing protocol buffer. + field_type: The field type of message. Must be either TYPE_MESSAGE + or TYPE_GROUP. + decoder: Decoder to read from. + message: Message to deserialize into. + """ + if field_type == _FieldDescriptor.TYPE_MESSAGE: + decoder.ReadMessageInto(message) + elif field_type == _FieldDescriptor.TYPE_GROUP: + decoder.ReadGroupInto(field_number, message) + else: + raise message_mod.DecodeError('Unexpected field type: %d' % field_type) + + +def _DeserializeScalarFromDecoder(field_type, decoder): + """Deserializes a scalar of the requested type from decoder. field_type must + be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. + """ + try: + method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] + return method(decoder) + except KeyError: + raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) + + +def _SkipField(field_number, wire_type, decoder): + """Skips a field with the specified wire type. + + Args: + field_number: Tag number of the field to skip. + wire_type: Wire type of the field to skip. + decoder: Decoder used to deserialize the messsage. It must be positioned + just after reading the the tag and wire type of the field. + """ + if wire_type == wire_format.WIRETYPE_VARINT: + decoder.ReadUInt64() + elif wire_type == wire_format.WIRETYPE_FIXED64: + decoder.ReadFixed64() + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + decoder.SkipBytes(decoder.ReadInt32()) + elif wire_type == wire_format.WIRETYPE_START_GROUP: + _SkipGroup(field_number, decoder) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + pass + elif wire_type == wire_format.WIRETYPE_FIXED32: + decoder.ReadFixed32() + else: + raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) + + +def _SkipGroup(group_number, decoder): + """Skips a nested group from the decoder. + + Args: + group_number: Tag number of the group to skip. + decoder: Decoder used to deserialize the message. It must be positioned + exactly at the beginning of the message that should be skipped. + """ + while True: + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if (wire_type == wire_format.WIRETYPE_END_GROUP and + field_number == group_number): + return + _SkipField(field_number, wire_type, decoder) + + +def _DeserializeMessageSetItem(message, decoder): + """Deserializes a message using the message set wire format. + + Args: + message: Message to be parsed to. + decoder: The decoder to be used to deserialize encoded data. Note that the + decoder should be positioned just after reading the START_GROUP tag that + began the messageset item. + """ + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + type_id = decoder.ReadInt32() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if type_id not in extensions_by_number: + _SkipField(field_number, wire_type, decoder) + return + + field_descriptor = extensions_by_number[type_id] + value = extension_dict[field_descriptor] + decoder.ReadMessageInto(value) + # Read the END_GROUP tag. + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + +def _DeserializeOneEntity(message_descriptor, message, decoder): + """Deserializes the next wire entity from decoder into message. + + The next wire entity is either a scalar or a nested message, an + element in a repeated field (the wire encoding in this case is the + same), or a packed repeated field (in this case, the entire repeated + field is read by a single call to _DeserializeOneEntity). + + Args: + message_descriptor: A Descriptor instance describing all fields + in message. + message: The Message instance into which we're decoding our fields. + decoder: The Decoder we're using to deserialize encoded data. + + Returns: The number of bytes read from decoder during this method. + """ + initial_position = decoder.Position() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if field_number in message_descriptor.fields_by_number: + # Non-extension field. + field_descriptor = message_descriptor.fields_by_number[field_number] + value = getattr(message, _PropertyName(field_descriptor.name)) + def nonextension_setter_fn(scalar): + setattr(message, _PropertyName(field_descriptor.name), scalar) + scalar_setter_fn = nonextension_setter_fn + elif field_number in extensions_by_number: + # Extension field. + field_descriptor = extensions_by_number[field_number] + value = extension_dict[field_descriptor] + def extension_setter_fn(scalar): + extension_dict[field_descriptor] = scalar + scalar_setter_fn = extension_setter_fn + elif wire_type == wire_format.WIRETYPE_END_GROUP: + # We assume we're being parsed as the group that's ended. + return 0 + elif (wire_type == wire_format.WIRETYPE_START_GROUP and + field_number == 1 and + message_descriptor.GetOptions().message_set_wire_format): + # A Message Set item. + _DeserializeMessageSetItem(message, decoder) + return decoder.Position() - initial_position + else: + _SkipField(field_number, wire_type, decoder) + return decoder.Position() - initial_position + + # If we reach this point, we've identified the field as either + # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, + # and |value| appropriately. Now actually deserialize the thing. + # + # field_descriptor: Describes the field we're deserializing. + # value: The value currently stored in the field to deserialize. + # Used only if the field is composite and/or repeated. + # scalar_setter_fn: A function F such that F(scalar) will + # set a nonrepeated scalar value for this field. Used only + # if this field is a nonrepeated scalar. + + field_number = field_descriptor.number + expected_wire_type = _WireTypeForField(field_descriptor) + if wire_type != expected_wire_type: + # Need to fill in uninterpreted_bytes. Work for the next CL. + raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') + + property_name = _PropertyName(field_descriptor.name) + label = field_descriptor.label + field_type = field_descriptor.type + cpp_type = field_descriptor.cpp_type + + # Nonrepeated scalar. Just set the field directly. + if (label != _FieldDescriptor.LABEL_REPEATED + and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + + # Nonrepeated composite. Recursively deserialize. + if label != _FieldDescriptor.LABEL_REPEATED: + composite = value + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + # Now we know we're dealing with a repeated field of some kind. + element_list = value + + if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated scalar. + if not field_descriptor.GetOptions().packed: + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + else: + # Packed repeated field. + length = _DeserializeScalarFromDecoder( + _FieldDescriptor.TYPE_INT32, decoder) + content_start = decoder.Position() + while decoder.Position() - content_start < length: + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + else: + # Repeated composite. + composite = element_list.add() + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + +def _FieldOrExtensionValues(message, field_or_extension): + """Retrieves the list of values for the specified field or extension. + + The target field or extension can be optional, required or repeated, but it + must have value(s) set. The assumption is that the target field or extension + is set (e.g. _HasFieldOrExtension holds true). + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension for which the list of values is + required. Must be an instance of FieldDescriptor. + + Returns: + A list of values for the specified field or extension. This list will only + contain a single element if the field is non-repeated. + """ + if field_or_extension.is_extension: + value = message.Extensions[field_or_extension] + else: + value = getattr(message, _ValueFieldName(field_or_extension.name)) + if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: + return [value] + else: + # In this case value is a list or repeated values. + return value + + +def _HasFieldOrExtension(message, field_or_extension): + """Checks if a message has the specified field or extension set. + + The field or extension specified can be optional, required or repeated. If + it is repeated, this function returns True. Otherwise it checks the has bit + of the field or extension. + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension to check. This must be a + FieldDescriptor instance. + + Returns: + True if the message has a value set for the specified field or extension, + or if the field or extension is repeated. + """ + if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: + return True + if field_or_extension.is_extension: + return message.HasExtension(field_or_extension) + else: + return message.HasField(field_or_extension.name) + + +def _IsFieldOrExtensionInitialized(message, field, errors=None): + """Checks if a message field or extension is initialized. + + Args: + message: The message which contains the field or extension. + field: Field or extension to check. This must be a FieldDescriptor instance. + errors: Errors will be appended to it, if set to a meaningful value. + + Returns: + True if the field/extension can be considered initialized. + """ + # If the field is required and is not set, it isn't initialized. + if field.label == _FieldDescriptor.LABEL_REQUIRED: + if not _HasFieldOrExtension(message, field): + if errors is not None: + errors.append('Required field %s is not set.' % field.full_name) + return False + + # If the field is optional and is not set, or if it + # isn't a submessage then the field is initialized. + if field.label == _FieldDescriptor.LABEL_OPTIONAL: + if not _HasFieldOrExtension(message, field): + return True + if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + return True + + # The field is set and is either a single or a repeated submessage. + messages = _FieldOrExtensionValues(message, field) + # If all submessages in this field are initialized, the field is + # considered initialized. + for message in messages: + if not _InternalIsInitialized(message, errors): + return False + return True + + +def _InternalIsInitialized(message, errors=None): + """Checks if all required fields of a message are set. + + Args: + message: The message to check. + errors: If set, initialization errors will be appended to it. + + Returns: + True iff the specified message has all required fields set. + """ + fields_and_extensions = [] + fields_and_extensions.extend(message.DESCRIPTOR.fields) + fields_and_extensions.extend( + [extension[0] for extension in message.Extensions._ListSetExtensions()]) + for field_or_extension in fields_and_extensions: + if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): + return False + return True + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Decoder = decoder.Decoder + def MergeFromString(self, serialized): + decoder = Decoder(serialized) + byte_count = 0 + while not decoder.EndOfStream(): + bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) + if not bytes_read: + break + byte_count += bytes_read + return byte_count + cls.MergeFromString = MergeFromString + + +def _AddIsInitializedMethod(cls): + """Adds the IsInitialized method to the protocol message class.""" + cls.IsInitialized = _InternalIsInitialized + + +def _MergeFieldOrExtension(destination_msg, field, value): + """Merges a specified message field into another message.""" + property_name = _PropertyName(field.name) + is_extension = field.is_extension + + if not is_extension: + destination = getattr(destination_msg, property_name) + elif (field.label == _FieldDescriptor.LABEL_REPEATED or + field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + destination = destination_msg.Extensions[field] + + # Case 1 - a composite field. + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.add().MergeFrom(v) + else: + destination.MergeFrom(value) + return + + # Case 2 - a repeated field. + if field.label == _FieldDescriptor.LABEL_REPEATED: + for v in value: + destination.append(v) + return + + # Case 3 - a singular field. + if is_extension: + destination_msg.Extensions[field] = value + else: + setattr(destination_msg, property_name, value) + + +def _AddMergeFromMethod(cls): + def MergeFrom(self, msg): + assert msg is not self + for field in msg.ListFields(): + _MergeFieldOrExtension(self, field[0], field[1]) + cls.MergeFrom = MergeFrom + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(cls) + _AddClearFieldMethod(cls) + _AddClearExtensionMethod(cls) + _AddClearMethod(cls) + _AddHasExtensionMethod(cls) + _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(cls) + _AddMergeFromMethod(cls) + + +def _AddPrivateHelperMethods(cls): + """Adds implementation of private helper methods to cls.""" + + def MaybeCallTransitionToNonemptyCallback(self): + """Calls self._listener.TransitionToNonempty() the first time this + method is called. On all subsequent calls, this is a no-op. + """ + if not self._called_transition_to_nonempty: + self._listener.TransitionToNonempty() + self._called_transition_to_nonempty = True + cls._MaybeCallTransitionToNonemptyCallback = ( + MaybeCallTransitionToNonemptyCallback) + + def MarkByteSizeDirty(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener.ByteSizeDirty() + cls._MarkByteSizeDirty = MarkByteSizeDirty + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message, has_field_name): + """Args: + parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() + and _MarkByteSizeDirty() methods we should call when we receive + TransitionToNonempty() and ByteSizeDirty() messages. + has_field_name: The name of the "has" field that we should set in + the parent message when we receive a TransitionToNonempty message, + or None if there's no "has" field to set. (This will be the case + for child objects in "repeated" fields). + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + self._has_field_name = has_field_name + + def TransitionToNonempty(self): + try: + if self._has_field_name is not None: + setattr(self._parent_message_weakref, self._has_field_name, True) + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + def ByteSizeDirty(self): + try: + self._parent_message_weakref._MarkByteSizeDirty() + except ReferenceError: + # Same as above. + pass + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): There's so much similarity between the way that +# extensions behave and the way that normal fields behave that it would +# be really nice to unify more code. It's not immediately obvious +# how to do this, though, and I'd rather get the full functionality +# implemented (and, crucially, get all the tests and specs fleshed out +# and passing), and then come back to this thorny unification problem. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + class _ExtensionListener(object): + + """Adapts an _ExtensionDict to behave as a MessageListener.""" + + def __init__(self, extension_dict, handle_id): + self._extension_dict = extension_dict + self._handle_id = handle_id + + def TransitionToNonempty(self): + self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) + + def ByteSizeDirty(self): + self._extension_dict._SubmessageByteSizeBecameDirty() + + # TODO(robinson): Somewhere, we need to blow up if people + # try to register two extensions with the same field number. + # (And we need a test for this of course). + + def __init__(self, extended_message, known_extensions): + """extended_message: Message instance for which we are the Extensions dict. + known_extensions: Iterable of known extension handles. + These must be FieldDescriptors. + """ + # We keep a weak reference to extended_message, since + # it has a reference to this instance in turn. + self._extended_message = weakref.proxy(extended_message) + # We make a deep copy of known_extensions to avoid any + # thread-safety concerns, since the argument passed in + # is the global (class-level) dict of known extensions for + # this type of message, which could be modified at any time + # via a RegisterExtension() call. + # + # This dict maps from handle id to handle (a FieldDescriptor). + # + # XXX + # TODO(robinson): This isn't good enough. The client could + # instantiate an object in module A, then afterward import + # module B and pass the instance to B.Foo(). If B imports + # an extender of this proto and then tries to use it, B + # will get a KeyError, even though the extension *is* registered + # at the time of use. + # XXX + self._known_extensions = dict((id(e), e) for e in known_extensions) + # Read lock around self._values, which may be modified by multiple + # concurrent readers in the conceptually "const" __getitem__ method. + # So, we grab this lock in every "read-only" method to ensure + # that concurrent read access is safe without external locking. + self._lock = threading.Lock() + # Maps from extension handle ID to current value of that extension. + self._values = {} + # Maps from extension handle ID to a boolean "has" bit, but only + # for non-repeated extension fields. + keys = (id for id, extension in self._known_extensions.iteritems() + if extension.label != _FieldDescriptor.LABEL_REPEATED) + self._has_bits = dict.fromkeys(keys, False) + + self._extensions_by_number = dict( + (f.number, f) for f in self._known_extensions.itervalues()) + + self._extensions_by_name = {} + for extension in self._known_extensions.itervalues(): + if (extension.containing_type.GetOptions().message_set_wire_format and + extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and + extension.message_type == extension.extension_scope and + extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL): + extension_name = extension.message_type.full_name + else: + extension_name = extension.full_name + self._extensions_by_name[extension_name] = extension + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + # We don't care as much about keeping critical sections short in the + # extension support, since it's presumably much less of a common case. + self._lock.acquire() + try: + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + if handle_id not in self._values: + self._AddMissingHandle(extension_handle, handle_id) + return self._values[handle_id] + finally: + self._lock.release() + + def __eq__(self, other): + # We have to grab read locks since we're accessing _values + # in a "const" method. See the comment in the constructor. + if self is other: + return True + self._lock.acquire() + try: + other._lock.acquire() + try: + if self._has_bits != other._has_bits: + return False + # If there's a "has" bit, then only compare values where it is true. + for k, v in self._values.iteritems(): + if self._has_bits.get(k, False) and v != other._values[k]: + return False + return True + finally: + other._lock.release() + finally: + self._lock.release() + + def __ne__(self, other): + return not self == other + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call + # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field + # this way, to set any necssary "has" bits in the ancestors of the extended + # message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + field = extension_handle # Just shorten the name. + if (field.label == _FieldDescriptor.LABEL_OPTIONAL + and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + type_checker.CheckValue(value) + self._values[handle_id] = value + self._has_bits[handle_id] = True + self._extended_message._MarkByteSizeDirty() + self._extended_message._MaybeCallTransitionToNonemptyCallback() + else: + raise TypeError('Extension is repeated and/or a composite type.') + + def _AddMissingHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # Special handling for non-repeated message extensions, which (like + # normal fields of this kind) are initialized lazily. + # REQUIRES: _lock already held. + cpp_type = extension_handle.cpp_type + label = extension_handle.label + if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + and label != _FieldDescriptor.LABEL_REPEATED): + self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) + else: + self._values[handle_id] = _DefaultValueForField( + self._extended_message, extension_handle) + + def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # REQUIRES: _lock already held. + value = extension_handle.message_type._concrete_class() + value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) + self._values[handle_id] = value + + def _SubmessageTransitionedToNonempty(self, handle_id): + """Called when a submessage with a given handle id first transitions to + being nonempty. Called by _ExtensionListener. + """ + assert handle_id in self._has_bits + self._has_bits[handle_id] = True + self._extended_message._MaybeCallTransitionToNonemptyCallback() + + def _SubmessageByteSizeBecameDirty(self): + """Called whenever a submessage's cached byte size becomes invalid + (goes from being "clean" to being "dirty"). Called by _ExtensionListener. + """ + self._extended_message._MarkByteSizeDirty() + + # We may wish to widen the public interface of Message.Extensions + # to expose some of this private functionality in the future. + # For now, we make all this functionality module-private and just + # implement what we need for serialization/deserialization, + # HasField()/ClearField(), etc. + + def _HasExtension(self, extension_handle): + """Method for internal use by this module. + Returns true iff we "have" this extension in the sense of the + "has" bit being set. + """ + handle_id = id(extension_handle) + # Note that this is different from the other checks. + if handle_id not in self._has_bits: + raise KeyError('Extension not known to this class, or is repeated field.') + return self._has_bits[handle_id] + + # Intentionally pretty similar to ClearField() above. + def _ClearExtension(self, extension_handle): + """Method for internal use by this module. + Clears the specified extension, unsetting its "has" bit. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + default_value = _DefaultValueForField(self._extended_message, + extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + self._extended_message._MarkByteSizeDirty() + else: + cpp_type = extension_handle.cpp_type + if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if handle_id in self._values: + # Future modifications to this object shouldn't set any + # "has" bits here. + self._values[handle_id]._SetListener(None) + if self._has_bits[handle_id]: + self._has_bits[handle_id] = False + self._extended_message._MarkByteSizeDirty() + if handle_id in self._values: + del self._values[handle_id] + + def _ListSetExtensions(self): + """Method for internal use by this module. + + Returns an sequence of all extensions that are currently "set" + in this extension dict. A "set" extension is a repeated extension, + or a non-repeated extension with its "has" bit set. + + The returned sequence contains (field_descriptor, value) pairs, + where value is the current value of the extension with the given + field descriptor. + + The sequence values are in arbitrary order. + """ + self._lock.acquire() # Read-only methods must lock around self._values. + try: + set_extensions = [] + for handle_id, value in self._values.iteritems(): + handle = self._known_extensions[handle_id] + if (handle.label == _FieldDescriptor.LABEL_REPEATED + or self._has_bits[handle_id]): + set_extensions.append((handle, value)) + return set_extensions + finally: + self._lock.release() + + def _AllExtensionsByNumber(self): + """Method for internal use by this module. + + Returns: A dict mapping field_number to (handle, field_descriptor), + for *all* registered extensions for this dict. + """ + return self._extensions_by_number + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extensions_by_name.get(name, None) diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py new file mode 100755 index 0000000..dd136c9 --- /dev/null +++ b/python/google/protobuf/service.py @@ -0,0 +1,222 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Declares the RPC service interfaces. + +This module declares the abstract interfaces underlying proto2 RPC +services. These are intended to be independent of any particular RPC +implementation, so that proto2 services can be used on top of a variety +of implementations. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class RpcException(Exception): + """Exception raised on failed blocking RPC method call.""" + pass + + +class Service(object): + + """Abstract base interface for protocol-buffer-based RPC services. + + Services themselves are abstract classes (implemented either by servers or as + stubs), but they subclass this base interface. The methods of this + interface can be used to call the methods of the service without knowing + its exact type at compile time (analogous to the Message interface). + """ + + def GetDescriptor(): + """Retrieves this service's descriptor.""" + raise NotImplementedError + + def CallMethod(self, method_descriptor, rpc_controller, + request, done): + """Calls a method of the service specified by method_descriptor. + + If "done" is None then the call is blocking and the response + message will be returned directly. Otherwise the call is asynchronous + and "done" will later be called with the response value. + + In the blocking case, RpcException will be raised on error. + + Preconditions: + * method_descriptor.service == GetDescriptor + * request is of the exact same classes as returned by + GetRequestClass(method). + * After the call has started, the request must not be modified. + * "rpc_controller" is of the correct type for the RPC implementation being + used by this Service. For stubs, the "correct type" depends on the + RpcChannel which the stub is using. + + Postconditions: + * "done" will be called when the method is complete. This may be + before CallMethod() returns or it may be at some point in the future. + * If the RPC failed, the response value passed to "done" will be None. + Further details about the failure can be found by querying the + RpcController. + """ + raise NotImplementedError + + def GetRequestClass(self, method_descriptor): + """Returns the class of the request message for the specified method. + + CallMethod() requires that the request is of a particular subclass of + Message. GetRequestClass() gets the default instance of this required + type. + + Example: + method = service.GetDescriptor().FindMethodByName("Foo") + request = stub.GetRequestClass(method)() + request.ParseFromString(input) + service.CallMethod(method, request, callback) + """ + raise NotImplementedError + + def GetResponseClass(self, method_descriptor): + """Returns the class of the response message for the specified method. + + This method isn't really needed, as the RpcChannel's CallMethod constructs + the response protocol message. It's provided anyway in case it is useful + for the caller to know the response type in advance. + """ + raise NotImplementedError + + +class RpcController(object): + + """An RpcController mediates a single method call. + + The primary purpose of the controller is to provide a way to manipulate + settings specific to the RPC implementation and to find out about RPC-level + errors. The methods provided by the RpcController interface are intended + to be a "least common denominator" set of features which we expect all + implementations to support. Specific implementations may provide more + advanced features (e.g. deadline propagation). + """ + + # Client-side methods below + + def Reset(self): + """Resets the RpcController to its initial state. + + After the RpcController has been reset, it may be reused in + a new call. Must not be called while an RPC is in progress. + """ + raise NotImplementedError + + def Failed(self): + """Returns true if the call failed. + + After a call has finished, returns true if the call failed. The possible + reasons for failure depend on the RPC implementation. Failed() must not + be called before a call has finished. If Failed() returns true, the + contents of the response message are undefined. + """ + raise NotImplementedError + + def ErrorText(self): + """If Failed is true, returns a human-readable description of the error.""" + raise NotImplementedError + + def StartCancel(self): + """Initiate cancellation. + + Advises the RPC system that the caller desires that the RPC call be + canceled. The RPC system may cancel it immediately, may wait awhile and + then cancel it, or may not even cancel the call at all. If the call is + canceled, the "done" callback will still be called and the RpcController + will indicate that the call failed at that time. + """ + raise NotImplementedError + + # Server-side methods below + + def SetFailed(self, reason): + """Sets a failure reason. + + Causes Failed() to return true on the client side. "reason" will be + incorporated into the message returned by ErrorText(). If you find + you need to return machine-readable information about failures, you + should incorporate it into your response protocol buffer and should + NOT call SetFailed(). + """ + raise NotImplementedError + + def IsCanceled(self): + """Checks if the client cancelled the RPC. + + If true, indicates that the client canceled the RPC, so the server may + as well give up on replying to it. The server should still call the + final "done" callback. + """ + raise NotImplementedError + + def NotifyOnCancel(self, callback): + """Sets a callback to invoke on cancel. + + Asks that the given callback be called when the RPC is canceled. The + callback will always be called exactly once. If the RPC completes without + being canceled, the callback will be called after completion. If the RPC + has already been canceled when NotifyOnCancel() is called, the callback + will be called immediately. + + NotifyOnCancel() must be called no more than once per request. + """ + raise NotImplementedError + + +class RpcChannel(object): + + """Abstract interface for an RPC channel. + + An RpcChannel represents a communication line to a service which can be used + to call that service's methods. The service may be running on another + machine. Normally, you should not use an RpcChannel directly, but instead + construct a stub {@link Service} wrapping it. Example: + + Example: + RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234") + RpcController controller = rpcImpl.Controller() + MyService service = MyService_Stub(channel) + service.MyMethod(controller, request, callback) + """ + + def CallMethod(self, method_descriptor, rpc_controller, + request, response_class, done): + """Calls the method identified by the descriptor. + + Call the given method of the remote service. The signature of this + procedure looks the same as Service.CallMethod(), but the requirements + are less strict in one important way: the request object doesn't have to + be of any specific class as long as its descriptor is method.input_type. + """ + raise NotImplementedError diff --git a/python/google/protobuf/service_reflection.py b/python/google/protobuf/service_reflection.py new file mode 100755 index 0000000..851e83e --- /dev/null +++ b/python/google/protobuf/service_reflection.py @@ -0,0 +1,284 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains metaclasses used to create protocol service and service stub +classes from ServiceDescriptor objects at runtime. + +The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to +inject all useful functionality into the classes output by the protocol +compiler at compile-time. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class GeneratedServiceType(type): + + """Metaclass for service classes created at runtime from ServiceDescriptors. + + Implementations for all methods described in the Service class are added here + by this class. We also create properties to allow getting/setting all fields + in the protocol message. + + The protocol compiler currently uses this metaclass to create protocol service + classes at runtime. Clients can also manually create their own classes at + runtime, as in this example: + + mydescriptor = ServiceDescriptor(.....) + class MyProtoService(service.Service): + __metaclass__ = GeneratedServiceType + DESCRIPTOR = mydescriptor + myservice_instance = MyProtoService() + ... + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service class. + + Args: + name: Name of the class (ignored, but required by the metaclass + protocol). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service class is subclassed. + if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY] + service_builder = _ServiceBuilder(descriptor) + service_builder.BuildService(cls) + + +class GeneratedServiceStubType(GeneratedServiceType): + + """Metaclass for service stubs created at runtime from ServiceDescriptors. + + This class has similar responsibilities as GeneratedServiceType, except that + it creates the service stub classes. + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service stub class. + + Args: + name: Name of the class (ignored, here). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary) + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service stub is subclassed. + if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY] + service_stub_builder = _ServiceStubBuilder(descriptor) + service_stub_builder.BuildServiceStub(cls) + + +class _ServiceBuilder(object): + + """This class constructs a protocol service class using a service descriptor. + + Given a service descriptor, this class constructs a class that represents + the specified service descriptor. One service builder instance constructs + exactly one service class. That means all instances of that class share the + same builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + service class. + """ + self.descriptor = service_descriptor + + def BuildService(self, cls): + """Constructs the service class. + + Args: + cls: The class that will be constructed. + """ + + # CallMethod needs to operate with an instance of the Service class. This + # internal wrapper function exists only to be able to pass the service + # instance to the method that does the real CallMethod work. + def _WrapCallMethod(srvc, method_descriptor, + rpc_controller, request, callback): + return self._CallMethod(srvc, method_descriptor, + rpc_controller, request, callback) + self.cls = cls + cls.CallMethod = _WrapCallMethod + cls.GetDescriptor = staticmethod(lambda: self.descriptor) + cls.GetDescriptor.__doc__ = "Returns the service descriptor." + cls.GetRequestClass = self._GetRequestClass + cls.GetResponseClass = self._GetResponseClass + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) + + def _CallMethod(self, srvc, method_descriptor, + rpc_controller, request, callback): + """Calls the method described by a given method descriptor. + + Args: + srvc: Instance of the service for which this method is called. + method_descriptor: Descriptor that represent the method to call. + rpc_controller: RPC controller to use for this method's execution. + request: Request protocol message. + callback: A callback to invoke after the method has completed. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'CallMethod() given method descriptor for wrong service type.') + method = getattr(srvc, method_descriptor.name) + return method(rpc_controller, request, callback) + + def _GetRequestClass(self, method_descriptor): + """Returns the class of the request protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + request protocol message class. + + Returns: + A class that represents the input protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetRequestClass() given method descriptor for wrong service type.') + return method_descriptor.input_type._concrete_class + + def _GetResponseClass(self, method_descriptor): + """Returns the class of the response protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + response protocol message class. + + Returns: + A class that represents the output protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetResponseClass() given method descriptor for wrong service type.') + return method_descriptor.output_type._concrete_class + + def _GenerateNonImplementedMethod(self, method): + """Generates and returns a method that can be set for a service methods. + + Args: + method: Descriptor of the service method for which a method is to be + generated. + + Returns: + A method that can be added to the service class. + """ + return lambda inst, rpc_controller, request, callback: ( + self._NonImplementedMethod(method.name, rpc_controller, callback)) + + def _NonImplementedMethod(self, method_name, rpc_controller, callback): + """The body of all methods in the generated service class. + + Args: + method_name: Name of the method being executed. + rpc_controller: RPC controller used to execute this method. + callback: A callback which will be invoked when the method finishes. + """ + rpc_controller.SetFailed('Method %s not implemented.' % method_name) + callback(None) + + +class _ServiceStubBuilder(object): + + """Constructs a protocol service stub class using a service descriptor. + + Given a service descriptor, this class constructs a suitable stub class. + A stub is just a type-safe wrapper around an RpcChannel which emulates a + local implementation of the service. + + One service stub builder instance constructs exactly one class. It means all + instances of that class share the same service stub builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service stub class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + stub class. + """ + self.descriptor = service_descriptor + + def BuildServiceStub(self, cls): + """Constructs the stub class. + + Args: + cls: The class that will be constructed. + """ + + def _ServiceStubInit(stub, rpc_channel): + stub.rpc_channel = rpc_channel + self.cls = cls + cls.__init__ = _ServiceStubInit + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateStubMethod(method)) + + def _GenerateStubMethod(self, method): + return (lambda inst, rpc_controller, request, callback=None: + self._StubMethod(inst, method, rpc_controller, request, callback)) + + def _StubMethod(self, stub, method_descriptor, + rpc_controller, request, callback): + """The body of all service methods in the generated stub class. + + Args: + stub: Stub instance. + method_descriptor: Descriptor of the invoked method. + rpc_controller: Rpc controller to execute the method. + request: Request protocol message. + callback: A callback to execute when the method finishes. + Returns: + Response message (in case of blocking call). + """ + return stub.rpc_channel.CallMethod( + method_descriptor, rpc_controller, request, + method_descriptor.output_type._concrete_class, callback) diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py new file mode 100755 index 0000000..1cddce6 --- /dev/null +++ b/python/google/protobuf/text_format.py @@ -0,0 +1,650 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in text format.""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import cStringIO +import re + +from collections import deque +from google.protobuf.internal import type_checkers +from google.protobuf import descriptor + +__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge' ] + + +class ParseError(Exception): + """Thrown in case of ASCII parsing error.""" + + +def MessageToString(message): + out = cStringIO.StringIO() + PrintMessage(message, out) + result = out.getvalue() + out.close() + return result + + +def PrintMessage(message, out, indent = 0): + for field, value in message.ListFields(): + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + PrintField(field, element, out, indent) + else: + PrintField(field, value, out, indent) + + +def PrintField(field, value, out, indent = 0): + """Print a single field name/value pair. For repeated fields, the value + should be a single element.""" + + out.write(' ' * indent); + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + PrintFieldValue(field, value, out, indent) + out.write('\n') + + +def PrintFieldValue(field, value, out, indent = 0): + """Print a single field value (not including name). For repeated fields, + the value should be a single element.""" + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + out.write(' {\n') + PrintMessage(value, out, indent + 2) + out.write(' ' * indent + '}') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + out.write(field.enum_type.values_by_number[value].name) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + out.write(_CEscape(value)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write("true") + else: + out.write("false") + else: + out.write(str(value)) + + +def Merge(text, message): + """Merges an ASCII representation of a protocol message into a message. + + Args: + text: Message ASCII representation. + message: A protocol buffer message to merge into. + + Raises: + ParseError: On ASCII parsing problems. + """ + tokenizer = _Tokenizer(text) + while not tokenizer.AtEnd(): + _MergeField(tokenizer, message) + + +def _MergeField(tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of ASCII parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + field = message.Extensions._FindExtensionByName(name) + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) + elif message_descriptor != field.containing_type: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" does not extend message type "%s".' % ( + name, message_descriptor.full_name)) + tokenizer.Consume(']') + else: + name = tokenizer.ConsumeIdentifier() + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % ( + message_descriptor.full_name, name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + + if tokenizer.TryConsume('<'): + end_token = '>' + else: + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) + _MergeField(tokenizer, sub_message) + else: + _MergeScalarField(tokenizer, message, field) + + +def _MergeScalarField(tokenizer, message, field): + """Merges a single protocol message scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of ASCII parsing problems. + RuntimeError: On runtime errors. + """ + tokenizer.Consume(':') + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = tokenizer.ConsumeInt32() + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = tokenizer.ConsumeInt64() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = tokenizer.ConsumeUint32() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = tokenizer.ConsumeUint64() + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + # Enum can be specified by a number (the enum value), or by + # a string literal (the enum name). + enum_descriptor = field.enum_type + if tokenizer.LookingAtInteger(): + number = tokenizer.ConsumeInt32() + enum_value = enum_descriptor.values_by_number.get(number, None) + if enum_value is None: + raise tokenizer.ParseErrorPreviousToken( + 'Enum type "%s" has no value with number %d.' % ( + enum_descriptor.full_name, number)) + else: + identifier = tokenizer.ConsumeIdentifier() + enum_value = enum_descriptor.values_by_name.get(identifier, None) + if enum_value is None: + raise tokenizer.ParseErrorPreviousToken( + 'Enum type "%s" has no value named %s.' % ( + enum_descriptor.full_name, identifier)) + value = enum_value.number + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + message.Extensions[field] = value + else: + setattr(message, field.name, value) + + +class _Tokenizer(object): + """Protocol buffer ASCII representation tokenizer. + + This class handles the lower level string parsing by splitting it into + meaningful tokens. + + It was directly ported from the Java protocol buffer API. + """ + + _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE) + _TOKEN = re.compile( + '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier + '[0-9+-][0-9a-zA-Z_.+-]*|' # a number + '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string + '\'([^\"\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string + _IDENTIFIER = re.compile('\w+') + _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(), + type_checkers.Int32ValueChecker(), + type_checkers.Uint64ValueChecker(), + type_checkers.Int64ValueChecker()] + _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE) + _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE) + + def __init__(self, text_message): + self._text_message = text_message + + self._position = 0 + self._line = -1 + self._column = 0 + self._token_start = None + self.token = '' + self._lines = deque(text_message.split('\n')) + self._current_line = '' + self._previous_line = 0 + self._previous_column = 0 + self._SkipWhitespace() + self.NextToken() + + def AtEnd(self): + """Checks the end of the text was reached. + + Returns: + True iff the end was reached. + """ + return not self._lines and not self._current_line + + def _PopLine(self): + while not self._current_line: + if not self._lines: + self._current_line = '' + return + self._line += 1 + self._column = 0 + self._current_line = self._lines.popleft() + + def _SkipWhitespace(self): + while True: + self._PopLine() + match = re.match(self._WHITESPACE, self._current_line) + if not match: + break + length = len(match.group(0)) + self._current_line = self._current_line[length:] + self._column += length + + def TryConsume(self, token): + """Tries to consume a given piece of text. + + Args: + token: Text to consume. + + Returns: + True iff the text was consumed. + """ + if self.token == token: + self.NextToken() + return True + return False + + def Consume(self, token): + """Consumes a piece of text. + + Args: + token: Text to consume. + + Raises: + ParseError: If the text couldn't be consumed. + """ + if not self.TryConsume(token): + raise self._ParseError('Expected "%s".' % token) + + def LookingAtInteger(self): + """Checks if the current token is an integer. + + Returns: + True iff the current token is an integer. + """ + if not self.token: + return False + c = self.token[0] + return (c >= '0' and c <= '9') or c == '-' or c == '+' + + def ConsumeIdentifier(self): + """Consumes protocol message field identifier. + + Returns: + Identifier string. + + Raises: + ParseError: If an identifier couldn't be consumed. + """ + result = self.token + if not re.match(self._IDENTIFIER, result): + raise self._ParseError('Expected identifier.') + self.NextToken() + return result + + def ConsumeInt32(self): + """Consumes a signed 32bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=True, is_long=False) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeUint32(self): + """Consumes an unsigned 32bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 32bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=False, is_long=False) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeInt64(self): + """Consumes a signed 64bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 64bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=True, is_long=True) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeUint64(self): + """Consumes an unsigned 64bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 64bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=False, is_long=True) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeFloat(self): + """Consumes an floating point number. + + Returns: + The number parsed. + + Raises: + ParseError: If a floating point number couldn't be consumed. + """ + text = self.token + if re.match(self._FLOAT_INFINITY, text): + self.NextToken() + if text.startswith('-'): + return float('-inf') + return float('inf') + + if re.match(self._FLOAT_NAN, text): + self.NextToken() + return float('nan') + + try: + result = float(text) + except ValueError, e: + raise self._FloatParseError(e) + self.NextToken() + return result + + def ConsumeBool(self): + """Consumes a boolean value. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + if self.token == 'true': + self.NextToken() + return True + elif self.token == 'false': + self.NextToken() + return False + else: + raise self._ParseError('Expected "true" or "false".') + + def ConsumeString(self): + """Consumes a string value. + + Returns: + The string parsed. + + Raises: + ParseError: If a string value couldn't be consumed. + """ + return unicode(self.ConsumeByteString(), 'utf-8') + + def ConsumeByteString(self): + """Consumes a byte array value. + + Returns: + The array parsed (as a string). + + Raises: + ParseError: If a byte array value couldn't be consumed. + """ + text = self.token + if len(text) < 1 or text[0] not in ('\'', '"'): + raise self._ParseError('Exptected string.') + + if len(text) < 2 or text[-1] != text[0]: + raise self._ParseError('String missing ending quote.') + + try: + result = _CUnescape(text[1:-1]) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() + return result + + def _ParseInteger(self, text, is_signed=False, is_long=False): + """Parses an integer. + + Args: + text: The text to parse. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + pos = 0 + if text.startswith('-'): + pos += 1 + + base = 10 + if text.startswith('0x', pos) or text.startswith('0X', pos): + base = 16 + elif text.startswith('0', pos): + base = 8 + + # Do the actual parsing. Exception handling is propagated to caller. + result = int(text, base) + + # Check if the integer is sane. Exceptions handled by callers. + checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + def ParseErrorPreviousToken(self, message): + """Creates and *returns* a ParseError for the previously read token. + + Args: + message: A message to set for the exception. + + Returns: + A ParseError instance. + """ + return ParseError('%d:%d : %s' % ( + self._previous_line + 1, self._previous_column + 1, message)) + + def _ParseError(self, message): + """Creates and *returns* a ParseError for the current token.""" + return ParseError('%d:%d : %s' % ( + self._line + 1, self._column + 1, message)) + + def _IntegerParseError(self, e): + return self._ParseError('Couldn\'t parse integer: ' + str(e)) + + def _FloatParseError(self, e): + return self._ParseError('Couldn\'t parse number: ' + str(e)) + + def NextToken(self): + """Reads the next meaningful token.""" + self._previous_line = self._line + self._previous_column = self._column + if self.AtEnd(): + self.token = '' + return + self._column += len(self.token) + + # Make sure there is data to work on. + self._PopLine() + + match = re.match(self._TOKEN, self._current_line) + if match: + token = match.group(0) + self._current_line = self._current_line[len(token):] + self.token = token + else: + self.token = self._current_line[0] + self._current_line = self._current_line[1:] + self._SkipWhitespace() + + +# text.encode('string_escape') does not seem to satisfy our needs as it +# encodes unprintable characters using two-digit hex escapes whereas our +# C++ unescaping function allows hex escapes to be any length. So, +# "\0011".encode('string_escape') ends up being "\\x011", which will be +# decoded in C++ as a single-character string with char code 0x11. +def _CEscape(text): + def escape(c): + o = ord(c) + if o == 10: return r"\n" # optional escape + if o == 13: return r"\r" # optional escape + if o == 9: return r"\t" # optional escape + if o == 39: return r"\'" # optional escape + + if o == 34: return r'\"' # necessary escape + if o == 92: return r"\\" # necessary escape + + if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes + return c + return "".join([escape(c) for c in text]) + + +_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])') + + +def _CUnescape(text): + def ReplaceHex(m): + return chr(int(m.group(0)[2:], 16)) + # This is required because the 'string_escape' encoding doesn't + # allow single-digit hex escapes (like '\xf'). + result = _CUNESCAPE_HEX.sub(ReplaceHex, text) + return result.decode('string_escape') diff --git a/python/mox.py b/python/mox.py new file mode 100755 index 0000000..ce80ba5 --- /dev/null +++ b/python/mox.py @@ -0,0 +1,1401 @@ +#!/usr/bin/python2.4 +# +# Copyright 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is used for testing. The original is at: +# http://code.google.com/p/pymox/ + +"""Mox, an object-mocking framework for Python. + +Mox works in the record-replay-verify paradigm. When you first create +a mock object, it is in record mode. You then programmatically set +the expected behavior of the mock object (what methods are to be +called on it, with what parameters, what they should return, and in +what order). + +Once you have set up the expected mock behavior, you put it in replay +mode. Now the mock responds to method calls just as you told it to. +If an unexpected method (or an expected method with unexpected +parameters) is called, then an exception will be raised. + +Once you are done interacting with the mock, you need to verify that +all the expected interactions occured. (Maybe your code exited +prematurely without calling some cleanup method!) The verify phase +ensures that every expected method was called; otherwise, an exception +will be raised. + +Suggested usage / workflow: + + # Create Mox factory + my_mox = Mox() + + # Create a mock data access object + mock_dao = my_mox.CreateMock(DAOClass) + + # Set up expected behavior + mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) + mock_dao.DeletePerson(person) + + # Put mocks in replay mode + my_mox.ReplayAll() + + # Inject mock object and run test + controller.SetDao(mock_dao) + controller.DeletePersonById('1') + + # Verify all methods were called as expected + my_mox.VerifyAll() +""" + +from collections import deque +import re +import types +import unittest + +import stubout + +class Error(AssertionError): + """Base exception for this module.""" + + pass + + +class ExpectedMethodCallsError(Error): + """Raised when Verify() is called before all expected methods have been called + """ + + def __init__(self, expected_methods): + """Init exception. + + Args: + # expected_methods: A sequence of MockMethod objects that should have been + # called. + expected_methods: [MockMethod] + + Raises: + ValueError: if expected_methods contains no methods. + """ + + if not expected_methods: + raise ValueError("There must be at least one expected method") + Error.__init__(self) + self._expected_methods = expected_methods + + def __str__(self): + calls = "\n".join(["%3d. %s" % (i, m) + for i, m in enumerate(self._expected_methods)]) + return "Verify: Expected methods never called:\n%s" % (calls,) + + +class UnexpectedMethodCallError(Error): + """Raised when an unexpected method is called. + + This can occur if a method is called with incorrect parameters, or out of the + specified order. + """ + + def __init__(self, unexpected_method, expected): + """Init exception. + + Args: + # unexpected_method: MockMethod that was called but was not at the head of + # the expected_method queue. + # expected: MockMethod or UnorderedGroup the method should have + # been in. + unexpected_method: MockMethod + expected: MockMethod or UnorderedGroup + """ + + Error.__init__(self) + self._unexpected_method = unexpected_method + self._expected = expected + + def __str__(self): + return "Unexpected method call: %s. Expecting: %s" % \ + (self._unexpected_method, self._expected) + + +class UnknownMethodCallError(Error): + """Raised if an unknown method is requested of the mock object.""" + + def __init__(self, unknown_method_name): + """Init exception. + + Args: + # unknown_method_name: Method call that is not part of the mocked class's + # public interface. + unknown_method_name: str + """ + + Error.__init__(self) + self._unknown_method_name = unknown_method_name + + def __str__(self): + return "Method called is not a member of the object: %s" % \ + self._unknown_method_name + + +class Mox(object): + """Mox: a factory for creating mock objects.""" + + # A list of types that should be stubbed out with MockObjects (as + # opposed to MockAnythings). + _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, + types.ObjectType, types.TypeType] + + def __init__(self): + """Initialize a new Mox.""" + + self._mock_objects = [] + self.stubs = stubout.StubOutForTesting() + + def CreateMock(self, class_to_mock): + """Create a new mock object. + + Args: + # class_to_mock: the class to be mocked + class_to_mock: class + + Returns: + MockObject that can be used as the class_to_mock would be. + """ + + new_mock = MockObject(class_to_mock) + self._mock_objects.append(new_mock) + return new_mock + + def CreateMockAnything(self): + """Create a mock that will accept any method calls. + + This does not enforce an interface. + """ + + new_mock = MockAnything() + self._mock_objects.append(new_mock) + return new_mock + + def ReplayAll(self): + """Set all mock objects to replay mode.""" + + for mock_obj in self._mock_objects: + mock_obj._Replay() + + + def VerifyAll(self): + """Call verify on all mock objects created.""" + + for mock_obj in self._mock_objects: + mock_obj._Verify() + + def ResetAll(self): + """Call reset on all mock objects. This does not unset stubs.""" + + for mock_obj in self._mock_objects: + mock_obj._Reset() + + def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): + """Replace a method, attribute, etc. with a Mock. + + This will replace a class or module with a MockObject, and everything else + (method, function, etc) with a MockAnything. This can be overridden to + always use a MockAnything by setting use_mock_anything to True. + + Args: + obj: A Python object (class, module, instance, callable). + attr_name: str. The name of the attribute to replace with a mock. + use_mock_anything: bool. True if a MockAnything should be used regardless + of the type of attribute. + """ + + attr_to_replace = getattr(obj, attr_name) + if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: + stub = self.CreateMock(attr_to_replace) + else: + stub = self.CreateMockAnything() + + self.stubs.Set(obj, attr_name, stub) + + def UnsetStubs(self): + """Restore stubs to their original state.""" + + self.stubs.UnsetAll() + +def Replay(*args): + """Put mocks into Replay mode. + + Args: + # args is any number of mocks to put into replay mode. + """ + + for mock in args: + mock._Replay() + + +def Verify(*args): + """Verify mocks. + + Args: + # args is any number of mocks to be verified. + """ + + for mock in args: + mock._Verify() + + +def Reset(*args): + """Reset mocks. + + Args: + # args is any number of mocks to be reset. + """ + + for mock in args: + mock._Reset() + + +class MockAnything: + """A mock that can be used to mock anything. + + This is helpful for mocking classes that do not provide a public interface. + """ + + def __init__(self): + """ """ + self._Reset() + + def __getattr__(self, method_name): + """Intercept method calls on this object. + + A new MockMethod is returned that is aware of the MockAnything's + state (record or replay). The call will be recorded or replayed + by the MockMethod's __call__. + + Args: + # method name: the name of the method being called. + method_name: str + + Returns: + A new MockMethod aware of MockAnything's state (record or replay). + """ + + return self._CreateMockMethod(method_name) + + def _CreateMockMethod(self, method_name): + """Create a new mock method call and return it. + + Args: + # method name: the name of the method being called. + method_name: str + + Returns: + A new MockMethod aware of MockAnything's state (record or replay). + """ + + return MockMethod(method_name, self._expected_calls_queue, + self._replay_mode) + + def __nonzero__(self): + """Return 1 for nonzero so the mock can be used as a conditional.""" + + return 1 + + def __eq__(self, rhs): + """Provide custom logic to compare objects.""" + + return (isinstance(rhs, MockAnything) and + self._replay_mode == rhs._replay_mode and + self._expected_calls_queue == rhs._expected_calls_queue) + + def __ne__(self, rhs): + """Provide custom logic to compare objects.""" + + return not self == rhs + + def _Replay(self): + """Start replaying expected method calls.""" + + self._replay_mode = True + + def _Verify(self): + """Verify that all of the expected calls have been made. + + Raises: + ExpectedMethodCallsError: if there are still more method calls in the + expected queue. + """ + + # If the list of expected calls is not empty, raise an exception + if self._expected_calls_queue: + # The last MultipleTimesGroup is not popped from the queue. + if (len(self._expected_calls_queue) == 1 and + isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and + self._expected_calls_queue[0].IsSatisfied()): + pass + else: + raise ExpectedMethodCallsError(self._expected_calls_queue) + + def _Reset(self): + """Reset the state of this mock to record mode with an empty queue.""" + + # Maintain a list of method calls we are expecting + self._expected_calls_queue = deque() + + # Make sure we are in setup mode, not replay mode + self._replay_mode = False + + +class MockObject(MockAnything, object): + """A mock object that simulates the public/protected interface of a class.""" + + def __init__(self, class_to_mock): + """Initialize a mock object. + + This determines the methods and properties of the class and stores them. + + Args: + # class_to_mock: class to be mocked + class_to_mock: class + """ + + # This is used to hack around the mixin/inheritance of MockAnything, which + # is not a proper object (it can be anything. :-) + MockAnything.__dict__['__init__'](self) + + # Get a list of all the public and special methods we should mock. + self._known_methods = set() + self._known_vars = set() + self._class_to_mock = class_to_mock + for method in dir(class_to_mock): + if callable(getattr(class_to_mock, method)): + self._known_methods.add(method) + else: + self._known_vars.add(method) + + def __getattr__(self, name): + """Intercept attribute request on this object. + + If the attribute is a public class variable, it will be returned and not + recorded as a call. + + If the attribute is not a variable, it is handled like a method + call. The method name is checked against the set of mockable + methods, and a new MockMethod is returned that is aware of the + MockObject's state (record or replay). The call will be recorded + or replayed by the MockMethod's __call__. + + Args: + # name: the name of the attribute being requested. + name: str + + Returns: + Either a class variable or a new MockMethod that is aware of the state + of the mock (record or replay). + + Raises: + UnknownMethodCallError if the MockObject does not mock the requested + method. + """ + + if name in self._known_vars: + return getattr(self._class_to_mock, name) + + if name in self._known_methods: + return self._CreateMockMethod(name) + + raise UnknownMethodCallError(name) + + def __eq__(self, rhs): + """Provide custom logic to compare objects.""" + + return (isinstance(rhs, MockObject) and + self._class_to_mock == rhs._class_to_mock and + self._replay_mode == rhs._replay_mode and + self._expected_calls_queue == rhs._expected_calls_queue) + + def __setitem__(self, key, value): + """Provide custom logic for mocking classes that support item assignment. + + Args: + key: Key to set the value for. + value: Value to set. + + Returns: + Expected return value in replay mode. A MockMethod object for the + __setitem__ method that has already been called if not in replay mode. + + Raises: + TypeError if the underlying class does not support item assignment. + UnexpectedMethodCallError if the object does not expect the call to + __setitem__. + + """ + setitem = self._class_to_mock.__dict__.get('__setitem__', None) + + # Verify the class supports item assignment. + if setitem is None: + raise TypeError('object does not support item assignment') + + # If we are in replay mode then simply call the mock __setitem__ method. + if self._replay_mode: + return MockMethod('__setitem__', self._expected_calls_queue, + self._replay_mode)(key, value) + + + # Otherwise, create a mock method __setitem__. + return self._CreateMockMethod('__setitem__')(key, value) + + def __getitem__(self, key): + """Provide custom logic for mocking classes that are subscriptable. + + Args: + key: Key to return the value for. + + Returns: + Expected return value in replay mode. A MockMethod object for the + __getitem__ method that has already been called if not in replay mode. + + Raises: + TypeError if the underlying class is not subscriptable. + UnexpectedMethodCallError if the object does not expect the call to + __setitem__. + + """ + getitem = self._class_to_mock.__dict__.get('__getitem__', None) + + # Verify the class supports item assignment. + if getitem is None: + raise TypeError('unsubscriptable object') + + # If we are in replay mode then simply call the mock __getitem__ method. + if self._replay_mode: + return MockMethod('__getitem__', self._expected_calls_queue, + self._replay_mode)(key) + + + # Otherwise, create a mock method __getitem__. + return self._CreateMockMethod('__getitem__')(key) + + def __call__(self, *params, **named_params): + """Provide custom logic for mocking classes that are callable.""" + + # Verify the class we are mocking is callable + callable = self._class_to_mock.__dict__.get('__call__', None) + if callable is None: + raise TypeError('Not callable') + + # Because the call is happening directly on this object instead of a method, + # the call on the mock method is made right here + mock_method = self._CreateMockMethod('__call__') + return mock_method(*params, **named_params) + + @property + def __class__(self): + """Return the class that is being mocked.""" + + return self._class_to_mock + + +class MockMethod(object): + """Callable mock method. + + A MockMethod should act exactly like the method it mocks, accepting parameters + and returning a value, or throwing an exception (as specified). When this + method is called, it can optionally verify whether the called method (name and + signature) matches the expected method. + """ + + def __init__(self, method_name, call_queue, replay_mode): + """Construct a new mock method. + + Args: + # method_name: the name of the method + # call_queue: deque of calls, verify this call against the head, or add + # this call to the queue. + # replay_mode: False if we are recording, True if we are verifying calls + # against the call queue. + method_name: str + call_queue: list or deque + replay_mode: bool + """ + + self._name = method_name + self._call_queue = call_queue + if not isinstance(call_queue, deque): + self._call_queue = deque(self._call_queue) + self._replay_mode = replay_mode + + self._params = None + self._named_params = None + self._return_value = None + self._exception = None + self._side_effects = None + + def __call__(self, *params, **named_params): + """Log parameters and return the specified return value. + + If the Mock(Anything/Object) associated with this call is in record mode, + this MockMethod will be pushed onto the expected call queue. If the mock + is in replay mode, this will pop a MockMethod off the top of the queue and + verify this call is equal to the expected call. + + Raises: + UnexpectedMethodCall if this call is supposed to match an expected method + call and it does not. + """ + + self._params = params + self._named_params = named_params + + if not self._replay_mode: + self._call_queue.append(self) + return self + + expected_method = self._VerifyMethodCall() + + if expected_method._side_effects: + expected_method._side_effects(*params, **named_params) + + if expected_method._exception: + raise expected_method._exception + + return expected_method._return_value + + def __getattr__(self, name): + """Raise an AttributeError with a helpful message.""" + + raise AttributeError('MockMethod has no attribute "%s". ' + 'Did you remember to put your mocks in replay mode?' % name) + + def _PopNextMethod(self): + """Pop the next method from our call queue.""" + try: + return self._call_queue.popleft() + except IndexError: + raise UnexpectedMethodCallError(self, None) + + def _VerifyMethodCall(self): + """Verify the called method is expected. + + This can be an ordered method, or part of an unordered set. + + Returns: + The expected mock method. + + Raises: + UnexpectedMethodCall if the method called was not expected. + """ + + expected = self._PopNextMethod() + + # Loop here, because we might have a MethodGroup followed by another + # group. + while isinstance(expected, MethodGroup): + expected, method = expected.MethodCalled(self) + if method is not None: + return method + + # This is a mock method, so just check equality. + if expected != self: + raise UnexpectedMethodCallError(self, expected) + + return expected + + def __str__(self): + params = ', '.join( + [repr(p) for p in self._params or []] + + ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) + desc = "%s(%s) -> %r" % (self._name, params, self._return_value) + return desc + + def __eq__(self, rhs): + """Test whether this MockMethod is equivalent to another MockMethod. + + Args: + # rhs: the right hand side of the test + rhs: MockMethod + """ + + return (isinstance(rhs, MockMethod) and + self._name == rhs._name and + self._params == rhs._params and + self._named_params == rhs._named_params) + + def __ne__(self, rhs): + """Test whether this MockMethod is not equivalent to another MockMethod. + + Args: + # rhs: the right hand side of the test + rhs: MockMethod + """ + + return not self == rhs + + def GetPossibleGroup(self): + """Returns a possible group from the end of the call queue or None if no + other methods are on the stack. + """ + + # Remove this method from the tail of the queue so we can add it to a group. + this_method = self._call_queue.pop() + assert this_method == self + + # Determine if the tail of the queue is a group, or just a regular ordered + # mock method. + group = None + try: + group = self._call_queue[-1] + except IndexError: + pass + + return group + + def _CheckAndCreateNewGroup(self, group_name, group_class): + """Checks if the last method (a possible group) is an instance of our + group_class. Adds the current method to this group or creates a new one. + + Args: + + group_name: the name of the group. + group_class: the class used to create instance of this new group + """ + group = self.GetPossibleGroup() + + # If this is a group, and it is the correct group, add the method. + if isinstance(group, group_class) and group.group_name() == group_name: + group.AddMethod(self) + return self + + # Create a new group and add the method. + new_group = group_class(group_name) + new_group.AddMethod(self) + self._call_queue.append(new_group) + return self + + def InAnyOrder(self, group_name="default"): + """Move this method into a group of unordered calls. + + A group of unordered calls must be defined together, and must be executed + in full before the next expected method can be called. There can be + multiple groups that are expected serially, if they are given + different group names. The same group name can be reused if there is a + standard method call, or a group with a different name, spliced between + usages. + + Args: + group_name: the name of the unordered group. + + Returns: + self + """ + return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) + + def MultipleTimes(self, group_name="default"): + """Move this method into group of calls which may be called multiple times. + + A group of repeating calls must be defined together, and must be executed in + full before the next expected mehtod can be called. + + Args: + group_name: the name of the unordered group. + + Returns: + self + """ + return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) + + def AndReturn(self, return_value): + """Set the value to return when this method is called. + + Args: + # return_value can be anything. + """ + + self._return_value = return_value + return return_value + + def AndRaise(self, exception): + """Set the exception to raise when this method is called. + + Args: + # exception: the exception to raise when this method is called. + exception: Exception + """ + + self._exception = exception + + def WithSideEffects(self, side_effects): + """Set the side effects that are simulated when this method is called. + + Args: + side_effects: A callable which modifies the parameters or other relevant + state which a given test case depends on. + + Returns: + Self for chaining with AndReturn and AndRaise. + """ + self._side_effects = side_effects + return self + +class Comparator: + """Base class for all Mox comparators. + + A Comparator can be used as a parameter to a mocked method when the exact + value is not known. For example, the code you are testing might build up a + long SQL string that is passed to your mock DAO. You're only interested that + the IN clause contains the proper primary keys, so you can set your mock + up as follows: + + mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) + + Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. + + A Comparator may replace one or more parameters, for example: + # return at most 10 rows + mock_dao.RunQuery(StrContains('SELECT'), 10) + + or + + # Return some non-deterministic number of rows + mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) + """ + + def equals(self, rhs): + """Special equals method that all comparators must implement. + + Args: + rhs: any python object + """ + + raise NotImplementedError, 'method must be implemented by a subclass.' + + def __eq__(self, rhs): + return self.equals(rhs) + + def __ne__(self, rhs): + return not self.equals(rhs) + + +class IsA(Comparator): + """This class wraps a basic Python type or class. It is used to verify + that a parameter is of the given type or class. + + Example: + mock_dao.Connect(IsA(DbConnectInfo)) + """ + + def __init__(self, class_name): + """Initialize IsA + + Args: + class_name: basic python type or a class + """ + + self._class_name = class_name + + def equals(self, rhs): + """Check to see if the RHS is an instance of class_name. + + Args: + # rhs: the right hand side of the test + rhs: object + + Returns: + bool + """ + + try: + return isinstance(rhs, self._class_name) + except TypeError: + # Check raw types if there was a type error. This is helpful for + # things like cStringIO.StringIO. + return type(rhs) == type(self._class_name) + + def __repr__(self): + return str(self._class_name) + +class IsAlmost(Comparator): + """Comparison class used to check whether a parameter is nearly equal + to a given value. Generally useful for floating point numbers. + + Example mock_dao.SetTimeout((IsAlmost(3.9))) + """ + + def __init__(self, float_value, places=7): + """Initialize IsAlmost. + + Args: + float_value: The value for making the comparison. + places: The number of decimal places to round to. + """ + + self._float_value = float_value + self._places = places + + def equals(self, rhs): + """Check to see if RHS is almost equal to float_value + + Args: + rhs: the value to compare to float_value + + Returns: + bool + """ + + try: + return round(rhs-self._float_value, self._places) == 0 + except TypeError: + # This is probably because either float_value or rhs is not a number. + return False + + def __repr__(self): + return str(self._float_value) + +class StrContains(Comparator): + """Comparison class used to check whether a substring exists in a + string parameter. This can be useful in mocking a database with SQL + passed in as a string parameter, for example. + + Example: + mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) + """ + + def __init__(self, search_string): + """Initialize. + + Args: + # search_string: the string you are searching for + search_string: str + """ + + self._search_string = search_string + + def equals(self, rhs): + """Check to see if the search_string is contained in the rhs string. + + Args: + # rhs: the right hand side of the test + rhs: object + + Returns: + bool + """ + + try: + return rhs.find(self._search_string) > -1 + except Exception: + return False + + def __repr__(self): + return '<str containing \'%s\'>' % self._search_string + + +class Regex(Comparator): + """Checks if a string matches a regular expression. + + This uses a given regular expression to determine equality. + """ + + def __init__(self, pattern, flags=0): + """Initialize. + + Args: + # pattern is the regular expression to search for + pattern: str + # flags passed to re.compile function as the second argument + flags: int + """ + + self.regex = re.compile(pattern, flags=flags) + + def equals(self, rhs): + """Check to see if rhs matches regular expression pattern. + + Returns: + bool + """ + + return self.regex.search(rhs) is not None + + def __repr__(self): + s = '<regular expression \'%s\'' % self.regex.pattern + if self.regex.flags: + s += ', flags=%d' % self.regex.flags + s += '>' + return s + + +class In(Comparator): + """Checks whether an item (or key) is in a list (or dict) parameter. + + Example: + mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) + """ + + def __init__(self, key): + """Initialize. + + Args: + # key is any thing that could be in a list or a key in a dict + """ + + self._key = key + + def equals(self, rhs): + """Check to see whether key is in rhs. + + Args: + rhs: dict + + Returns: + bool + """ + + return self._key in rhs + + def __repr__(self): + return '<sequence or map containing \'%s\'>' % self._key + + +class ContainsKeyValue(Comparator): + """Checks whether a key/value pair is in a dict parameter. + + Example: + mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) + """ + + def __init__(self, key, value): + """Initialize. + + Args: + # key: a key in a dict + # value: the corresponding value + """ + + self._key = key + self._value = value + + def equals(self, rhs): + """Check whether the given key/value pair is in the rhs dict. + + Returns: + bool + """ + + try: + return rhs[self._key] == self._value + except Exception: + return False + + def __repr__(self): + return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) + + +class SameElementsAs(Comparator): + """Checks whether iterables contain the same elements (ignoring order). + + Example: + mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) + """ + + def __init__(self, expected_seq): + """Initialize. + + Args: + expected_seq: a sequence + """ + + self._expected_seq = expected_seq + + def equals(self, actual_seq): + """Check to see whether actual_seq has same elements as expected_seq. + + Args: + actual_seq: sequence + + Returns: + bool + """ + + try: + expected = dict([(element, None) for element in self._expected_seq]) + actual = dict([(element, None) for element in actual_seq]) + except TypeError: + # Fall back to slower list-compare if any of the objects are unhashable. + expected = list(self._expected_seq) + actual = list(actual_seq) + expected.sort() + actual.sort() + return expected == actual + + def __repr__(self): + return '<sequence with same elements as \'%s\'>' % self._expected_seq + + +class And(Comparator): + """Evaluates one or more Comparators on RHS and returns an AND of the results. + """ + + def __init__(self, *args): + """Initialize. + + Args: + *args: One or more Comparator + """ + + self._comparators = args + + def equals(self, rhs): + """Checks whether all Comparators are equal to rhs. + + Args: + # rhs: can be anything + + Returns: + bool + """ + + for comparator in self._comparators: + if not comparator.equals(rhs): + return False + + return True + + def __repr__(self): + return '<AND %s>' % str(self._comparators) + + +class Or(Comparator): + """Evaluates one or more Comparators on RHS and returns an OR of the results. + """ + + def __init__(self, *args): + """Initialize. + + Args: + *args: One or more Mox comparators + """ + + self._comparators = args + + def equals(self, rhs): + """Checks whether any Comparator is equal to rhs. + + Args: + # rhs: can be anything + + Returns: + bool + """ + + for comparator in self._comparators: + if comparator.equals(rhs): + return True + + return False + + def __repr__(self): + return '<OR %s>' % str(self._comparators) + + +class Func(Comparator): + """Call a function that should verify the parameter passed in is correct. + + You may need the ability to perform more advanced operations on the parameter + in order to validate it. You can use this to have a callable validate any + parameter. The callable should return either True or False. + + + Example: + + def myParamValidator(param): + # Advanced logic here + return True + + mock_dao.DoSomething(Func(myParamValidator), true) + """ + + def __init__(self, func): + """Initialize. + + Args: + func: callable that takes one parameter and returns a bool + """ + + self._func = func + + def equals(self, rhs): + """Test whether rhs passes the function test. + + rhs is passed into func. + + Args: + rhs: any python object + + Returns: + the result of func(rhs) + """ + + return self._func(rhs) + + def __repr__(self): + return str(self._func) + + +class IgnoreArg(Comparator): + """Ignore an argument. + + This can be used when we don't care about an argument of a method call. + + Example: + # Check if CastMagic is called with 3 as first arg and 'disappear' as third. + mymock.CastMagic(3, IgnoreArg(), 'disappear') + """ + + def equals(self, unused_rhs): + """Ignores arguments and returns True. + + Args: + unused_rhs: any python object + + Returns: + always returns True + """ + + return True + + def __repr__(self): + return '<IgnoreArg>' + + +class MethodGroup(object): + """Base class containing common behaviour for MethodGroups.""" + + def __init__(self, group_name): + self._group_name = group_name + + def group_name(self): + return self._group_name + + def __str__(self): + return '<%s "%s">' % (self.__class__.__name__, self._group_name) + + def AddMethod(self, mock_method): + raise NotImplementedError + + def MethodCalled(self, mock_method): + raise NotImplementedError + + def IsSatisfied(self): + raise NotImplementedError + +class UnorderedGroup(MethodGroup): + """UnorderedGroup holds a set of method calls that may occur in any order. + + This construct is helpful for non-deterministic events, such as iterating + over the keys of a dict. + """ + + def __init__(self, group_name): + super(UnorderedGroup, self).__init__(group_name) + self._methods = [] + + def AddMethod(self, mock_method): + """Add a method to this group. + + Args: + mock_method: A mock method to be added to this group. + """ + + self._methods.append(mock_method) + + def MethodCalled(self, mock_method): + """Remove a method call from the group. + + If the method is not in the set, an UnexpectedMethodCallError will be + raised. + + Args: + mock_method: a mock method that should be equal to a method in the group. + + Returns: + The mock method from the group + + Raises: + UnexpectedMethodCallError if the mock_method was not in the group. + """ + + # Check to see if this method exists, and if so, remove it from the set + # and return it. + for method in self._methods: + if method == mock_method: + # Remove the called mock_method instead of the method in the group. + # The called method will match any comparators when equality is checked + # during removal. The method in the group could pass a comparator to + # another comparator during the equality check. + self._methods.remove(mock_method) + + # If this group is not empty, put it back at the head of the queue. + if not self.IsSatisfied(): + mock_method._call_queue.appendleft(self) + + return self, method + + raise UnexpectedMethodCallError(mock_method, self) + + def IsSatisfied(self): + """Return True if there are not any methods in this group.""" + + return len(self._methods) == 0 + + +class MultipleTimesGroup(MethodGroup): + """MultipleTimesGroup holds methods that may be called any number of times. + + Note: Each method must be called at least once. + + This is helpful, if you don't know or care how many times a method is called. + """ + + def __init__(self, group_name): + super(MultipleTimesGroup, self).__init__(group_name) + self._methods = set() + self._methods_called = set() + + def AddMethod(self, mock_method): + """Add a method to this group. + + Args: + mock_method: A mock method to be added to this group. + """ + + self._methods.add(mock_method) + + def MethodCalled(self, mock_method): + """Remove a method call from the group. + + If the method is not in the set, an UnexpectedMethodCallError will be + raised. + + Args: + mock_method: a mock method that should be equal to a method in the group. + + Returns: + The mock method from the group + + Raises: + UnexpectedMethodCallError if the mock_method was not in the group. + """ + + # Check to see if this method exists, and if so add it to the set of + # called methods. + + for method in self._methods: + if method == mock_method: + self._methods_called.add(mock_method) + # Always put this group back on top of the queue, because we don't know + # when we are done. + mock_method._call_queue.appendleft(self) + return self, method + + if self.IsSatisfied(): + next_method = mock_method._PopNextMethod(); + return next_method, None + else: + raise UnexpectedMethodCallError(mock_method, self) + + def IsSatisfied(self): + """Return True if all methods in this group are called at least once.""" + # NOTE(psycho): We can't use the simple set difference here because we want + # to match different parameters which are considered the same e.g. IsA(str) + # and some string. This solution is O(n^2) but n should be small. + tmp = self._methods.copy() + for called in self._methods_called: + for expected in tmp: + if called == expected: + tmp.remove(expected) + if not tmp: + return True + break + return False + + +class MoxMetaTestBase(type): + """Metaclass to add mox cleanup and verification to every test. + + As the mox unit testing class is being constructed (MoxTestBase or a + subclass), this metaclass will modify all test functions to call the + CleanUpMox method of the test class after they finish. This means that + unstubbing and verifying will happen for every test with no additional code, + and any failures will result in test failures as opposed to errors. + """ + + def __init__(cls, name, bases, d): + type.__init__(cls, name, bases, d) + + # also get all the attributes from the base classes to account + # for a case when test class is not the immediate child of MoxTestBase + for base in bases: + for attr_name in dir(base): + d[attr_name] = getattr(base, attr_name) + + for func_name, func in d.items(): + if func_name.startswith('test') and callable(func): + setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) + + @staticmethod + def CleanUpTest(cls, func): + """Adds Mox cleanup code to any MoxTestBase method. + + Always unsets stubs after a test. Will verify all mocks for tests that + otherwise pass. + + Args: + cls: MoxTestBase or subclass; the class whose test method we are altering. + func: method; the method of the MoxTestBase test class we wish to alter. + + Returns: + The modified method. + """ + def new_method(self, *args, **kwargs): + mox_obj = getattr(self, 'mox', None) + cleanup_mox = False + if mox_obj and isinstance(mox_obj, Mox): + cleanup_mox = True + try: + func(self, *args, **kwargs) + finally: + if cleanup_mox: + mox_obj.UnsetStubs() + if cleanup_mox: + mox_obj.VerifyAll() + new_method.__name__ = func.__name__ + new_method.__doc__ = func.__doc__ + new_method.__module__ = func.__module__ + return new_method + + +class MoxTestBase(unittest.TestCase): + """Convenience test class to make stubbing easier. + + Sets up a "mox" attribute which is an instance of Mox - any mox tests will + want this. Also automatically unsets any stubs and verifies that all mock + methods have been called at the end of each test, eliminating boilerplate + code. + """ + + __metaclass__ = MoxMetaTestBase + + def setUp(self): + self.mox = Mox() diff --git a/python/setup.py b/python/setup.py new file mode 100755 index 0000000..d2997da --- /dev/null +++ b/python/setup.py @@ -0,0 +1,136 @@ +#! /usr/bin/python +# +# See README for usage instructions. + +# We must use setuptools, not distutils, because we need to use the +# namespace_packages option for the "google" package. +from ez_setup import use_setuptools +use_setuptools() + +from setuptools import setup +from distutils.spawn import find_executable +import sys +import os +import subprocess + +maintainer_email = "protobuf@googlegroups.com" + +# Find the Protocol Compiler. +if os.path.exists("../src/protoc"): + protoc = "../src/protoc" +elif os.path.exists("../src/protoc.exe"): + protoc = "../src/protoc.exe" +else: + protoc = find_executable("protoc") + +def generate_proto(source): + """Invokes the Protocol Compiler to generate a _pb2.py from the given + .proto file. Does nothing if the output already exists and is newer than + the input.""" + + output = source.replace(".proto", "_pb2.py").replace("../src/", "") + + if not os.path.exists(source): + print "Can't find required file: " + source + sys.exit(-1) + + if (not os.path.exists(output) or + (os.path.exists(source) and + os.path.getmtime(source) > os.path.getmtime(output))): + print "Generating %s..." % output + + if protoc == None: + sys.stderr.write( + "protoc is not installed nor found in ../src. Please compile it " + "or install the binary package.\n") + sys.exit(-1) + + protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ] + if subprocess.call(protoc_command) != 0: + sys.exit(-1) + +def MakeTestSuite(): + # This is apparently needed on some systems to make sure that the tests + # work even if a previous version is already installed. + if 'google' in sys.modules: + del sys.modules['google'] + + generate_proto("../src/google/protobuf/unittest.proto") + generate_proto("../src/google/protobuf/unittest_import.proto") + generate_proto("../src/google/protobuf/unittest_mset.proto") + generate_proto("google/protobuf/internal/more_extensions.proto") + generate_proto("google/protobuf/internal/more_messages.proto") + + import unittest + import google.protobuf.internal.generator_test as generator_test + import google.protobuf.internal.decoder_test as decoder_test + import google.protobuf.internal.descriptor_test as descriptor_test + import google.protobuf.internal.encoder_test as encoder_test + import google.protobuf.internal.input_stream_test as input_stream_test + import google.protobuf.internal.output_stream_test as output_stream_test + import google.protobuf.internal.reflection_test as reflection_test + import google.protobuf.internal.service_reflection_test \ + as service_reflection_test + import google.protobuf.internal.text_format_test as text_format_test + import google.protobuf.internal.wire_format_test as wire_format_test + + loader = unittest.defaultTestLoader + suite = unittest.TestSuite() + for test in [ generator_test, + decoder_test, + descriptor_test, + encoder_test, + input_stream_test, + output_stream_test, + reflection_test, + service_reflection_test, + text_format_test, + wire_format_test ]: + suite.addTest(loader.loadTestsFromModule(test)) + + return suite + +if __name__ == '__main__': + # TODO(kenton): Integrate this into setuptools somehow? + if len(sys.argv) >= 2 and sys.argv[1] == "clean": + # Delete generated _pb2.py files and .pyc files in the code tree. + for (dirpath, dirnames, filenames) in os.walk("."): + for filename in filenames: + filepath = os.path.join(dirpath, filename) + if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"): + os.remove(filepath) + else: + # Generate necessary .proto file if it doesn't exist. + # TODO(kenton): Maybe we should hook this into a distutils command? + generate_proto("../src/google/protobuf/descriptor.proto") + + setup(name = 'protobuf', + version = '2.2.0', + packages = [ 'google' ], + namespace_packages = [ 'google' ], + test_suite = 'setup.MakeTestSuite', + # Must list modules explicitly so that we don't install tests. + py_modules = [ + 'google.protobuf.internal.containers', + 'google.protobuf.internal.decoder', + 'google.protobuf.internal.encoder', + 'google.protobuf.internal.input_stream', + 'google.protobuf.internal.message_listener', + 'google.protobuf.internal.output_stream', + 'google.protobuf.internal.type_checkers', + 'google.protobuf.internal.wire_format', + 'google.protobuf.descriptor', + 'google.protobuf.descriptor_pb2', + 'google.protobuf.message', + 'google.protobuf.reflection', + 'google.protobuf.service', + 'google.protobuf.service_reflection', + 'google.protobuf.text_format' ], + url = 'http://code.google.com/p/protobuf/', + maintainer = maintainer_email, + maintainer_email = 'protobuf@googlegroups.com', + license = 'New BSD License', + description = 'Protocol Buffers', + long_description = + "Protocol Buffers are Google's data interchange format.", + ) diff --git a/python/stubout.py b/python/stubout.py new file mode 100755 index 0000000..aee4f2d --- /dev/null +++ b/python/stubout.py @@ -0,0 +1,140 @@ +#!/usr/bin/python2.4 +# +# Copyright 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is used for testing. The original is at: +# http://code.google.com/p/pymox/ + +class StubOutForTesting: + """Sample Usage: + You want os.path.exists() to always return true during testing. + + stubs = StubOutForTesting() + stubs.Set(os.path, 'exists', lambda x: 1) + ... + stubs.UnsetAll() + + The above changes os.path.exists into a lambda that returns 1. Once + the ... part of the code finishes, the UnsetAll() looks up the old value + of os.path.exists and restores it. + + """ + def __init__(self): + self.cache = [] + self.stubs = [] + + def __del__(self): + self.SmartUnsetAll() + self.UnsetAll() + + def SmartSet(self, obj, attr_name, new_attr): + """Replace obj.attr_name with new_attr. This method is smart and works + at the module, class, and instance level while preserving proper + inheritance. It will not stub out C types however unless that has been + explicitly allowed by the type. + + This method supports the case where attr_name is a staticmethod or a + classmethod of obj. + + Notes: + - If obj is an instance, then it is its class that will actually be + stubbed. Note that the method Set() does not do that: if obj is + an instance, it (and not its class) will be stubbed. + - The stubbing is using the builtin getattr and setattr. So, the __get__ + and __set__ will be called when stubbing (TODO: A better idea would + probably be to manipulate obj.__dict__ instead of getattr() and + setattr()). + + Raises AttributeError if the attribute cannot be found. + """ + if (inspect.ismodule(obj) or + (not inspect.isclass(obj) and obj.__dict__.has_key(attr_name))): + orig_obj = obj + orig_attr = getattr(obj, attr_name) + + else: + if not inspect.isclass(obj): + mro = list(inspect.getmro(obj.__class__)) + else: + mro = list(inspect.getmro(obj)) + + mro.reverse() + + orig_attr = None + + for cls in mro: + try: + orig_obj = cls + orig_attr = getattr(obj, attr_name) + except AttributeError: + continue + + if orig_attr is None: + raise AttributeError("Attribute not found.") + + # Calling getattr() on a staticmethod transforms it to a 'normal' function. + # We need to ensure that we put it back as a staticmethod. + old_attribute = obj.__dict__.get(attr_name) + if old_attribute is not None and isinstance(old_attribute, staticmethod): + orig_attr = staticmethod(orig_attr) + + self.stubs.append((orig_obj, attr_name, orig_attr)) + setattr(orig_obj, attr_name, new_attr) + + def SmartUnsetAll(self): + """Reverses all the SmartSet() calls, restoring things to their original + definition. Its okay to call SmartUnsetAll() repeatedly, as later calls + have no effect if no SmartSet() calls have been made. + + """ + self.stubs.reverse() + + for args in self.stubs: + setattr(*args) + + self.stubs = [] + + def Set(self, parent, child_name, new_child): + """Replace child_name's old definition with new_child, in the context + of the given parent. The parent could be a module when the child is a + function at module scope. Or the parent could be a class when a class' + method is being replaced. The named child is set to new_child, while + the prior definition is saved away for later, when UnsetAll() is called. + + This method supports the case where child_name is a staticmethod or a + classmethod of parent. + """ + old_child = getattr(parent, child_name) + + old_attribute = parent.__dict__.get(child_name) + if old_attribute is not None and isinstance(old_attribute, staticmethod): + old_child = staticmethod(old_child) + + self.cache.append((parent, old_child, child_name)) + setattr(parent, child_name, new_child) + + def UnsetAll(self): + """Reverses all the Set() calls, restoring things to their original + definition. Its okay to call UnsetAll() repeatedly, as later calls have + no effect if no Set() calls have been made. + + """ + # Undo calls to Set() in reverse order, in case Set() was called on the + # same arguments repeatedly (want the original call to be last one undone) + self.cache.reverse() + + for (parent, old_child, child_name) in self.cache: + setattr(parent, child_name, old_child) + self.cache = [] |