diff options
author | qsr <qsr@chromium.org> | 2014-09-26 09:28:25 -0700 |
---|---|---|
committer | Commit bot <commit-bot@chromium.org> | 2014-09-26 16:28:49 +0000 |
commit | e86ad221e441c938f50e482f8fd6ad7e59bde83c (patch) | |
tree | 1264767bbc6f463154c2d062da036fb1c4bea3ee | |
parent | f7d782116910dc7d962c245e0ccac2a5feefa6cc (diff) | |
download | chromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.zip chromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.tar.gz chromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.tar.bz2 |
mojo: Add router for python bindings.
This is a reland of https://codereview.chromium.org/607513003 with the
fix for the failing test.
BUG=417707
TBR=sdefresne@chromium.org
Review URL: https://codereview.chromium.org/612443002
Cr-Commit-Position: refs/heads/master@{#296958}
-rw-r--r-- | mojo/public/python/mojo/bindings/messaging.py | 220 | ||||
-rw-r--r-- | mojo/python/tests/messaging_unittest.py | 150 |
2 files changed, 358 insertions, 12 deletions
diff --git a/mojo/public/python/mojo/bindings/messaging.py b/mojo/public/python/mojo/bindings/messaging.py index a6eb575..956f5b3 100644 --- a/mojo/public/python/mojo/bindings/messaging.py +++ b/mojo/public/python/mojo/bindings/messaging.py @@ -5,18 +5,140 @@ """Utility classes to handle sending and receiving messages.""" +import struct import weakref # pylint: disable=F0401 +import mojo.bindings.serialization as serialization import mojo.system as system +# The flag values for a message header. +NO_FLAG = 0 +MESSAGE_EXPECTS_RESPONSE_FLAG = 1 << 0 +MESSAGE_IS_RESPONSE_FLAG = 1 << 1 + + +class MessageHeader(object): + """The header of a mojo message.""" + + _SIMPLE_MESSAGE_NUM_FIELDS = 2 + _SIMPLE_MESSAGE_STRUCT = struct.Struct("=IIII") + + _REQUEST_ID_STRUCT = struct.Struct("=Q") + _REQUEST_ID_OFFSET = _SIMPLE_MESSAGE_STRUCT.size + + _MESSAGE_WITH_REQUEST_ID_NUM_FIELDS = 3 + _MESSAGE_WITH_REQUEST_ID_SIZE = ( + _SIMPLE_MESSAGE_STRUCT.size + _REQUEST_ID_STRUCT.size) + + def __init__(self, message_type, flags, request_id=0, data=None): + self._message_type = message_type + self._flags = flags + self._request_id = request_id + self._data = data + + @classmethod + def Deserialize(cls, data): + buf = buffer(data) + if len(data) < cls._SIMPLE_MESSAGE_STRUCT.size: + raise serialization.DeserializationException('Header is too short.') + (size, version, message_type, flags) = ( + cls._SIMPLE_MESSAGE_STRUCT.unpack_from(buf)) + if (version < cls._SIMPLE_MESSAGE_NUM_FIELDS): + raise serialization.DeserializationException('Incorrect version.') + request_id = 0 + if _HasRequestId(flags): + if version < cls._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS: + raise serialization.DeserializationException('Incorrect version.') + if (size < cls._MESSAGE_WITH_REQUEST_ID_SIZE or + len(data) < cls._MESSAGE_WITH_REQUEST_ID_SIZE): + raise serialization.DeserializationException('Header is too short.') + (request_id, ) = cls._REQUEST_ID_STRUCT.unpack_from( + buf, cls._REQUEST_ID_OFFSET) + return MessageHeader(message_type, flags, request_id, data) + + @property + def message_type(self): + return self._message_type + + # pylint: disable=E0202 + @property + def request_id(self): + assert self.has_request_id + return self._request_id + + # pylint: disable=E0202 + @request_id.setter + def request_id(self, request_id): + assert self.has_request_id + self._request_id = request_id + self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET, + request_id) + + @property + def has_request_id(self): + return _HasRequestId(self._flags) + + @property + def expects_response(self): + return self._HasFlag(MESSAGE_EXPECTS_RESPONSE_FLAG) + + @property + def is_response(self): + return self._HasFlag(MESSAGE_IS_RESPONSE_FLAG) + + @property + def size(self): + if self.has_request_id: + return self._MESSAGE_WITH_REQUEST_ID_SIZE + return self._SIMPLE_MESSAGE_STRUCT.size + + def Serialize(self): + if not self._data: + self._data = bytearray(self.size) + version = self._SIMPLE_MESSAGE_NUM_FIELDS + size = self._SIMPLE_MESSAGE_STRUCT.size + if self.has_request_id: + version = self._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS + size = self._MESSAGE_WITH_REQUEST_ID_SIZE + self._SIMPLE_MESSAGE_STRUCT.pack_into(self._data, 0, size, version, + self._message_type, self._flags) + if self.has_request_id: + self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET, + self._request_id) + return self._data + + def _HasFlag(self, flag): + return self._flags & flag != 0 + + class Message(object): """A message for a message pipe. This contains data and handles.""" def __init__(self, data=None, handles=None): self.data = data self.handles = handles + self._header = None + self._payload = None + + @property + def header(self): + if self._header is None: + self._header = MessageHeader.Deserialize(self.data) + return self._header + + @property + def payload(self): + if self._payload is None: + self._payload = Message(self.data[self.header.size:], self.handles) + return self._payload + + def SetRequestId(self, request_id): + header = self.header + header.request_id = request_id + (data, _) = header.Serialize() + self.data[:header.Size] = data[:header.Size] class MessageReceiver(object): @@ -111,6 +233,12 @@ class Connector(MessageReceiver): result = self._handle.WriteMessage(message.data, message.handles) return result == system.RESULT_OK + def Close(self): + if self._cancellable: + self._cancellable() + self._cancellable = None + self._handle.Close() + def _OnAsyncWaiterResult(self, result): self._cancellable = None if result == system.RESULT_OK: @@ -141,6 +269,96 @@ class Connector(MessageReceiver): self._OnError(result) +class Router(MessageReceiverWithResponder): + """ + A Router will handle mojo message and forward those to a Connector. It deals + with parsing of headers and adding of request ids in order to be able to match + a response to a request. + """ + + def __init__(self, handle): + MessageReceiverWithResponder.__init__(self) + self._incoming_message_receiver = None + self._next_request_id = 1 + self._responders = {} + self._connector = Connector(handle) + self._connector.SetIncomingMessageReceiver( + ForwardingMessageReceiver(self._HandleIncomingMessage)) + + def Start(self): + self._connector.Start() + + def SetIncomingMessageReceiver(self, message_receiver): + """ + Set the MessageReceiver that will receive message from the owned message + pipe. + """ + self._incoming_message_receiver = message_receiver + + def SetErrorHandler(self, error_handler): + """ + Set the ConnectionErrorHandler that will be notified of errors on the owned + message pipe. + """ + self._connector.SetErrorHandler(error_handler) + + def Accept(self, message): + # A message without responder is directly forwarded to the connector. + return self._connector.Accept(message) + + def AcceptWithResponder(self, message, responder): + # The message must have a header. + header = message.header + assert header.expects_response + request_id = self.NextRequestId() + header.request_id = request_id + if not self._connector.Accept(message): + return False + self._responders[request_id] = responder + return True + + def Close(self): + self._connector.Close() + + def _HandleIncomingMessage(self, message): + header = message.header + if header.expects_response: + if self._incoming_message_receiver: + return self._incoming_message_receiver.AcceptWithResponder( + message, self) + # If we receive a request expecting a response when the client is not + # listening, then we have no choice but to tear down the pipe. + self.Close() + return False + if header.is_response: + request_id = header.request_id + responder = self._responders.pop(request_id, None) + if responder is None: + return False + return responder.Accept(message) + if self._incoming_message_receiver: + return self._incoming_message_receiver.Accept(message) + # Ok to drop the message + return False + + def NextRequestId(self): + request_id = self._next_request_id + while request_id == 0 or request_id in self._responders: + request_id = (request_id + 1) % (1 << 64) + self._next_request_id = (request_id + 1) % (1 << 64) + return request_id + +class ForwardingMessageReceiver(MessageReceiver): + """A MessageReceiver that forward calls to |Accept| to a callable.""" + + def __init__(self, callback): + MessageReceiver.__init__(self) + self._callback = callback + + def Accept(self, message): + return self._callback(message) + + def _WeakCallback(callback): func = callback.im_func self = callback.im_self @@ -165,3 +383,5 @@ def _ReadAndDispatchMessage(handle, message_receiver): message_receiver.Accept(Message(data[0], data[1])) return result +def _HasRequestId(flags): + return flags & (MESSAGE_EXPECTS_RESPONSE_FLAG|MESSAGE_IS_RESPONSE_FLAG) != 0 diff --git a/mojo/python/tests/messaging_unittest.py b/mojo/python/tests/messaging_unittest.py index c67048b..2d08941 100644 --- a/mojo/python/tests/messaging_unittest.py +++ b/mojo/python/tests/messaging_unittest.py @@ -10,16 +10,6 @@ from mojo.bindings import messaging from mojo import system -class _ForwardingMessageReceiver(messaging.MessageReceiver): - - def __init__(self, callback): - self._callback = callback - - def Accept(self, message): - self._callback(message) - return True - - class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler): def __init__(self, callback): @@ -29,7 +19,7 @@ class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler): self._callback(result) -class MessagingTest(unittest.TestCase): +class ConnectorTest(unittest.TestCase): def setUp(self): mojo.embedder.Init() @@ -38,12 +28,13 @@ class MessagingTest(unittest.TestCase): self.received_errors = [] def _OnMessage(message): self.received_messages.append(message) + return True def _OnError(result): self.received_errors.append(result) handles = system.MessagePipe() self.connector = messaging.Connector(handles.handle1) self.connector.SetIncomingMessageReceiver( - _ForwardingMessageReceiver(_OnMessage)) + messaging.ForwardingMessageReceiver(_OnMessage)) self.connector.SetErrorHandler( _ForwardingConnectionErrorHandler(_OnError)) self.connector.Start() @@ -79,3 +70,138 @@ class MessagingTest(unittest.TestCase): self.connector = None (result, _, _) = self.handle.ReadMessage() self.assertEquals(result, system.RESULT_FAILED_PRECONDITION) + + +class HeaderTest(unittest.TestCase): + + def testSimpleMessageHeader(self): + header = messaging.MessageHeader(0xdeadbeaf, messaging.NO_FLAG) + self.assertEqual(header.message_type, 0xdeadbeaf) + self.assertFalse(header.has_request_id) + self.assertFalse(header.expects_response) + self.assertFalse(header.is_response) + data = header.Serialize() + other_header = messaging.MessageHeader.Deserialize(data) + self.assertEqual(other_header.message_type, 0xdeadbeaf) + self.assertFalse(other_header.has_request_id) + self.assertFalse(other_header.expects_response) + self.assertFalse(other_header.is_response) + + def testMessageHeaderWithRequestID(self): + # Request message. + header = messaging.MessageHeader(0xdeadbeaf, + messaging.MESSAGE_EXPECTS_RESPONSE_FLAG) + + self.assertEqual(header.message_type, 0xdeadbeaf) + self.assertTrue(header.has_request_id) + self.assertTrue(header.expects_response) + self.assertFalse(header.is_response) + self.assertEqual(header.request_id, 0) + + data = header.Serialize() + other_header = messaging.MessageHeader.Deserialize(data) + + self.assertEqual(other_header.message_type, 0xdeadbeaf) + self.assertTrue(other_header.has_request_id) + self.assertTrue(other_header.expects_response) + self.assertFalse(other_header.is_response) + self.assertEqual(other_header.request_id, 0) + + header.request_id = 0xdeadbeafdeadbeaf + data = header.Serialize() + other_header = messaging.MessageHeader.Deserialize(data) + + self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf) + + # Response message. + header = messaging.MessageHeader(0xdeadbeaf, + messaging.MESSAGE_IS_RESPONSE_FLAG, + 0xdeadbeafdeadbeaf) + + self.assertEqual(header.message_type, 0xdeadbeaf) + self.assertTrue(header.has_request_id) + self.assertFalse(header.expects_response) + self.assertTrue(header.is_response) + self.assertEqual(header.request_id, 0xdeadbeafdeadbeaf) + + data = header.Serialize() + other_header = messaging.MessageHeader.Deserialize(data) + + self.assertEqual(other_header.message_type, 0xdeadbeaf) + self.assertTrue(other_header.has_request_id) + self.assertFalse(other_header.expects_response) + self.assertTrue(other_header.is_response) + self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf) + + +class RouterTest(unittest.TestCase): + + def setUp(self): + mojo.embedder.Init() + self.loop = system.RunLoop() + self.received_messages = [] + self.received_errors = [] + def _OnMessage(message): + self.received_messages.append(message) + return True + def _OnError(result): + self.received_errors.append(result) + handles = system.MessagePipe() + self.router = messaging.Router(handles.handle1) + self.router.SetIncomingMessageReceiver( + messaging.ForwardingMessageReceiver(_OnMessage)) + self.router.SetErrorHandler( + _ForwardingConnectionErrorHandler(_OnError)) + self.router.Start() + self.handle = handles.handle0 + + def tearDown(self): + self.router = None + self.handle = None + self.loop = None + + def testSimpleMessage(self): + header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize() + message = messaging.Message(header_data) + self.router.Accept(message) + self.loop.RunUntilIdle() + self.assertFalse(self.received_errors) + self.assertFalse(self.received_messages) + (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data))) + self.assertEquals(system.RESULT_OK, res) + self.assertEquals(data[0], header_data) + + def testSimpleReception(self): + header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize() + self.handle.WriteMessage(header_data) + self.loop.RunUntilIdle() + self.assertFalse(self.received_errors) + self.assertEquals(len(self.received_messages), 1) + self.assertEquals(self.received_messages[0].data, header_data) + + def testRequestResponse(self): + header_data = messaging.MessageHeader( + 0, messaging.MESSAGE_EXPECTS_RESPONSE_FLAG).Serialize() + message = messaging.Message(header_data) + back_messages = [] + def OnBackMessage(message): + back_messages.append(message) + self.router.AcceptWithResponder(message, + messaging.ForwardingMessageReceiver( + OnBackMessage)) + self.loop.RunUntilIdle() + self.assertFalse(self.received_errors) + self.assertFalse(self.received_messages) + (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data))) + self.assertEquals(system.RESULT_OK, res) + message_header = messaging.MessageHeader.Deserialize(data[0]) + self.assertNotEquals(message_header.request_id, 0) + response_header_data = messaging.MessageHeader( + 0, + messaging.MESSAGE_IS_RESPONSE_FLAG, + message_header.request_id).Serialize() + self.handle.WriteMessage(response_header_data) + self.loop.RunUntilIdle() + self.assertFalse(self.received_errors) + self.assertEquals(len(back_messages), 1) + self.assertEquals(back_messages[0].data, response_header_data) |