summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorqsr <qsr@chromium.org>2014-09-26 09:28:25 -0700
committerCommit bot <commit-bot@chromium.org>2014-09-26 16:28:49 +0000
commite86ad221e441c938f50e482f8fd6ad7e59bde83c (patch)
tree1264767bbc6f463154c2d062da036fb1c4bea3ee
parentf7d782116910dc7d962c245e0ccac2a5feefa6cc (diff)
downloadchromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.zip
chromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.tar.gz
chromium_src-e86ad221e441c938f50e482f8fd6ad7e59bde83c.tar.bz2
mojo: Add router for python bindings.
This is a reland of https://codereview.chromium.org/607513003 with the fix for the failing test. BUG=417707 TBR=sdefresne@chromium.org Review URL: https://codereview.chromium.org/612443002 Cr-Commit-Position: refs/heads/master@{#296958}
-rw-r--r--mojo/public/python/mojo/bindings/messaging.py220
-rw-r--r--mojo/python/tests/messaging_unittest.py150
2 files changed, 358 insertions, 12 deletions
diff --git a/mojo/public/python/mojo/bindings/messaging.py b/mojo/public/python/mojo/bindings/messaging.py
index a6eb575..956f5b3 100644
--- a/mojo/public/python/mojo/bindings/messaging.py
+++ b/mojo/public/python/mojo/bindings/messaging.py
@@ -5,18 +5,140 @@
"""Utility classes to handle sending and receiving messages."""
+import struct
import weakref
# pylint: disable=F0401
+import mojo.bindings.serialization as serialization
import mojo.system as system
+# The flag values for a message header.
+NO_FLAG = 0
+MESSAGE_EXPECTS_RESPONSE_FLAG = 1 << 0
+MESSAGE_IS_RESPONSE_FLAG = 1 << 1
+
+
+class MessageHeader(object):
+ """The header of a mojo message."""
+
+ _SIMPLE_MESSAGE_NUM_FIELDS = 2
+ _SIMPLE_MESSAGE_STRUCT = struct.Struct("=IIII")
+
+ _REQUEST_ID_STRUCT = struct.Struct("=Q")
+ _REQUEST_ID_OFFSET = _SIMPLE_MESSAGE_STRUCT.size
+
+ _MESSAGE_WITH_REQUEST_ID_NUM_FIELDS = 3
+ _MESSAGE_WITH_REQUEST_ID_SIZE = (
+ _SIMPLE_MESSAGE_STRUCT.size + _REQUEST_ID_STRUCT.size)
+
+ def __init__(self, message_type, flags, request_id=0, data=None):
+ self._message_type = message_type
+ self._flags = flags
+ self._request_id = request_id
+ self._data = data
+
+ @classmethod
+ def Deserialize(cls, data):
+ buf = buffer(data)
+ if len(data) < cls._SIMPLE_MESSAGE_STRUCT.size:
+ raise serialization.DeserializationException('Header is too short.')
+ (size, version, message_type, flags) = (
+ cls._SIMPLE_MESSAGE_STRUCT.unpack_from(buf))
+ if (version < cls._SIMPLE_MESSAGE_NUM_FIELDS):
+ raise serialization.DeserializationException('Incorrect version.')
+ request_id = 0
+ if _HasRequestId(flags):
+ if version < cls._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS:
+ raise serialization.DeserializationException('Incorrect version.')
+ if (size < cls._MESSAGE_WITH_REQUEST_ID_SIZE or
+ len(data) < cls._MESSAGE_WITH_REQUEST_ID_SIZE):
+ raise serialization.DeserializationException('Header is too short.')
+ (request_id, ) = cls._REQUEST_ID_STRUCT.unpack_from(
+ buf, cls._REQUEST_ID_OFFSET)
+ return MessageHeader(message_type, flags, request_id, data)
+
+ @property
+ def message_type(self):
+ return self._message_type
+
+ # pylint: disable=E0202
+ @property
+ def request_id(self):
+ assert self.has_request_id
+ return self._request_id
+
+ # pylint: disable=E0202
+ @request_id.setter
+ def request_id(self, request_id):
+ assert self.has_request_id
+ self._request_id = request_id
+ self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
+ request_id)
+
+ @property
+ def has_request_id(self):
+ return _HasRequestId(self._flags)
+
+ @property
+ def expects_response(self):
+ return self._HasFlag(MESSAGE_EXPECTS_RESPONSE_FLAG)
+
+ @property
+ def is_response(self):
+ return self._HasFlag(MESSAGE_IS_RESPONSE_FLAG)
+
+ @property
+ def size(self):
+ if self.has_request_id:
+ return self._MESSAGE_WITH_REQUEST_ID_SIZE
+ return self._SIMPLE_MESSAGE_STRUCT.size
+
+ def Serialize(self):
+ if not self._data:
+ self._data = bytearray(self.size)
+ version = self._SIMPLE_MESSAGE_NUM_FIELDS
+ size = self._SIMPLE_MESSAGE_STRUCT.size
+ if self.has_request_id:
+ version = self._MESSAGE_WITH_REQUEST_ID_NUM_FIELDS
+ size = self._MESSAGE_WITH_REQUEST_ID_SIZE
+ self._SIMPLE_MESSAGE_STRUCT.pack_into(self._data, 0, size, version,
+ self._message_type, self._flags)
+ if self.has_request_id:
+ self._REQUEST_ID_STRUCT.pack_into(self._data, self._REQUEST_ID_OFFSET,
+ self._request_id)
+ return self._data
+
+ def _HasFlag(self, flag):
+ return self._flags & flag != 0
+
+
class Message(object):
"""A message for a message pipe. This contains data and handles."""
def __init__(self, data=None, handles=None):
self.data = data
self.handles = handles
+ self._header = None
+ self._payload = None
+
+ @property
+ def header(self):
+ if self._header is None:
+ self._header = MessageHeader.Deserialize(self.data)
+ return self._header
+
+ @property
+ def payload(self):
+ if self._payload is None:
+ self._payload = Message(self.data[self.header.size:], self.handles)
+ return self._payload
+
+ def SetRequestId(self, request_id):
+ header = self.header
+ header.request_id = request_id
+ (data, _) = header.Serialize()
+ self.data[:header.Size] = data[:header.Size]
class MessageReceiver(object):
@@ -111,6 +233,12 @@ class Connector(MessageReceiver):
result = self._handle.WriteMessage(message.data, message.handles)
return result == system.RESULT_OK
+ def Close(self):
+ if self._cancellable:
+ self._cancellable()
+ self._cancellable = None
+ self._handle.Close()
+
def _OnAsyncWaiterResult(self, result):
self._cancellable = None
if result == system.RESULT_OK:
@@ -141,6 +269,96 @@ class Connector(MessageReceiver):
self._OnError(result)
+class Router(MessageReceiverWithResponder):
+ """
+ A Router will handle mojo message and forward those to a Connector. It deals
+ with parsing of headers and adding of request ids in order to be able to match
+ a response to a request.
+ """
+
+ def __init__(self, handle):
+ MessageReceiverWithResponder.__init__(self)
+ self._incoming_message_receiver = None
+ self._next_request_id = 1
+ self._responders = {}
+ self._connector = Connector(handle)
+ self._connector.SetIncomingMessageReceiver(
+ ForwardingMessageReceiver(self._HandleIncomingMessage))
+
+ def Start(self):
+ self._connector.Start()
+
+ def SetIncomingMessageReceiver(self, message_receiver):
+ """
+ Set the MessageReceiver that will receive message from the owned message
+ pipe.
+ """
+ self._incoming_message_receiver = message_receiver
+
+ def SetErrorHandler(self, error_handler):
+ """
+ Set the ConnectionErrorHandler that will be notified of errors on the owned
+ message pipe.
+ """
+ self._connector.SetErrorHandler(error_handler)
+
+ def Accept(self, message):
+ # A message without responder is directly forwarded to the connector.
+ return self._connector.Accept(message)
+
+ def AcceptWithResponder(self, message, responder):
+ # The message must have a header.
+ header = message.header
+ assert header.expects_response
+ request_id = self.NextRequestId()
+ header.request_id = request_id
+ if not self._connector.Accept(message):
+ return False
+ self._responders[request_id] = responder
+ return True
+
+ def Close(self):
+ self._connector.Close()
+
+ def _HandleIncomingMessage(self, message):
+ header = message.header
+ if header.expects_response:
+ if self._incoming_message_receiver:
+ return self._incoming_message_receiver.AcceptWithResponder(
+ message, self)
+ # If we receive a request expecting a response when the client is not
+ # listening, then we have no choice but to tear down the pipe.
+ self.Close()
+ return False
+ if header.is_response:
+ request_id = header.request_id
+ responder = self._responders.pop(request_id, None)
+ if responder is None:
+ return False
+ return responder.Accept(message)
+ if self._incoming_message_receiver:
+ return self._incoming_message_receiver.Accept(message)
+ # Ok to drop the message
+ return False
+
+ def NextRequestId(self):
+ request_id = self._next_request_id
+ while request_id == 0 or request_id in self._responders:
+ request_id = (request_id + 1) % (1 << 64)
+ self._next_request_id = (request_id + 1) % (1 << 64)
+ return request_id
+
+class ForwardingMessageReceiver(MessageReceiver):
+ """A MessageReceiver that forward calls to |Accept| to a callable."""
+
+ def __init__(self, callback):
+ MessageReceiver.__init__(self)
+ self._callback = callback
+
+ def Accept(self, message):
+ return self._callback(message)
+
+
def _WeakCallback(callback):
func = callback.im_func
self = callback.im_self
@@ -165,3 +383,5 @@ def _ReadAndDispatchMessage(handle, message_receiver):
message_receiver.Accept(Message(data[0], data[1]))
return result
+def _HasRequestId(flags):
+ return flags & (MESSAGE_EXPECTS_RESPONSE_FLAG|MESSAGE_IS_RESPONSE_FLAG) != 0
diff --git a/mojo/python/tests/messaging_unittest.py b/mojo/python/tests/messaging_unittest.py
index c67048b..2d08941 100644
--- a/mojo/python/tests/messaging_unittest.py
+++ b/mojo/python/tests/messaging_unittest.py
@@ -10,16 +10,6 @@ from mojo.bindings import messaging
from mojo import system
-class _ForwardingMessageReceiver(messaging.MessageReceiver):
-
- def __init__(self, callback):
- self._callback = callback
-
- def Accept(self, message):
- self._callback(message)
- return True
-
-
class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler):
def __init__(self, callback):
@@ -29,7 +19,7 @@ class _ForwardingConnectionErrorHandler(messaging.ConnectionErrorHandler):
self._callback(result)
-class MessagingTest(unittest.TestCase):
+class ConnectorTest(unittest.TestCase):
def setUp(self):
mojo.embedder.Init()
@@ -38,12 +28,13 @@ class MessagingTest(unittest.TestCase):
self.received_errors = []
def _OnMessage(message):
self.received_messages.append(message)
+ return True
def _OnError(result):
self.received_errors.append(result)
handles = system.MessagePipe()
self.connector = messaging.Connector(handles.handle1)
self.connector.SetIncomingMessageReceiver(
- _ForwardingMessageReceiver(_OnMessage))
+ messaging.ForwardingMessageReceiver(_OnMessage))
self.connector.SetErrorHandler(
_ForwardingConnectionErrorHandler(_OnError))
self.connector.Start()
@@ -79,3 +70,138 @@ class MessagingTest(unittest.TestCase):
self.connector = None
(result, _, _) = self.handle.ReadMessage()
self.assertEquals(result, system.RESULT_FAILED_PRECONDITION)
+
+
+class HeaderTest(unittest.TestCase):
+
+ def testSimpleMessageHeader(self):
+ header = messaging.MessageHeader(0xdeadbeaf, messaging.NO_FLAG)
+ self.assertEqual(header.message_type, 0xdeadbeaf)
+ self.assertFalse(header.has_request_id)
+ self.assertFalse(header.expects_response)
+ self.assertFalse(header.is_response)
+ data = header.Serialize()
+ other_header = messaging.MessageHeader.Deserialize(data)
+ self.assertEqual(other_header.message_type, 0xdeadbeaf)
+ self.assertFalse(other_header.has_request_id)
+ self.assertFalse(other_header.expects_response)
+ self.assertFalse(other_header.is_response)
+
+ def testMessageHeaderWithRequestID(self):
+ # Request message.
+ header = messaging.MessageHeader(0xdeadbeaf,
+ messaging.MESSAGE_EXPECTS_RESPONSE_FLAG)
+
+ self.assertEqual(header.message_type, 0xdeadbeaf)
+ self.assertTrue(header.has_request_id)
+ self.assertTrue(header.expects_response)
+ self.assertFalse(header.is_response)
+ self.assertEqual(header.request_id, 0)
+
+ data = header.Serialize()
+ other_header = messaging.MessageHeader.Deserialize(data)
+
+ self.assertEqual(other_header.message_type, 0xdeadbeaf)
+ self.assertTrue(other_header.has_request_id)
+ self.assertTrue(other_header.expects_response)
+ self.assertFalse(other_header.is_response)
+ self.assertEqual(other_header.request_id, 0)
+
+ header.request_id = 0xdeadbeafdeadbeaf
+ data = header.Serialize()
+ other_header = messaging.MessageHeader.Deserialize(data)
+
+ self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf)
+
+ # Response message.
+ header = messaging.MessageHeader(0xdeadbeaf,
+ messaging.MESSAGE_IS_RESPONSE_FLAG,
+ 0xdeadbeafdeadbeaf)
+
+ self.assertEqual(header.message_type, 0xdeadbeaf)
+ self.assertTrue(header.has_request_id)
+ self.assertFalse(header.expects_response)
+ self.assertTrue(header.is_response)
+ self.assertEqual(header.request_id, 0xdeadbeafdeadbeaf)
+
+ data = header.Serialize()
+ other_header = messaging.MessageHeader.Deserialize(data)
+
+ self.assertEqual(other_header.message_type, 0xdeadbeaf)
+ self.assertTrue(other_header.has_request_id)
+ self.assertFalse(other_header.expects_response)
+ self.assertTrue(other_header.is_response)
+ self.assertEqual(other_header.request_id, 0xdeadbeafdeadbeaf)
+
+
+class RouterTest(unittest.TestCase):
+
+ def setUp(self):
+ mojo.embedder.Init()
+ self.loop = system.RunLoop()
+ self.received_messages = []
+ self.received_errors = []
+ def _OnMessage(message):
+ self.received_messages.append(message)
+ return True
+ def _OnError(result):
+ self.received_errors.append(result)
+ handles = system.MessagePipe()
+ self.router = messaging.Router(handles.handle1)
+ self.router.SetIncomingMessageReceiver(
+ messaging.ForwardingMessageReceiver(_OnMessage))
+ self.router.SetErrorHandler(
+ _ForwardingConnectionErrorHandler(_OnError))
+ self.router.Start()
+ self.handle = handles.handle0
+
+ def tearDown(self):
+ self.router = None
+ self.handle = None
+ self.loop = None
+
+ def testSimpleMessage(self):
+ header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize()
+ message = messaging.Message(header_data)
+ self.router.Accept(message)
+ self.loop.RunUntilIdle()
+ self.assertFalse(self.received_errors)
+ self.assertFalse(self.received_messages)
+ (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data)))
+ self.assertEquals(system.RESULT_OK, res)
+ self.assertEquals(data[0], header_data)
+
+ def testSimpleReception(self):
+ header_data = messaging.MessageHeader(0, messaging.NO_FLAG).Serialize()
+ self.handle.WriteMessage(header_data)
+ self.loop.RunUntilIdle()
+ self.assertFalse(self.received_errors)
+ self.assertEquals(len(self.received_messages), 1)
+ self.assertEquals(self.received_messages[0].data, header_data)
+
+ def testRequestResponse(self):
+ header_data = messaging.MessageHeader(
+ 0, messaging.MESSAGE_EXPECTS_RESPONSE_FLAG).Serialize()
+ message = messaging.Message(header_data)
+ back_messages = []
+ def OnBackMessage(message):
+ back_messages.append(message)
+ self.router.AcceptWithResponder(message,
+ messaging.ForwardingMessageReceiver(
+ OnBackMessage))
+ self.loop.RunUntilIdle()
+ self.assertFalse(self.received_errors)
+ self.assertFalse(self.received_messages)
+ (res, data, _) = self.handle.ReadMessage(bytearray(len(header_data)))
+ self.assertEquals(system.RESULT_OK, res)
+ message_header = messaging.MessageHeader.Deserialize(data[0])
+ self.assertNotEquals(message_header.request_id, 0)
+ response_header_data = messaging.MessageHeader(
+ 0,
+ messaging.MESSAGE_IS_RESPONSE_FLAG,
+ message_header.request_id).Serialize()
+ self.handle.WriteMessage(response_header_data)
+ self.loop.RunUntilIdle()
+ self.assertFalse(self.received_errors)
+ self.assertEquals(len(back_messages), 1)
+ self.assertEquals(back_messages[0].data, response_header_data)