diff options
Diffstat (limited to 'python/google/protobuf/internal/reflection_test.py')
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 870 |
1 files changed, 784 insertions, 86 deletions
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 2c9fa30..d59815d 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -3,7 +3,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -37,12 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy +import gc import operator import struct -import unittest -# TODO(robinson): When we split this test in two, only some of these imports -# will be necessary in each test. +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -50,6 +50,8 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 from google.protobuf.internal import wire_format @@ -102,12 +104,12 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): - def assertIs(self, values, others): + def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) for i in range(len(values)): - self.assertTrue(values[i] is others[i]) + self.assertEqual(values[i], others[i]) def testScalarConstructor(self): # Constructor with only scalar types should succeed. @@ -200,6 +202,41 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + def testConstructorTypeError(self): + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) + + def testConstructorInvalidatesCachedByteSize(self): + message = unittest_pb2.TestAllTypes(optional_int32 = 12) + self.assertEquals(2, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) + self.assertEquals(3, message.ByteSize()) + def testSimpleHasBits(self): # Test a scalar. proto = unittest_pb2.TestAllTypes() @@ -284,12 +321,6 @@ class ReflectionTest(unittest.TestCase): # ...and ensure that the scalar field has returned to its default. self.assertEqual(0, getattr(composite_field, scalar_field_name)) - # Finally, ensure that modifications to the old composite field object - # don't have any effect on the parent. - # - # (NOTE that when we clear the composite field in the parent, we actually - # don't recursively clear down the tree. Instead, we just disconnect the - # cleared composite from the tree.) self.assertTrue(old_composite_field is not composite_field) setattr(old_composite_field, scalar_field_name, new_val) self.assertTrue(not composite_field.HasField(scalar_field_name)) @@ -319,6 +350,64 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) + def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') + del proto + del nested + # Force a garbage collect so that the underlying CMessages are freed along + # with the Messages they point to. This is to make sure we're not deleting + # default message instances. + gc.collect() + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + + def testDisconnectingNestedMessageAfterSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + self.assertTrue(proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertEqual(5, nested.bb) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testDisconnectingNestedMessageBeforeGettingField(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + + def testDisconnectingNestedMessageAfterMerge(self): + # This test exercises the code path that does not use ReleaseMessage(). + # The underlying fear is that if we use ReleaseMessage() incorrectly, + # we will have memory leaks. It's hard to check that that doesn't happen, + # but at least we can exercise that code path to make sure it works. + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_nested_message.bb = 5 + proto1.MergeFrom(proto2) + self.assertTrue(proto1.HasField('optional_nested_message')) + proto1.ClearField('optional_nested_message') + self.assertTrue(not proto1.HasField('optional_nested_message')) + + def testDisconnectingLazyNestedMessage(self): + # This test exercises releasing a nested message that is lazy. This test + # only exercises real code in the C++ implementation as Python does not + # support lazy parsing, but the current C++ implementation results in + # memory corruption and a crash. + if api_implementation.Type() != 'python': + return + proto = unittest_pb2.TestAllTypes() + proto.optional_lazy_message.bb = 5 + proto.ClearField('optional_lazy_message') + del proto + gc.collect() + def testHasBitsWhenModifyingRepeatedFields(self): # Test nesting when we add an element to a repeated field in a submessage. proto = unittest_pb2.TestNestedMessageHasBits() @@ -446,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -462,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -479,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -500,7 +600,6 @@ class ReflectionTest(unittest.TestCase): # proto.nonexistent_field = 23 should fail as well. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) - # TODO(robinson): Add type-safety check for enums. def testSingleScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) @@ -508,11 +607,37 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() setattr(pb, field_name, expected_min) + self.assertEqual(expected_min, getattr(pb, field_name)) setattr(pb, field_name, expected_max) + self.assertEqual(expected_max, getattr(pb, field_name)) self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) @@ -520,7 +645,10 @@ class ReflectionTest(unittest.TestCase): TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) - TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + pb = unittest_pb2.TestAllTypes() + pb.optional_nested_enum = 1 + self.assertEqual(1, pb.optional_nested_enum) def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() @@ -534,11 +662,19 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + # Repeated enums tests. + #proto.repeated_nested_enum.append(0) + def testSingleScalarGettersAndSetters(self): proto = unittest_pb2.TestAllTypes() self.assertEqual(0, proto.optional_int32) proto.optional_int32 = 1 self.assertEqual(1, proto.optional_int32) + + proto.optional_uint64 = 0xffffffffffff + self.assertEqual(0xffffffffffff, proto.optional_uint64) + proto.optional_uint64 = 0xffffffffffffffff + self.assertEqual(0xffffffffffffffff, proto.optional_uint64) # TODO(robinson): Test all other scalar field types. def testSingleScalarClearField(self): @@ -561,6 +697,77 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testEnum_Name(self): + self.assertEqual('FOREIGN_FOO', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) + self.assertEqual('FOREIGN_BAR', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) + self.assertEqual('FOREIGN_BAZ', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Name, 11312) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual('FOO', + proto.NestedEnum.Name(proto.FOO)) + self.assertEqual('FOO', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) + self.assertEqual('BAR', + proto.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAR', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAZ', + proto.NestedEnum.Name(proto.BAZ)) + self.assertEqual('BAZ', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) + self.assertRaises(ValueError, + proto.NestedEnum.Name, 11312) + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) + + def testEnum_Value(self): + self.assertEqual(unittest_pb2.FOREIGN_FOO, + unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Value, 'FO') + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(proto.FOO, + proto.NestedEnum.Value('FOO')) + self.assertEqual(proto.FOO, + unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) + self.assertEqual(proto.BAR, + proto.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAR, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAZ, + proto.NestedEnum.Value('BAZ')) + self.assertEqual(proto.BAZ, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) + self.assertRaises(ValueError, + proto.NestedEnum.Value, 'Foo') + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') + + def testEnum_KeysAndValues(self): + self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], + unittest_pb2.ForeignEnum.keys()) + self.assertEqual([4, 5, 6], + unittest_pb2.ForeignEnum.values()) + self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), + ('FOREIGN_BAZ', 6)], + unittest_pb2.ForeignEnum.items()) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], + proto.NestedEnum.items()) + def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -619,11 +826,38 @@ class ReflectionTest(unittest.TestCase): del proto.repeated_int32[2:] self.assertEqual([5, 35], proto.repeated_int32) + # Test extending. + proto.repeated_int32.extend([3, 13]) + self.assertEqual([5, 35, 3, 13], proto.repeated_int32) + # Test clearing. proto.ClearField('repeated_int32') self.assertTrue(not proto.repeated_int32) self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(1) + self.assertEqual(1, proto.repeated_int32[-1]) + # Test assignment to a negative index. + proto.repeated_int32[-1] = 2 + self.assertEqual(2, proto.repeated_int32[-1]) + + # Test deletion at negative indices. + proto.repeated_int32[:] = [0, 1, 2, 3] + del proto.repeated_int32[-1] + self.assertEqual([0, 1, 2], proto.repeated_int32) + + del proto.repeated_int32[-2] + self.assertEqual([0, 2], proto.repeated_int32) + + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) + + del proto.repeated_int32[-2:-1] + self.assertEqual([2], proto.repeated_int32) + + del proto.repeated_int32[100:10000] + self.assertEqual([2], proto.repeated_int32) + def testRepeatedScalarsRemove(self): proto = unittest_pb2.TestAllTypes() @@ -661,7 +895,7 @@ class ReflectionTest(unittest.TestCase): m1 = proto.repeated_nested_message.add() self.assertTrue(proto.repeated_nested_message) self.assertEqual(2, len(proto.repeated_nested_message)) - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) # Test out-of-bounds indices. @@ -680,32 +914,86 @@ class ReflectionTest(unittest.TestCase): m2 = proto.repeated_nested_message.add() m3 = proto.repeated_nested_message.add() m4 = proto.repeated_nested_message.add() - self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) - self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m1, m2, m3], proto.repeated_nested_message[1:4]) + self.assertListsEqual( + [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m0, m1], proto.repeated_nested_message[:2]) + self.assertListsEqual( + [m2, m3, m4], proto.repeated_nested_message[2:]) + self.assertEqual( + m0, proto.repeated_nested_message[0]) + self.assertListsEqual( + [m0], proto.repeated_nested_message[:1]) # Test that we can use the field as an iterator. result = [] for i in proto.repeated_nested_message: result.append(i) - self.assertIs([m0, m1, m2, m3, m4], result) + self.assertListsEqual([m0, m1, m2, m3, m4], result) # Test single deletion. del proto.repeated_nested_message[2] - self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) # Test slice deletion. del proto.repeated_nested_message[2:] - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) + + # Test extending. + n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) + n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) + proto.repeated_nested_message.extend([n1,n2]) + self.assertEqual(4, len(proto.repeated_nested_message)) + self.assertEqual(n1, proto.repeated_nested_message[2]) + self.assertEqual(n2, proto.repeated_nested_message[3]) # Test clearing. proto.ClearField('repeated_nested_message') self.assertTrue(not proto.repeated_nested_message) self.assertEqual(0, len(proto.repeated_nested_message)) + # Test constructing an element while adding it. + proto.repeated_nested_message.add(bb=23) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(23, proto.repeated_nested_message[0].bb) + + def testRepeatedCompositeRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + # Need to set some differentiating variable so m0 != m1 != m2: + m0.bb = len(proto.repeated_nested_message) + m1 = proto.repeated_nested_message.add() + m1.bb = len(proto.repeated_nested_message) + self.assertTrue(m0 != m1) + m2 = proto.repeated_nested_message.add() + m2.bb = len(proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) + + self.assertEqual(3, len(proto.repeated_nested_message)) + proto.repeated_nested_message.remove(m0) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + self.assertEqual(m2, proto.repeated_nested_message[1]) + + # Removing m0 again or removing None should raise error + self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) + self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) + self.assertEqual(2, len(proto.repeated_nested_message)) + + proto.repeated_nested_message.remove(m2) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + def testHandWrittenReflection(self): - # TODO(robinson): We probably need a better way to specify - # protocol types by hand. But then again, this isn't something - # we expect many people to do. Hmm. + # Hand written extensions are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + FieldDescriptor = descriptor.FieldDescriptor foo_field_descriptor = FieldDescriptor( name='foo_field', full_name='MyProto.foo_field', @@ -730,6 +1018,68 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + def testDescriptorProtoSupport(self): + # Hand written descriptors/reflection are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + + def AddDescriptorField(proto, field_name, field_type): + AddDescriptorField.field_index += 1 + new_field = proto.field.add() + new_field.name = field_name + new_field.type = field_type + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + AddDescriptorField.field_index = 0 + + desc_proto = descriptor_pb2.DescriptorProto() + desc_proto.name = 'Car' + fdp = descriptor_pb2.FieldDescriptorProto + AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) + AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) + AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) + AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) + # Add a repeated field + AddDescriptorField.field_index += 1 + new_field = desc_proto.field.add() + new_field.name = 'owners' + new_field.type = fdp.TYPE_STRING + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + + desc = descriptor.MakeDescriptor(desc_proto) + self.assertTrue(desc.fields_by_name.has_key('name')) + self.assertTrue(desc.fields_by_name.has_key('year')) + self.assertTrue(desc.fields_by_name.has_key('automatic')) + self.assertTrue(desc.fields_by_name.has_key('price')) + self.assertTrue(desc.fields_by_name.has_key('owners')) + + class CarMessage(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = desc + + prius = CarMessage() + prius.name = 'prius' + prius.year = 2010 + prius.automatic = True + prius.price = 25134.75 + prius.owners.extend(['bob', 'susan']) + + serialized_prius = prius.SerializeToString() + new_prius = reflection.ParseMessage(desc, serialized_prius) + self.assertTrue(new_prius is not prius) + self.assertEqual(prius, new_prius) + + # these are unnecessary assuming message equality works as advertised but + # explicitly check to be safe since we're mucking about in metaclass foo + self.assertEqual(prius.name, new_prius.name) + self.assertEqual(prius.year, new_prius.year) + self.assertEqual(prius.automatic, new_prius.automatic) + self.assertEqual(prius.price, new_prius.price) + self.assertEqual(prius.owners, new_prius.owners) + def testTopLevelExtensionsForOptionalScalar(self): extendee_proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.optional_int32_extension @@ -819,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -868,7 +1226,7 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not toplevel.HasField('submessage')) foreign = toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension].add() - self.assertTrue(foreign is toplevel.submessage.Extensions[ + self.assertEqual(foreign, toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension][0]) self.assertTrue(toplevel.HasField('submessage')) @@ -971,6 +1329,12 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(123, proto2.repeated_nested_message[1].bb) self.assertEqual(321, proto2.repeated_nested_message[2].bb) + proto3 = unittest_pb2.TestAllTypes() + proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) + self.assertEqual(999, proto3.repeated_nested_message[0].bb) + self.assertEqual(123, proto3.repeated_nested_message[1].bb) + self.assertEqual(321, proto3.repeated_nested_message[2].bb) + def testMergeFromAllFields(self): # With all fields set. proto1 = unittest_pb2.TestAllTypes() @@ -1035,6 +1399,19 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(222, ext2[1].bb) self.assertEqual(333, ext2[2].bb) + def testMergeFromBug(self): + message1 = unittest_pb2.TestAllTypes() + message2 = unittest_pb2.TestAllTypes() + + # Cause optional_nested_message to be instantiated within message1, even + # though it is not considered to be "present". + message1.optional_nested_message + self.assertFalse(message1.HasField('optional_nested_message')) + + # Merge into message2. This should not instantiate the field is message2. + message2.MergeFrom(message1) + self.assertFalse(message2.HasField('optional_nested_message')) + def testCopyFromSingularField(self): # Test copy with just a singular field. proto1 = unittest_pb2.TestAllTypes() @@ -1087,9 +1464,36 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(2, proto1.optional_int32) self.assertEqual('important-text', proto1.optional_string) + def testCopyFromBadType(self): + # The python implementation doesn't raise an exception in this + # case. In theory it should. + if api_implementation.Type() == 'python': + return + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllExtensions() + self.assertRaises(TypeError, proto1.CopyFrom, proto2) + + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() - test_util.SetAllFields(proto) + # C++ implementation does not support lazy fields right now so leave it + # out for now. + if api_implementation.Type() == 'python': + test_util.SetAllFields(proto) + else: + test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() self.assertEquals(proto.ByteSize(), 0) @@ -1105,6 +1509,45 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def testDisconnectingBeforeClear(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + foreign = proto.optional_foreign_message + foreign.c = 6 + + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + self.assertTrue(foreign is not proto.optional_foreign_message) + self.assertEqual(5, nested.bb) + self.assertEqual(6, foreign.c) + nested.bb = 15 + foreign.c = 16 + self.assertFalse(proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertFalse(proto.HasField('optional_foreign_message')) + self.assertEqual(0, proto.optional_foreign_message.c) + + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1175,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1192,16 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - # Values of type 'str' are also accepted as long as they can be encoded in - # UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1224,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1232,18 +1709,37 @@ class ReflectionTest(unittest.TestCase): # Check that the type_id is the same as the tag ID in the .proto file. self.assertEqual(raw.item[0].type_id, 1547769) - # Check the actually bytes on the wire. + # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) - # How about if the bytes on the wire aren't a valid UTF-8 encoded string. - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') - self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + # The pure Python API throws an exception on MergeFromString(), + # if any of the string fields of the message can't be UTF-8 decoded. + # The C++ implementation of the API has no way to check that on + # MergeFromString and thus has no way to throw the exception. + # + # The pure Python API always returns objects of type 'unicode' (UTF-8 + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') + + unicode_decode_failed = False + try: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: + unicode_decode_failed = True + string_field = message2.str + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1257,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1280,12 +1779,15 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() self.second_proto = unittest_pb2.TestAllTypes() + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testSelfEquality(self): self.assertEqual(self.first_proto, self.first_proto) @@ -1293,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1303,6 +1805,9 @@ class FullProtosEqualityTest(unittest.TestCase): test_util.SetAllFields(self.first_proto) test_util.SetAllFields(self.second_proto) + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testNoneNotEqual(self): self.assertNotEqual(self.first_proto, None) self.assertNotEqual(None, self.second_proto) @@ -1371,15 +1876,12 @@ class FullProtosEqualityTest(unittest.TestCase): self.first_proto.ClearField('optional_nested_message') self.second_proto.optional_nested_message.ClearField('bb') self.assertNotEqual(self.first_proto, self.second_proto) - # TODO(robinson): Replace next two lines with method - # to set the "has" bit without changing the value, - # if/when such a method exists. self.first_proto.optional_nested_message.bb = 0 self.first_proto.optional_nested_message.ClearField('bb') self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1412,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1424,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -1438,6 +1940,14 @@ class ByteSizeTest(unittest.TestCase): def testEmptyMessage(self): self.assertEqual(0, self.proto.ByteSize()) + def testSizedOnKwargs(self): + # Use a separate message to ensure testing right after creation. + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.ByteSize()) + proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) + # One byte for the tag, one to encode varint 1. + self.assertEqual(2, proto_kwargs.ByteSize()) + def testVarints(self): def Test(i, expected_varint_size): self.proto.Clear() @@ -1629,10 +2139,13 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(3, self.proto.ByteSize()) self.proto.ClearField('optional_foreign_message') self.assertEqual(0, self.proto.ByteSize()) - child = self.proto.optional_foreign_message - self.proto.ClearField('optional_foreign_message') - child.c = 128 - self.assertEqual(0, self.proto.ByteSize()) + + if api_implementation.Type() == 'python': + # This is only possible in pure-Python implementation of the API. + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) # Test within extension. extension = more_extensions_pb2.optional_message_extension @@ -1698,7 +2211,6 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(19, self.packed_extended_proto.ByteSize()) -# TODO(robinson): We need cross-language serialization consistency tests. # Issues to be sure to cover include: # * Handling of unrecognized tags ("uninterpreted_bytes"). # * Handling of MessageSets. @@ -1710,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -1726,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -1734,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -1753,6 +2281,10 @@ class SerializationTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) def testParseTruncated(self): + # This test is only applicable for the Python implementation of the API. + if api_implementation.Type() != 'python': + return + first_proto = unittest_pb2.TestAllTypes() test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() @@ -1822,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -1847,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -1900,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -1918,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -1928,13 +2474,15 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" try: callable_obj() - except exc_class, ex: + except exc_class as ex: # Check if the exception message is the right one. self.assertEqual(exception, str(ex)) return @@ -1946,15 +2494,22 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: a,b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: ' + 'a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() + proto2 = unittest_pb2.TestRequired() + self.assertFalse(proto2.HasField('a')) + # proto2 ParseFromString does not check that required fields are set. + proto2.ParseFromString(partial) + self.assertFalse(proto2.HasField('a')) + proto.a = 1 self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1962,7 +2517,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: c') + 'Message protobuf_unittest.TestRequired is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1972,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -1991,7 +2550,8 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign ' + 'is missing required fields: ' 'optional_message.b,optional_message.c') proto.optional_message.b = 2 @@ -2003,7 +2563,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 'repeated_message[0].b,repeated_message[0].c,' 'repeated_message[1].a,repeated_message[1].c') @@ -2043,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2076,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2085,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2137,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2155,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2205,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2232,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() |