summaryrefslogtreecommitdiffstats
path: root/python/google/protobuf/internal/reflection_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py870
1 files changed, 784 insertions, 86 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 2c9fa30..d59815d 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -3,7 +3,7 @@
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
+# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -37,12 +37,12 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)'
+import copy
+import gc
import operator
import struct
-import unittest
-# TODO(robinson): When we split this test in two, only some of these imports
-# will be necessary in each test.
+from google.apputils import basetest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
@@ -50,6 +50,8 @@ from google.protobuf import descriptor_pb2
from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
+from google.protobuf import text_format
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import more_messages_pb2
from google.protobuf.internal import wire_format
@@ -102,12 +104,12 @@ class _MiniDecoder(object):
return self._pos == len(self._bytes)
-class ReflectionTest(unittest.TestCase):
+class ReflectionTest(basetest.TestCase):
- def assertIs(self, values, others):
+ def assertListsEqual(self, values, others):
self.assertEqual(len(values), len(others))
for i in range(len(values)):
- self.assertTrue(values[i] is others[i])
+ self.assertEqual(values[i], others[i])
def testScalarConstructor(self):
# Constructor with only scalar types should succeed.
@@ -200,6 +202,41 @@ class ReflectionTest(unittest.TestCase):
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
+ def testConstructorTypeError(self):
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
+ self.assertRaises(
+ TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
+
+ def testConstructorInvalidatesCachedByteSize(self):
+ message = unittest_pb2.TestAllTypes(optional_int32 = 12)
+ self.assertEquals(2, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(
+ optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
+ self.assertEquals(3, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
+ self.assertEquals(3, message.ByteSize())
+
+ message = unittest_pb2.TestAllTypes(
+ repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
+ self.assertEquals(3, message.ByteSize())
+
def testSimpleHasBits(self):
# Test a scalar.
proto = unittest_pb2.TestAllTypes()
@@ -284,12 +321,6 @@ class ReflectionTest(unittest.TestCase):
# ...and ensure that the scalar field has returned to its default.
self.assertEqual(0, getattr(composite_field, scalar_field_name))
- # Finally, ensure that modifications to the old composite field object
- # don't have any effect on the parent.
- #
- # (NOTE that when we clear the composite field in the parent, we actually
- # don't recursively clear down the tree. Instead, we just disconnect the
- # cleared composite from the tree.)
self.assertTrue(old_composite_field is not composite_field)
setattr(old_composite_field, scalar_field_name, new_val)
self.assertTrue(not composite_field.HasField(scalar_field_name))
@@ -319,6 +350,64 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(not proto.HasField('optional_nested_message'))
self.assertEqual(0, proto.optional_nested_message.bb)
+ def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ proto.ClearField('optional_nested_message')
+ del proto
+ del nested
+ # Force a garbage collect so that the underlying CMessages are freed along
+ # with the Messages they point to. This is to make sure we're not deleting
+ # default message instances.
+ gc.collect()
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+
+ def testDisconnectingNestedMessageAfterSettingField(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ nested.bb = 5
+ self.assertTrue(proto.HasField('optional_nested_message'))
+ proto.ClearField('optional_nested_message') # Should disconnect from parent
+ self.assertEqual(5, nested.bb)
+ self.assertEqual(0, proto.optional_nested_message.bb)
+ self.assertTrue(nested is not proto.optional_nested_message)
+ nested.bb = 23
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+
+ def testDisconnectingNestedMessageBeforeGettingField(self):
+ proto = unittest_pb2.TestAllTypes()
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ proto.ClearField('optional_nested_message')
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+
+ def testDisconnectingNestedMessageAfterMerge(self):
+ # This test exercises the code path that does not use ReleaseMessage().
+ # The underlying fear is that if we use ReleaseMessage() incorrectly,
+ # we will have memory leaks. It's hard to check that that doesn't happen,
+ # but at least we can exercise that code path to make sure it works.
+ proto1 = unittest_pb2.TestAllTypes()
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.optional_nested_message.bb = 5
+ proto1.MergeFrom(proto2)
+ self.assertTrue(proto1.HasField('optional_nested_message'))
+ proto1.ClearField('optional_nested_message')
+ self.assertTrue(not proto1.HasField('optional_nested_message'))
+
+ def testDisconnectingLazyNestedMessage(self):
+ # This test exercises releasing a nested message that is lazy. This test
+ # only exercises real code in the C++ implementation as Python does not
+ # support lazy parsing, but the current C++ implementation results in
+ # memory corruption and a crash.
+ if api_implementation.Type() != 'python':
+ return
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_lazy_message.bb = 5
+ proto.ClearField('optional_lazy_message')
+ del proto
+ gc.collect()
+
def testHasBitsWhenModifyingRepeatedFields(self):
# Test nesting when we add an element to a repeated field in a submessage.
proto = unittest_pb2.TestNestedMessageHasBits()
@@ -446,7 +535,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(0.0, proto.optional_double)
self.assertEqual(False, proto.optional_bool)
self.assertEqual('', proto.optional_string)
- self.assertEqual('', proto.optional_bytes)
+ self.assertEqual(b'', proto.optional_bytes)
self.assertEqual(41, proto.default_int32)
self.assertEqual(42, proto.default_int64)
@@ -462,7 +551,7 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(52e3, proto.default_double)
self.assertEqual(True, proto.default_bool)
self.assertEqual('hello', proto.default_string)
- self.assertEqual('world', proto.default_bytes)
+ self.assertEqual(b'world', proto.default_bytes)
self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_BAR,
@@ -479,6 +568,17 @@ class ReflectionTest(unittest.TestCase):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
+ def testClearRemovesChildren(self):
+ # Make sure there aren't any implementation bugs that are only partially
+ # clearing the message (which can happen in the more complex C++
+ # implementation which has parallel message lists).
+ proto = unittest_pb2.TestRequiredForeign()
+ for i in range(10):
+ proto.repeated_message.add()
+ proto2 = unittest_pb2.TestRequiredForeign()
+ proto.CopyFrom(proto2)
+ self.assertRaises(IndexError, lambda: proto.repeated_message[5])
+
def testDisallowedAssignments(self):
# It's illegal to assign values directly to repeated fields
# or to nonrepeated composite fields. Ensure that this fails.
@@ -500,7 +600,6 @@ class ReflectionTest(unittest.TestCase):
# proto.nonexistent_field = 23 should fail as well.
self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
- # TODO(robinson): Add type-safety check for enums.
def testSingleScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
@@ -508,11 +607,37 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
+ def testIntegerTypes(self):
+ def TestGetAndDeserialize(field_name, value, expected_type):
+ proto = unittest_pb2.TestAllTypes()
+ setattr(proto, field_name, value)
+ self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
+ proto2 = unittest_pb2.TestAllTypes()
+ proto2.ParseFromString(proto.SerializeToString())
+ self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
+
+ TestGetAndDeserialize('optional_int32', 1, int)
+ TestGetAndDeserialize('optional_int32', 1 << 30, int)
+ TestGetAndDeserialize('optional_uint32', 1 << 30, int)
+ if struct.calcsize('L') == 4:
+ # Python only has signed ints, so 32-bit python can't fit an uint32
+ # in an int.
+ TestGetAndDeserialize('optional_uint32', 1 << 31, long)
+ else:
+ # 64-bit python can fit uint32 inside an int
+ TestGetAndDeserialize('optional_uint32', 1 << 31, int)
+ TestGetAndDeserialize('optional_int64', 1 << 30, long)
+ TestGetAndDeserialize('optional_int64', 1 << 60, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 30, long)
+ TestGetAndDeserialize('optional_uint64', 1 << 60, long)
+
def testSingleScalarBoundsChecking(self):
def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
pb = unittest_pb2.TestAllTypes()
setattr(pb, field_name, expected_min)
+ self.assertEqual(expected_min, getattr(pb, field_name))
setattr(pb, field_name, expected_max)
+ self.assertEqual(expected_max, getattr(pb, field_name))
self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
@@ -520,7 +645,10 @@ class ReflectionTest(unittest.TestCase):
TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
- TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
+
+ pb = unittest_pb2.TestAllTypes()
+ pb.optional_nested_enum = 1
+ self.assertEqual(1, pb.optional_nested_enum)
def testRepeatedScalarTypeSafety(self):
proto = unittest_pb2.TestAllTypes()
@@ -534,11 +662,19 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
+ # Repeated enums tests.
+ #proto.repeated_nested_enum.append(0)
+
def testSingleScalarGettersAndSetters(self):
proto = unittest_pb2.TestAllTypes()
self.assertEqual(0, proto.optional_int32)
proto.optional_int32 = 1
self.assertEqual(1, proto.optional_int32)
+
+ proto.optional_uint64 = 0xffffffffffff
+ self.assertEqual(0xffffffffffff, proto.optional_uint64)
+ proto.optional_uint64 = 0xffffffffffffffff
+ self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
# TODO(robinson): Test all other scalar field types.
def testSingleScalarClearField(self):
@@ -561,6 +697,77 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
+ def testEnum_Name(self):
+ self.assertEqual('FOREIGN_FOO',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
+ self.assertEqual('FOREIGN_BAR',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
+ self.assertEqual('FOREIGN_BAZ',
+ unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
+ self.assertRaises(ValueError,
+ unittest_pb2.ForeignEnum.Name, 11312)
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual('FOO',
+ proto.NestedEnum.Name(proto.FOO))
+ self.assertEqual('FOO',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
+ self.assertEqual('BAR',
+ proto.NestedEnum.Name(proto.BAR))
+ self.assertEqual('BAR',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
+ self.assertEqual('BAZ',
+ proto.NestedEnum.Name(proto.BAZ))
+ self.assertEqual('BAZ',
+ unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
+ self.assertRaises(ValueError,
+ proto.NestedEnum.Name, 11312)
+ self.assertRaises(ValueError,
+ unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
+
+ def testEnum_Value(self):
+ self.assertEqual(unittest_pb2.FOREIGN_FOO,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
+ self.assertEqual(unittest_pb2.FOREIGN_BAR,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
+ self.assertEqual(unittest_pb2.FOREIGN_BAZ,
+ unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
+ self.assertRaises(ValueError,
+ unittest_pb2.ForeignEnum.Value, 'FO')
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(proto.FOO,
+ proto.NestedEnum.Value('FOO'))
+ self.assertEqual(proto.FOO,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
+ self.assertEqual(proto.BAR,
+ proto.NestedEnum.Value('BAR'))
+ self.assertEqual(proto.BAR,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
+ self.assertEqual(proto.BAZ,
+ proto.NestedEnum.Value('BAZ'))
+ self.assertEqual(proto.BAZ,
+ unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
+ self.assertRaises(ValueError,
+ proto.NestedEnum.Value, 'Foo')
+ self.assertRaises(ValueError,
+ unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
+
+ def testEnum_KeysAndValues(self):
+ self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
+ unittest_pb2.ForeignEnum.keys())
+ self.assertEqual([4, 5, 6],
+ unittest_pb2.ForeignEnum.values())
+ self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
+ ('FOREIGN_BAZ', 6)],
+ unittest_pb2.ForeignEnum.items())
+
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
+ self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
+ self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
+ proto.NestedEnum.items())
+
def testRepeatedScalars(self):
proto = unittest_pb2.TestAllTypes()
@@ -619,11 +826,38 @@ class ReflectionTest(unittest.TestCase):
del proto.repeated_int32[2:]
self.assertEqual([5, 35], proto.repeated_int32)
+ # Test extending.
+ proto.repeated_int32.extend([3, 13])
+ self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
+
# Test clearing.
proto.ClearField('repeated_int32')
self.assertTrue(not proto.repeated_int32)
self.assertEqual(0, len(proto.repeated_int32))
+ proto.repeated_int32.append(1)
+ self.assertEqual(1, proto.repeated_int32[-1])
+ # Test assignment to a negative index.
+ proto.repeated_int32[-1] = 2
+ self.assertEqual(2, proto.repeated_int32[-1])
+
+ # Test deletion at negative indices.
+ proto.repeated_int32[:] = [0, 1, 2, 3]
+ del proto.repeated_int32[-1]
+ self.assertEqual([0, 1, 2], proto.repeated_int32)
+
+ del proto.repeated_int32[-2]
+ self.assertEqual([0, 2], proto.repeated_int32)
+
+ self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
+ self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
+
+ del proto.repeated_int32[-2:-1]
+ self.assertEqual([2], proto.repeated_int32)
+
+ del proto.repeated_int32[100:10000]
+ self.assertEqual([2], proto.repeated_int32)
+
def testRepeatedScalarsRemove(self):
proto = unittest_pb2.TestAllTypes()
@@ -661,7 +895,7 @@ class ReflectionTest(unittest.TestCase):
m1 = proto.repeated_nested_message.add()
self.assertTrue(proto.repeated_nested_message)
self.assertEqual(2, len(proto.repeated_nested_message))
- self.assertIs([m0, m1], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1], proto.repeated_nested_message)
self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
# Test out-of-bounds indices.
@@ -680,32 +914,86 @@ class ReflectionTest(unittest.TestCase):
m2 = proto.repeated_nested_message.add()
m3 = proto.repeated_nested_message.add()
m4 = proto.repeated_nested_message.add()
- self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
- self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
+ self.assertListsEqual(
+ [m1, m2, m3], proto.repeated_nested_message[1:4])
+ self.assertListsEqual(
+ [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
+ self.assertListsEqual(
+ [m0, m1], proto.repeated_nested_message[:2])
+ self.assertListsEqual(
+ [m2, m3, m4], proto.repeated_nested_message[2:])
+ self.assertEqual(
+ m0, proto.repeated_nested_message[0])
+ self.assertListsEqual(
+ [m0], proto.repeated_nested_message[:1])
# Test that we can use the field as an iterator.
result = []
for i in proto.repeated_nested_message:
result.append(i)
- self.assertIs([m0, m1, m2, m3, m4], result)
+ self.assertListsEqual([m0, m1, m2, m3, m4], result)
# Test single deletion.
del proto.repeated_nested_message[2]
- self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
# Test slice deletion.
del proto.repeated_nested_message[2:]
- self.assertIs([m0, m1], proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1], proto.repeated_nested_message)
+
+ # Test extending.
+ n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
+ n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
+ proto.repeated_nested_message.extend([n1,n2])
+ self.assertEqual(4, len(proto.repeated_nested_message))
+ self.assertEqual(n1, proto.repeated_nested_message[2])
+ self.assertEqual(n2, proto.repeated_nested_message[3])
# Test clearing.
proto.ClearField('repeated_nested_message')
self.assertTrue(not proto.repeated_nested_message)
self.assertEqual(0, len(proto.repeated_nested_message))
+ # Test constructing an element while adding it.
+ proto.repeated_nested_message.add(bb=23)
+ self.assertEqual(1, len(proto.repeated_nested_message))
+ self.assertEqual(23, proto.repeated_nested_message[0].bb)
+
+ def testRepeatedCompositeRemove(self):
+ proto = unittest_pb2.TestAllTypes()
+
+ self.assertEqual(0, len(proto.repeated_nested_message))
+ m0 = proto.repeated_nested_message.add()
+ # Need to set some differentiating variable so m0 != m1 != m2:
+ m0.bb = len(proto.repeated_nested_message)
+ m1 = proto.repeated_nested_message.add()
+ m1.bb = len(proto.repeated_nested_message)
+ self.assertTrue(m0 != m1)
+ m2 = proto.repeated_nested_message.add()
+ m2.bb = len(proto.repeated_nested_message)
+ self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
+
+ self.assertEqual(3, len(proto.repeated_nested_message))
+ proto.repeated_nested_message.remove(m0)
+ self.assertEqual(2, len(proto.repeated_nested_message))
+ self.assertEqual(m1, proto.repeated_nested_message[0])
+ self.assertEqual(m2, proto.repeated_nested_message[1])
+
+ # Removing m0 again or removing None should raise error
+ self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
+ self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
+ self.assertEqual(2, len(proto.repeated_nested_message))
+
+ proto.repeated_nested_message.remove(m2)
+ self.assertEqual(1, len(proto.repeated_nested_message))
+ self.assertEqual(m1, proto.repeated_nested_message[0])
+
def testHandWrittenReflection(self):
- # TODO(robinson): We probably need a better way to specify
- # protocol types by hand. But then again, this isn't something
- # we expect many people to do. Hmm.
+ # Hand written extensions are only supported by the pure-Python
+ # implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
FieldDescriptor = descriptor.FieldDescriptor
foo_field_descriptor = FieldDescriptor(
name='foo_field', full_name='MyProto.foo_field',
@@ -730,6 +1018,68 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
+ def testDescriptorProtoSupport(self):
+ # Hand written descriptors/reflection are only supported by the pure-Python
+ # implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
+ def AddDescriptorField(proto, field_name, field_type):
+ AddDescriptorField.field_index += 1
+ new_field = proto.field.add()
+ new_field.name = field_name
+ new_field.type = field_type
+ new_field.number = AddDescriptorField.field_index
+ new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
+
+ AddDescriptorField.field_index = 0
+
+ desc_proto = descriptor_pb2.DescriptorProto()
+ desc_proto.name = 'Car'
+ fdp = descriptor_pb2.FieldDescriptorProto
+ AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
+ AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
+ AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
+ AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
+ # Add a repeated field
+ AddDescriptorField.field_index += 1
+ new_field = desc_proto.field.add()
+ new_field.name = 'owners'
+ new_field.type = fdp.TYPE_STRING
+ new_field.number = AddDescriptorField.field_index
+ new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
+
+ desc = descriptor.MakeDescriptor(desc_proto)
+ self.assertTrue(desc.fields_by_name.has_key('name'))
+ self.assertTrue(desc.fields_by_name.has_key('year'))
+ self.assertTrue(desc.fields_by_name.has_key('automatic'))
+ self.assertTrue(desc.fields_by_name.has_key('price'))
+ self.assertTrue(desc.fields_by_name.has_key('owners'))
+
+ class CarMessage(message.Message):
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ DESCRIPTOR = desc
+
+ prius = CarMessage()
+ prius.name = 'prius'
+ prius.year = 2010
+ prius.automatic = True
+ prius.price = 25134.75
+ prius.owners.extend(['bob', 'susan'])
+
+ serialized_prius = prius.SerializeToString()
+ new_prius = reflection.ParseMessage(desc, serialized_prius)
+ self.assertTrue(new_prius is not prius)
+ self.assertEqual(prius, new_prius)
+
+ # these are unnecessary assuming message equality works as advertised but
+ # explicitly check to be safe since we're mucking about in metaclass foo
+ self.assertEqual(prius.name, new_prius.name)
+ self.assertEqual(prius.year, new_prius.year)
+ self.assertEqual(prius.automatic, new_prius.automatic)
+ self.assertEqual(prius.price, new_prius.price)
+ self.assertEqual(prius.owners, new_prius.owners)
+
def testTopLevelExtensionsForOptionalScalar(self):
extendee_proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.optional_int32_extension
@@ -819,6 +1169,14 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(required is not extendee_proto.Extensions[extension])
self.assertTrue(not extendee_proto.HasExtension(extension))
+ def testRegisteredExtensions(self):
+ self.assertTrue('protobuf_unittest.optional_int32_extension' in
+ unittest_pb2.TestAllExtensions._extensions_by_name)
+ self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
+ # Make sure extensions haven't been registered into types that shouldn't
+ # have any.
+ self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
+
# If message A directly contains message B, and
# a.HasField('b') is currently False, then mutating any
# extension in B should change a.HasField('b') to True
@@ -868,7 +1226,7 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(not toplevel.HasField('submessage'))
foreign = toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension].add()
- self.assertTrue(foreign is toplevel.submessage.Extensions[
+ self.assertEqual(foreign, toplevel.submessage.Extensions[
more_extensions_pb2.repeated_message_extension][0])
self.assertTrue(toplevel.HasField('submessage'))
@@ -971,6 +1329,12 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(123, proto2.repeated_nested_message[1].bb)
self.assertEqual(321, proto2.repeated_nested_message[2].bb)
+ proto3 = unittest_pb2.TestAllTypes()
+ proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
+ self.assertEqual(999, proto3.repeated_nested_message[0].bb)
+ self.assertEqual(123, proto3.repeated_nested_message[1].bb)
+ self.assertEqual(321, proto3.repeated_nested_message[2].bb)
+
def testMergeFromAllFields(self):
# With all fields set.
proto1 = unittest_pb2.TestAllTypes()
@@ -1035,6 +1399,19 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(222, ext2[1].bb)
self.assertEqual(333, ext2[2].bb)
+ def testMergeFromBug(self):
+ message1 = unittest_pb2.TestAllTypes()
+ message2 = unittest_pb2.TestAllTypes()
+
+ # Cause optional_nested_message to be instantiated within message1, even
+ # though it is not considered to be "present".
+ message1.optional_nested_message
+ self.assertFalse(message1.HasField('optional_nested_message'))
+
+ # Merge into message2. This should not instantiate the field is message2.
+ message2.MergeFrom(message1)
+ self.assertFalse(message2.HasField('optional_nested_message'))
+
def testCopyFromSingularField(self):
# Test copy with just a singular field.
proto1 = unittest_pb2.TestAllTypes()
@@ -1087,9 +1464,36 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(2, proto1.optional_int32)
self.assertEqual('important-text', proto1.optional_string)
+ def testCopyFromBadType(self):
+ # The python implementation doesn't raise an exception in this
+ # case. In theory it should.
+ if api_implementation.Type() == 'python':
+ return
+ proto1 = unittest_pb2.TestAllTypes()
+ proto2 = unittest_pb2.TestAllExtensions()
+ self.assertRaises(TypeError, proto1.CopyFrom, proto2)
+
+ def testDeepCopy(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ proto1.optional_int32 = 1
+ proto2 = copy.deepcopy(proto1)
+ self.assertEqual(1, proto2.optional_int32)
+
+ proto1.repeated_int32.append(2)
+ proto1.repeated_int32.append(3)
+ container = copy.deepcopy(proto1.repeated_int32)
+ self.assertEqual([2, 3], container)
+
+ # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
+
def testClear(self):
proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto)
+ # C++ implementation does not support lazy fields right now so leave it
+ # out for now.
+ if api_implementation.Type() == 'python':
+ test_util.SetAllFields(proto)
+ else:
+ test_util.SetAllNonLazyFields(proto)
# Clear the message.
proto.Clear()
self.assertEquals(proto.ByteSize(), 0)
@@ -1105,6 +1509,45 @@ class ReflectionTest(unittest.TestCase):
empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto)
+ def testDisconnectingBeforeClear(self):
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ proto.Clear()
+ self.assertTrue(nested is not proto.optional_nested_message)
+ nested.bb = 23
+ self.assertTrue(not proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+
+ proto = unittest_pb2.TestAllTypes()
+ nested = proto.optional_nested_message
+ nested.bb = 5
+ foreign = proto.optional_foreign_message
+ foreign.c = 6
+
+ proto.Clear()
+ self.assertTrue(nested is not proto.optional_nested_message)
+ self.assertTrue(foreign is not proto.optional_foreign_message)
+ self.assertEqual(5, nested.bb)
+ self.assertEqual(6, foreign.c)
+ nested.bb = 15
+ foreign.c = 16
+ self.assertFalse(proto.HasField('optional_nested_message'))
+ self.assertEqual(0, proto.optional_nested_message.bb)
+ self.assertFalse(proto.HasField('optional_foreign_message'))
+ self.assertEqual(0, proto.optional_foreign_message.c)
+
+ def testOneOf(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.oneof_uint32 = 10
+ proto.oneof_nested_message.bb = 11
+ self.assertEqual(11, proto.oneof_nested_message.bb)
+ self.assertFalse(proto.HasField('oneof_uint32'))
+ nested = proto.oneof_nested_message
+ proto.oneof_string = 'abc'
+ self.assertEqual('abc', proto.oneof_string)
+ self.assertEqual(11, nested.bb)
+ self.assertFalse(proto.HasField('oneof_nested_message'))
+
def assertInitialized(self, proto):
self.assertTrue(proto.IsInitialized())
# Neither method should raise an exception.
@@ -1175,6 +1618,40 @@ class ReflectionTest(unittest.TestCase):
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
+ @basetest.unittest.skipIf(
+ api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
+ 'Errors are only available from the most recent C++ implementation.')
+ def testFileDescriptorErrors(self):
+ file_name = 'test_file_descriptor_errors.proto'
+ package_name = 'test_file_descriptor_errors.proto'
+ file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
+ file_descriptor_proto.name = file_name
+ file_descriptor_proto.package = package_name
+ m1 = file_descriptor_proto.message_type.add()
+ m1.name = 'msg1'
+ # Compiles the proto into the C++ descriptor pool
+ descriptor.FileDescriptor(
+ file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ # Add a FileDescriptorProto that has duplicate symbols
+ another_file_name = 'another_test_file_descriptor_errors.proto'
+ file_descriptor_proto.name = another_file_name
+ m2 = file_descriptor_proto.message_type.add()
+ m2.name = 'msg2'
+ with self.assertRaises(TypeError) as cm:
+ descriptor.FileDescriptor(
+ another_file_name,
+ package_name,
+ serialized_pb=file_descriptor_proto.SerializeToString())
+ self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
+ getattr(cm.expected, '__name__', cm.expected))
+ self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
+ # Error message will say something about this definition being a
+ # duplicate, though we don't check the message exactly to avoid a
+ # dependency on the C++ logging code.
+ self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
+
def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes()
@@ -1192,16 +1669,15 @@ class ReflectionTest(unittest.TestCase):
proto.optional_string = str('Testing')
self.assertEqual(proto.optional_string, unicode('Testing'))
- # Values of type 'str' are also accepted as long as they can be encoded in
- # UTF-8.
- self.assertEqual(type(proto.optional_string), str)
-
# Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
self.assertRaises(ValueError,
- setattr, proto, 'optional_string', str('a\x80a'))
- # Assign a 'str' object which contains a UTF-8 encoded string.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', 'Тест')
+ setattr, proto, 'optional_string', b'a\x80a')
+ if str is bytes: # PY2
+ # Assign a 'str' object which contains a UTF-8 encoded string.
+ self.assertRaises(ValueError,
+ setattr, proto, 'optional_string', 'Тест')
+ else:
+ proto.optional_string = 'Тест'
# No exception thrown.
proto.optional_string = 'abc'
@@ -1224,7 +1700,8 @@ class ReflectionTest(unittest.TestCase):
self.assertEqual(proto.ByteSize(), len(serialized))
raw = unittest_mset_pb2.RawMessageSet()
- raw.MergeFromString(serialized)
+ bytes_read = raw.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_read)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
@@ -1232,18 +1709,37 @@ class ReflectionTest(unittest.TestCase):
# Check that the type_id is the same as the tag ID in the .proto file.
self.assertEqual(raw.item[0].type_id, 1547769)
- # Check the actually bytes on the wire.
+ # Check the actual bytes on the wire.
self.assertTrue(
raw.item[0].message.endswith(test_utf8_bytes))
- message2.MergeFromString(raw.item[0].message)
+ bytes_read = message2.MergeFromString(raw.item[0].message)
+ self.assertEqual(len(raw.item[0].message), bytes_read)
self.assertEqual(type(message2.str), unicode)
self.assertEqual(message2.str, test_utf8)
- # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
- bytes = raw.item[0].message.replace(
- test_utf8_bytes, len(test_utf8_bytes) * '\xff')
- self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
+ # The pure Python API throws an exception on MergeFromString(),
+ # if any of the string fields of the message can't be UTF-8 decoded.
+ # The C++ implementation of the API has no way to check that on
+ # MergeFromString and thus has no way to throw the exception.
+ #
+ # The pure Python API always returns objects of type 'unicode' (UTF-8
+ # encoded), or 'bytes' (in 7 bit ASCII).
+ badbytes = raw.item[0].message.replace(
+ test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
+
+ unicode_decode_failed = False
+ try:
+ message2.MergeFromString(badbytes)
+ except UnicodeDecodeError:
+ unicode_decode_failed = True
+ string_field = message2.str
+ self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
+
+ def testBytesInTextFormat(self):
+ proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
+ self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
+ unicode(proto))
def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
@@ -1257,16 +1753,19 @@ class ReflectionTest(unittest.TestCase):
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.MergeFromString('')
+ bytes_read = proto.optional_nested_message.MergeFromString(b'')
+ self.assertEqual(0, bytes_read)
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.ParseFromString('')
+ proto.optional_nested_message.ParseFromString(b'')
self.assertTrue(proto.HasField('optional_nested_message'))
serialized = proto.SerializeToString()
proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertTrue(proto2.HasField('optional_nested_message'))
def testSetInParent(self):
@@ -1280,12 +1779,15 @@ class ReflectionTest(unittest.TestCase):
# into separate TestCase classes.
-class TestAllTypesEqualityTest(unittest.TestCase):
+class TestAllTypesEqualityTest(basetest.TestCase):
def setUp(self):
self.first_proto = unittest_pb2.TestAllTypes()
self.second_proto = unittest_pb2.TestAllTypes()
+ def testNotHashable(self):
+ self.assertRaises(TypeError, hash, self.first_proto)
+
def testSelfEquality(self):
self.assertEqual(self.first_proto, self.first_proto)
@@ -1293,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase):
self.assertEqual(self.first_proto, self.second_proto)
-class FullProtosEqualityTest(unittest.TestCase):
+class FullProtosEqualityTest(basetest.TestCase):
"""Equality tests using completely-full protos as a starting point."""
@@ -1303,6 +1805,9 @@ class FullProtosEqualityTest(unittest.TestCase):
test_util.SetAllFields(self.first_proto)
test_util.SetAllFields(self.second_proto)
+ def testNotHashable(self):
+ self.assertRaises(TypeError, hash, self.first_proto)
+
def testNoneNotEqual(self):
self.assertNotEqual(self.first_proto, None)
self.assertNotEqual(None, self.second_proto)
@@ -1371,15 +1876,12 @@ class FullProtosEqualityTest(unittest.TestCase):
self.first_proto.ClearField('optional_nested_message')
self.second_proto.optional_nested_message.ClearField('bb')
self.assertNotEqual(self.first_proto, self.second_proto)
- # TODO(robinson): Replace next two lines with method
- # to set the "has" bit without changing the value,
- # if/when such a method exists.
self.first_proto.optional_nested_message.bb = 0
self.first_proto.optional_nested_message.ClearField('bb')
self.assertEqual(self.first_proto, self.second_proto)
-class ExtensionEqualityTest(unittest.TestCase):
+class ExtensionEqualityTest(basetest.TestCase):
def testExtensionEquality(self):
first_proto = unittest_pb2.TestAllExtensions()
@@ -1412,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class MutualRecursionEqualityTest(unittest.TestCase):
+class MutualRecursionEqualityTest(basetest.TestCase):
def testEqualityWithMutualRecursion(self):
first_proto = unittest_pb2.TestMutualRecursionA()
@@ -1424,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
-class ByteSizeTest(unittest.TestCase):
+class ByteSizeTest(basetest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
@@ -1438,6 +1940,14 @@ class ByteSizeTest(unittest.TestCase):
def testEmptyMessage(self):
self.assertEqual(0, self.proto.ByteSize())
+ def testSizedOnKwargs(self):
+ # Use a separate message to ensure testing right after creation.
+ proto = unittest_pb2.TestAllTypes()
+ self.assertEqual(0, proto.ByteSize())
+ proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
+ # One byte for the tag, one to encode varint 1.
+ self.assertEqual(2, proto_kwargs.ByteSize())
+
def testVarints(self):
def Test(i, expected_varint_size):
self.proto.Clear()
@@ -1629,10 +2139,13 @@ class ByteSizeTest(unittest.TestCase):
self.assertEqual(3, self.proto.ByteSize())
self.proto.ClearField('optional_foreign_message')
self.assertEqual(0, self.proto.ByteSize())
- child = self.proto.optional_foreign_message
- self.proto.ClearField('optional_foreign_message')
- child.c = 128
- self.assertEqual(0, self.proto.ByteSize())
+
+ if api_implementation.Type() == 'python':
+ # This is only possible in pure-Python implementation of the API.
+ child = self.proto.optional_foreign_message
+ self.proto.ClearField('optional_foreign_message')
+ child.c = 128
+ self.assertEqual(0, self.proto.ByteSize())
# Test within extension.
extension = more_extensions_pb2.optional_message_extension
@@ -1698,7 +2211,6 @@ class ByteSizeTest(unittest.TestCase):
self.assertEqual(19, self.packed_extended_proto.ByteSize())
-# TODO(robinson): We need cross-language serialization consistency tests.
# Issues to be sure to cover include:
# * Handling of unrecognized tags ("uninterpreted_bytes").
# * Handling of MessageSets.
@@ -1710,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase):
# * Handling of empty submessages (with and without "has"
# bits set).
-class SerializationTest(unittest.TestCase):
+class SerializationTest(basetest.TestCase):
def testSerializeEmtpyMessage(self):
first_proto = unittest_pb2.TestAllTypes()
second_proto = unittest_pb2.TestAllTypes()
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllFields(self):
@@ -1726,7 +2240,9 @@ class SerializationTest(unittest.TestCase):
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeAllExtensions(self):
@@ -1734,7 +2250,19 @@ class SerializationTest(unittest.TestCase):
second_proto = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(first_proto)
serialized = first_proto.SerializeToString()
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeWithOptionalGroup(self):
+ first_proto = unittest_pb2.TestAllTypes()
+ second_proto = unittest_pb2.TestAllTypes()
+ first_proto.optionalgroup.a = 242
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual(first_proto, second_proto)
def testSerializeNegativeValues(self):
@@ -1753,6 +2281,10 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(first_proto, second_proto)
def testParseTruncated(self):
+ # This test is only applicable for the Python implementation of the API.
+ if api_implementation.Type() != 'python':
+ return
+
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
@@ -1822,7 +2354,9 @@ class SerializationTest(unittest.TestCase):
second_proto.optional_int32 = 100
second_proto.optional_nested_message.bb = 999
- second_proto.MergeFromString(serialized)
+ bytes_parsed = second_proto.MergeFromString(serialized)
+ self.assertEqual(len(serialized), bytes_parsed)
+
# Ensure that we append to repeated fields.
self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
# Ensure that we overwrite nonrepeatd scalars.
@@ -1847,20 +2381,28 @@ class SerializationTest(unittest.TestCase):
raw = unittest_mset_pb2.RawMessageSet()
self.assertEqual(False,
raw.DESCRIPTOR.GetOptions().message_set_wire_format)
- raw.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ raw.MergeFromString(serialized))
self.assertEqual(2, len(raw.item))
message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.MergeFromString(raw.item[0].message)
+ self.assertEqual(
+ len(raw.item[0].message),
+ message1.MergeFromString(raw.item[0].message))
self.assertEqual(123, message1.i)
message2 = unittest_mset_pb2.TestMessageSetExtension2()
- message2.MergeFromString(raw.item[1].message)
+ self.assertEqual(
+ len(raw.item[1].message),
+ message2.MergeFromString(raw.item[1].message))
self.assertEqual('foo', message2.str)
# Deserialize using the MessageSet wire format.
proto2 = unittest_mset_pb2.TestMessageSet()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(123, proto2.Extensions[extension1].i)
self.assertEqual('foo', proto2.Extensions[extension2].str)
@@ -1900,7 +2442,9 @@ class SerializationTest(unittest.TestCase):
# Parse message using the message set wire format.
proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto.MergeFromString(serialized))
# Check that the message parsed well.
extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
@@ -1918,7 +2462,9 @@ class SerializationTest(unittest.TestCase):
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
# Now test with a int64 field set.
proto = unittest_pb2.TestAllTypes()
@@ -1928,13 +2474,15 @@ class SerializationTest(unittest.TestCase):
# unknown.
proto2 = unittest_pb2.TestEmptyMessage()
# Parsing this message should succeed.
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
def _CheckRaises(self, exc_class, callable_obj, exception):
"""This method checks if the excpetion type and message are as expected."""
try:
callable_obj()
- except exc_class, ex:
+ except exc_class as ex:
# Check if the exception message is the right one.
self.assertEqual(exception, str(ex))
return
@@ -1946,15 +2494,22 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: a,b,c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: '
+ 'a,b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
+ proto2 = unittest_pb2.TestRequired()
+ self.assertFalse(proto2.HasField('a'))
+ # proto2 ParseFromString does not check that required fields are set.
+ proto2.ParseFromString(partial)
+ self.assertFalse(proto2.HasField('a'))
+
proto.a = 1
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: b,c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: b,c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1962,7 +2517,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: c')
+ 'Message protobuf_unittest.TestRequired is missing required fields: c')
# Shouldn't raise exceptions.
partial = proto.SerializePartialToString()
@@ -1972,11 +2527,15 @@ class SerializationTest(unittest.TestCase):
partial = proto.SerializePartialToString()
proto2 = unittest_pb2.TestRequired()
- proto2.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ proto2.MergeFromString(serialized))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
- proto2.ParseFromString(partial)
+ self.assertEqual(
+ len(partial),
+ proto2.MergeFromString(partial))
self.assertEqual(1, proto2.a)
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
@@ -1991,7 +2550,8 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: '
+ 'Message protobuf_unittest.TestRequiredForeign '
+ 'is missing required fields: '
'optional_message.b,optional_message.c')
proto.optional_message.b = 2
@@ -2003,7 +2563,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
- 'Message is missing required fields: '
+ 'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
'repeated_message[0].b,repeated_message[0].c,'
'repeated_message[1].a,repeated_message[1].c')
@@ -2043,7 +2603,9 @@ class SerializationTest(unittest.TestCase):
second_proto.packed_double.extend([1.0, 2.0])
second_proto.packed_sint32.append(4)
- second_proto.MergeFromString(serialized)
+ self.assertEqual(
+ len(serialized),
+ second_proto.MergeFromString(serialized))
self.assertEqual([3, 1, 2], second_proto.packed_int32)
self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
self.assertEqual([4], second_proto.packed_sint32)
@@ -2076,7 +2638,10 @@ class SerializationTest(unittest.TestCase):
unpacked = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(unpacked)
packed = unittest_pb2.TestPackedTypes()
- packed.MergeFromString(unpacked.SerializeToString())
+ serialized = unpacked.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ packed.MergeFromString(serialized))
expected = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(expected)
self.assertEqual(expected, packed)
@@ -2085,7 +2650,10 @@ class SerializationTest(unittest.TestCase):
packed = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(packed)
unpacked = unittest_pb2.TestUnpackedTypes()
- unpacked.MergeFromString(packed.SerializeToString())
+ serialized = packed.SerializeToString()
+ self.assertEqual(
+ len(serialized),
+ unpacked.MergeFromString(serialized))
expected = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(expected)
self.assertEqual(expected, unpacked)
@@ -2137,7 +2705,7 @@ class SerializationTest(unittest.TestCase):
optional_int32=1,
optional_string='foo',
optional_bool=True,
- optional_bytes='bar',
+ optional_bytes=b'bar',
optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
@@ -2155,7 +2723,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1, proto.optional_int32)
self.assertEqual('foo', proto.optional_string)
self.assertEqual(True, proto.optional_bool)
- self.assertEqual('bar', proto.optional_bytes)
+ self.assertEqual(b'bar', proto.optional_bytes)
self.assertEqual(1, proto.optional_nested_message.bb)
self.assertEqual(1, proto.optional_foreign_message.c)
self.assertEqual(unittest_pb2.TestAllTypes.FOO,
@@ -2205,7 +2773,7 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(3, proto.repeated_int32[2])
-class OptionsTest(unittest.TestCase):
+class OptionsTest(basetest.TestCase):
def testMessageOptions(self):
proto = unittest_mset_pb2.TestMessageSet()
@@ -2232,5 +2800,135 @@ class OptionsTest(unittest.TestCase):
+class ClassAPITest(basetest.TestCase):
+
+ def testMakeClassWithNestedDescriptor(self):
+ leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
+ containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
+ containing_type=None, fields=[],
+ nested_types=[leaf_desc], enum_types=[],
+ extensions=[])
+ sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
+ '', containing_type=None, fields=[],
+ nested_types=[], enum_types=[],
+ extensions=[])
+ parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
+ containing_type=None, fields=[],
+ nested_types=[child_desc, sibling_desc],
+ enum_types=[], extensions=[])
+ message_class = reflection.MakeClass(parent_desc)
+ self.assertIn('child', message_class.__dict__)
+ self.assertIn('sibling', message_class.__dict__)
+ self.assertIn('leaf', message_class.child.__dict__)
+
+ def _GetSerializedFileDescriptor(self, name):
+ """Get a serialized representation of a test FileDescriptorProto.
+
+ Args:
+ name: All calls to this must use a unique message name, to avoid
+ collisions in the cpp descriptor pool.
+ Returns:
+ A string containing the serialized form of a test FileDescriptorProto.
+ """
+ file_descriptor_str = (
+ 'message_type {'
+ ' name: "' + name + '"'
+ ' field {'
+ ' name: "flat"'
+ ' number: 1'
+ ' label: LABEL_REPEATED'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' field {'
+ ' name: "bar"'
+ ' number: 2'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Bar"'
+ ' }'
+ ' nested_type {'
+ ' name: "Bar"'
+ ' field {'
+ ' name: "baz"'
+ ' number: 3'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_MESSAGE'
+ ' type_name: "Baz"'
+ ' }'
+ ' nested_type {'
+ ' name: "Baz"'
+ ' enum_type {'
+ ' name: "deep_enum"'
+ ' value {'
+ ' name: "VALUE_A"'
+ ' number: 0'
+ ' }'
+ ' }'
+ ' field {'
+ ' name: "deep"'
+ ' number: 4'
+ ' label: LABEL_OPTIONAL'
+ ' type: TYPE_UINT32'
+ ' }'
+ ' }'
+ ' }'
+ '}')
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ text_format.Merge(file_descriptor_str, file_descriptor)
+ return file_descriptor.SerializeToString()
+
+ def testParsingFlatClassWithExplicitClassDeclaration(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+
+ class MessageClass(message.Message):
+ __metaclass__ = reflection.GeneratedProtocolMessageType
+ DESCRIPTOR = msg_descriptor
+ msg = MessageClass()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingFlatClass(self):
+ """Test that the generated class can parse a flat message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'flat: 0 '
+ 'flat: 1 '
+ 'flat: 2 ')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.flat, [0, 1, 2])
+
+ def testParsingNestedClass(self):
+ """Test that the generated class can parse a nested message."""
+ file_descriptor = descriptor_pb2.FileDescriptorProto()
+ file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
+ msg_descriptor = descriptor.MakeDescriptor(
+ file_descriptor.message_type[0])
+ msg_class = reflection.MakeClass(msg_descriptor)
+ msg = msg_class()
+ msg_str = (
+ 'bar {'
+ ' baz {'
+ ' deep: 4'
+ ' }'
+ '}')
+ text_format.Merge(msg_str, msg)
+ self.assertEqual(msg.bar.baz.deep, 4)
+
if __name__ == '__main__':
- unittest.main()
+ basetest.main()