summaryrefslogtreecommitdiffstats
path: root/sync/tools/testserver
diff options
context:
space:
mode:
authorrsimha@chromium.org <rsimha@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-01-20 01:10:24 +0000
committerrsimha@chromium.org <rsimha@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98>2013-01-20 01:10:24 +0000
commite4c029f76eb948af468a4d11ec0d3272671ddb58 (patch)
treef80be661f716ca63d85e2f94d562b89c242af599 /sync/tools/testserver
parent00353d188a87c6a2f953e22a73d7c18fa2c37b2a (diff)
downloadchromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.zip
chromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.tar.gz
chromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.tar.bz2
[sync] Divorce python sync test server chromiumsync.py from testserver.py
Various chrome test suites use the infrastructure in net::LocalTestServer and net/tools/testserver.py to create local test server instances against which to run automated tests. Sync tests use reference implementations of sync and xmpp servers, which build on the testserver infrastructure in net/. In the past, the sync testserver was small enough that it made sense for it to be a part of the testserver in net/. This, however, resulted in an unwanted dependency from net/ onto sync/, due to the sync proto modules needed to run a python sync server. Now that the sync testserver has grown considerably in scope, it is time to separate it out from net/ while reusing base testserver code, and eliminate the dependency from net/ onto sync/. This work also provides us with the opportunity to remove a whole bunch of dead pyauto sync test code in chrome/test/functional. This patch does the following: - Moves the native class LocalSyncTestServer from net/test/ to sync/test/. - Moves chromiumsync{_test}.py and xmppserver{_test}.py from net/tools/testserver/ to sync/tools/testserver/. - Removes all sync server specific code from net/. - Adds a new sync_testserver.py runner script for the python sync test. - Moves some base classes from testserver.py to testserver_base.py so they can be reused by sync_testserver.py. - Audits all the python imports in testserver.py, testserver_base.py and sync_testserver.py to make sure there are no unnecessary / missing imports. - Adds a new run_sync_testserver runner executable to launch a sync testserver. - Removes a couple of static methods from LocalTestServer, that were being used by run_testserver, and refactors run_sync_testserver to use their non-static versions. - Adds the ability to run both chromiumsync_test.py and xmppserver_test.py from run_sync_testserver. - Fixes chromiumsync.py to undo / rectify some older changes that broke tests in chromiumsync_test.py. - Adds a new test target called test_support_sync_testserver to sync.gyp. - Removes the hacky dependency on sync_proto from net.gyp:net_test_support. - Updates various gyp files across chrome to use the new sync testserver target. - Audits dependencies of net_test_support, run_testserver, and the newly added targets. - Fixes the android chrome testserver spawner script to account for the above changes. - Removes all mentions of TYPE_SYNC from the pyauto TestServer shim. - Deletes all (deprecated) pyauto sync tests. (They had all become broken over time, gotten disabled, and were all redundant due to their equivalent sync integration tests.) - Removes all sync related pyauto hooks from TestingAutomationProvider, since they are no longer going to be used. - Takes care of a TODO in safe_browser_testserver.py to remove an unnecessary code block. Note: A majority of the bugs listed below are for individual pyauto sync tests. Deleting the sync pyauto test script fixes all these bugs in one fell swoop. TBR=mattm@chromium.org BUG=117559, 119403, 159731, 15016, 80329, 49378, 87642, 86949, 88679, 104227, 88593, 124913 TEST=run_testserver, run_sync_testserver, sync_integration_tests, sync_performance_tests. All chrome tests that use a testserver should continue to work. Review URL: https://codereview.chromium.org/11971025 git-svn-id: svn://svn.chromium.org/chrome/trunk/src@177864 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'sync/tools/testserver')
-rw-r--r--sync/tools/testserver/DEPS3
-rw-r--r--sync/tools/testserver/OWNERS3
-rw-r--r--sync/tools/testserver/chromiumsync.py1370
-rwxr-xr-xsync/tools/testserver/chromiumsync_test.py655
-rw-r--r--sync/tools/testserver/run_sync_testserver.cc121
-rwxr-xr-xsync/tools/testserver/sync_testserver.py447
-rw-r--r--sync/tools/testserver/xmppserver.py594
-rwxr-xr-xsync/tools/testserver/xmppserver_test.py421
8 files changed, 3614 insertions, 0 deletions
diff --git a/sync/tools/testserver/DEPS b/sync/tools/testserver/DEPS
new file mode 100644
index 0000000..f9b201f
--- /dev/null
+++ b/sync/tools/testserver/DEPS
@@ -0,0 +1,3 @@
+include_rules = [
+ "+sync/test",
+]
diff --git a/sync/tools/testserver/OWNERS b/sync/tools/testserver/OWNERS
new file mode 100644
index 0000000..e628479
--- /dev/null
+++ b/sync/tools/testserver/OWNERS
@@ -0,0 +1,3 @@
+akalin@chromium.org
+nick@chromium.org
+rsimha@chromium.org
diff --git a/sync/tools/testserver/chromiumsync.py b/sync/tools/testserver/chromiumsync.py
new file mode 100644
index 0000000..eb631c2
--- /dev/null
+++ b/sync/tools/testserver/chromiumsync.py
@@ -0,0 +1,1370 @@
+# Copyright 2013 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""An implementation of the server side of the Chromium sync protocol.
+
+The details of the protocol are described mostly by comments in the protocol
+buffer definition at chrome/browser/sync/protocol/sync.proto.
+"""
+
+import cgi
+import copy
+import operator
+import pickle
+import random
+import string
+import sys
+import threading
+import time
+import urlparse
+
+import app_notification_specifics_pb2
+import app_setting_specifics_pb2
+import app_specifics_pb2
+import autofill_specifics_pb2
+import bookmark_specifics_pb2
+import get_updates_caller_info_pb2
+import extension_setting_specifics_pb2
+import extension_specifics_pb2
+import history_delete_directive_specifics_pb2
+import nigori_specifics_pb2
+import password_specifics_pb2
+import preference_specifics_pb2
+import search_engine_specifics_pb2
+import session_specifics_pb2
+import sync_pb2
+import sync_enums_pb2
+import synced_notification_specifics_pb2
+import theme_specifics_pb2
+import typed_url_specifics_pb2
+
+# An enumeration of the various kinds of data that can be synced.
+# Over the wire, this enumeration is not used: a sync object's type is
+# inferred by which EntitySpecifics field it has. But in the context
+# of a program, it is useful to have an enumeration.
+ALL_TYPES = (
+ TOP_LEVEL, # The type of the 'Google Chrome' folder.
+ APPS,
+ APP_NOTIFICATION,
+ APP_SETTINGS,
+ AUTOFILL,
+ AUTOFILL_PROFILE,
+ BOOKMARK,
+ DEVICE_INFO,
+ EXPERIMENTS,
+ EXTENSIONS,
+ HISTORY_DELETE_DIRECTIVE,
+ NIGORI,
+ PASSWORD,
+ PREFERENCE,
+ SEARCH_ENGINE,
+ SESSION,
+ SYNCED_NOTIFICATION,
+ THEME,
+ TYPED_URL,
+ EXTENSION_SETTINGS) = range(20)
+
+# An eumeration on the frequency at which the server should send errors
+# to the client. This would be specified by the url that triggers the error.
+# Note: This enum should be kept in the same order as the enum in sync_test.h.
+SYNC_ERROR_FREQUENCY = (
+ ERROR_FREQUENCY_NONE,
+ ERROR_FREQUENCY_ALWAYS,
+ ERROR_FREQUENCY_TWO_THIRDS) = range(3)
+
+# Well-known server tag of the top level 'Google Chrome' folder.
+TOP_LEVEL_FOLDER_TAG = 'google_chrome'
+
+# Given a sync type from ALL_TYPES, find the FieldDescriptor corresponding
+# to that datatype. Note that TOP_LEVEL has no such token.
+SYNC_TYPE_FIELDS = sync_pb2.EntitySpecifics.DESCRIPTOR.fields_by_name
+SYNC_TYPE_TO_DESCRIPTOR = {
+ APP_NOTIFICATION: SYNC_TYPE_FIELDS['app_notification'],
+ APP_SETTINGS: SYNC_TYPE_FIELDS['app_setting'],
+ APPS: SYNC_TYPE_FIELDS['app'],
+ AUTOFILL: SYNC_TYPE_FIELDS['autofill'],
+ AUTOFILL_PROFILE: SYNC_TYPE_FIELDS['autofill_profile'],
+ BOOKMARK: SYNC_TYPE_FIELDS['bookmark'],
+ DEVICE_INFO: SYNC_TYPE_FIELDS['device_info'],
+ EXPERIMENTS: SYNC_TYPE_FIELDS['experiments'],
+ EXTENSION_SETTINGS: SYNC_TYPE_FIELDS['extension_setting'],
+ EXTENSIONS: SYNC_TYPE_FIELDS['extension'],
+ HISTORY_DELETE_DIRECTIVE: SYNC_TYPE_FIELDS['history_delete_directive'],
+ NIGORI: SYNC_TYPE_FIELDS['nigori'],
+ PASSWORD: SYNC_TYPE_FIELDS['password'],
+ PREFERENCE: SYNC_TYPE_FIELDS['preference'],
+ SEARCH_ENGINE: SYNC_TYPE_FIELDS['search_engine'],
+ SESSION: SYNC_TYPE_FIELDS['session'],
+ SYNCED_NOTIFICATION: SYNC_TYPE_FIELDS["synced_notification"],
+ THEME: SYNC_TYPE_FIELDS['theme'],
+ TYPED_URL: SYNC_TYPE_FIELDS['typed_url'],
+ }
+
+# The parent ID used to indicate a top-level node.
+ROOT_ID = '0'
+
+# Unix time epoch in struct_time format. The tuple corresponds to UTC Wednesday
+# Jan 1 1970, 00:00:00, non-dst.
+UNIX_TIME_EPOCH = (1970, 1, 1, 0, 0, 0, 3, 1, 0)
+
+# The number of characters in the server-generated encryption key.
+KEYSTORE_KEY_LENGTH = 16
+
+# The hashed client tag for the keystore encryption experiment node.
+KEYSTORE_ENCRYPTION_EXPERIMENT_TAG = "pis8ZRzh98/MKLtVEio2mr42LQA="
+
+class Error(Exception):
+ """Error class for this module."""
+
+
+class ProtobufDataTypeFieldNotUnique(Error):
+ """An entry should not have more than one data type present."""
+
+
+class DataTypeIdNotRecognized(Error):
+ """The requested data type is not recognized."""
+
+
+class MigrationDoneError(Error):
+ """A server-side migration occurred; clients must re-sync some datatypes.
+
+ Attributes:
+ datatypes: a list of the datatypes (python enum) needing migration.
+ """
+
+ def __init__(self, datatypes):
+ self.datatypes = datatypes
+
+
+class StoreBirthdayError(Error):
+ """The client sent a birthday that doesn't correspond to this server."""
+
+
+class TransientError(Error):
+ """The client would be sent a transient error."""
+
+
+class SyncInducedError(Error):
+ """The client would be sent an error."""
+
+
+class InducedErrorFrequencyNotDefined(Error):
+ """The error frequency defined is not handled."""
+
+
+def GetEntryType(entry):
+ """Extract the sync type from a SyncEntry.
+
+ Args:
+ entry: A SyncEntity protobuf object whose type to determine.
+ Returns:
+ An enum value from ALL_TYPES if the entry's type can be determined, or None
+ if the type cannot be determined.
+ Raises:
+ ProtobufDataTypeFieldNotUnique: More than one type was indicated by
+ the entry.
+ """
+ if entry.server_defined_unique_tag == TOP_LEVEL_FOLDER_TAG:
+ return TOP_LEVEL
+ entry_types = GetEntryTypesFromSpecifics(entry.specifics)
+ if not entry_types:
+ return None
+
+ # If there is more than one, either there's a bug, or else the caller
+ # should use GetEntryTypes.
+ if len(entry_types) > 1:
+ raise ProtobufDataTypeFieldNotUnique
+ return entry_types[0]
+
+
+def GetEntryTypesFromSpecifics(specifics):
+ """Determine the sync types indicated by an EntitySpecifics's field(s).
+
+ If the specifics have more than one recognized data type field (as commonly
+ happens with the requested_types field of GetUpdatesMessage), all types
+ will be returned. Callers must handle the possibility of the returned
+ value having more than one item.
+
+ Args:
+ specifics: A EntitySpecifics protobuf message whose extensions to
+ enumerate.
+ Returns:
+ A list of the sync types (values from ALL_TYPES) associated with each
+ recognized extension of the specifics message.
+ """
+ return [data_type for data_type, field_descriptor
+ in SYNC_TYPE_TO_DESCRIPTOR.iteritems()
+ if specifics.HasField(field_descriptor.name)]
+
+
+def SyncTypeToProtocolDataTypeId(data_type):
+ """Convert from a sync type (python enum) to the protocol's data type id."""
+ return SYNC_TYPE_TO_DESCRIPTOR[data_type].number
+
+
+def ProtocolDataTypeIdToSyncType(protocol_data_type_id):
+ """Convert from the protocol's data type id to a sync type (python enum)."""
+ for data_type, field_descriptor in SYNC_TYPE_TO_DESCRIPTOR.iteritems():
+ if field_descriptor.number == protocol_data_type_id:
+ return data_type
+ raise DataTypeIdNotRecognized
+
+
+def DataTypeStringToSyncTypeLoose(data_type_string):
+ """Converts a human-readable string to a sync type (python enum).
+
+ Capitalization and pluralization don't matter; this function is appropriate
+ for values that might have been typed by a human being; e.g., command-line
+ flags or query parameters.
+ """
+ if data_type_string.isdigit():
+ return ProtocolDataTypeIdToSyncType(int(data_type_string))
+ name = data_type_string.lower().rstrip('s')
+ for data_type, field_descriptor in SYNC_TYPE_TO_DESCRIPTOR.iteritems():
+ if field_descriptor.name.lower().rstrip('s') == name:
+ return data_type
+ raise DataTypeIdNotRecognized
+
+
+def MakeNewKeystoreKey():
+ """Returns a new random keystore key."""
+ return ''.join(random.choice(string.ascii_uppercase + string.digits)
+ for x in xrange(KEYSTORE_KEY_LENGTH))
+
+
+def SyncTypeToString(data_type):
+ """Formats a sync type enum (from ALL_TYPES) to a human-readable string."""
+ return SYNC_TYPE_TO_DESCRIPTOR[data_type].name
+
+
+def CallerInfoToString(caller_info_source):
+ """Formats a GetUpdatesSource enum value to a readable string."""
+ return get_updates_caller_info_pb2.GetUpdatesCallerInfo \
+ .DESCRIPTOR.enum_types_by_name['GetUpdatesSource'] \
+ .values_by_number[caller_info_source].name
+
+
+def ShortDatatypeListSummary(data_types):
+ """Formats compactly a list of sync types (python enums) for human eyes.
+
+ This function is intended for use by logging. If the list of datatypes
+ contains almost all of the values, the return value will be expressed
+ in terms of the datatypes that aren't set.
+ """
+ included = set(data_types) - set([TOP_LEVEL])
+ if not included:
+ return 'nothing'
+ excluded = set(ALL_TYPES) - included - set([TOP_LEVEL])
+ if not excluded:
+ return 'everything'
+ simple_text = '+'.join(sorted([SyncTypeToString(x) for x in included]))
+ all_but_text = 'all except %s' % (
+ '+'.join(sorted([SyncTypeToString(x) for x in excluded])))
+ if len(included) < len(excluded) or len(simple_text) <= len(all_but_text):
+ return simple_text
+ else:
+ return all_but_text
+
+
+def GetDefaultEntitySpecifics(data_type):
+ """Get an EntitySpecifics having a sync type's default field value."""
+ specifics = sync_pb2.EntitySpecifics()
+ if data_type in SYNC_TYPE_TO_DESCRIPTOR:
+ descriptor = SYNC_TYPE_TO_DESCRIPTOR[data_type]
+ getattr(specifics, descriptor.name).SetInParent()
+ return specifics
+
+
+class PermanentItem(object):
+ """A specification of one server-created permanent item.
+
+ Attributes:
+ tag: A known-to-the-client value that uniquely identifies a server-created
+ permanent item.
+ name: The human-readable display name for this item.
+ parent_tag: The tag of the permanent item's parent. If ROOT_ID, indicates
+ a top-level item. Otherwise, this must be the tag value of some other
+ server-created permanent item.
+ sync_type: A value from ALL_TYPES, giving the datatype of this permanent
+ item. This controls which types of client GetUpdates requests will
+ cause the permanent item to be created and returned.
+ create_by_default: Whether the permanent item is created at startup or not.
+ This value is set to True in the default case. Non-default permanent items
+ are those that are created only when a client explicitly tells the server
+ to do so.
+ """
+
+ def __init__(self, tag, name, parent_tag, sync_type, create_by_default=True):
+ self.tag = tag
+ self.name = name
+ self.parent_tag = parent_tag
+ self.sync_type = sync_type
+ self.create_by_default = create_by_default
+
+
+class MigrationHistory(object):
+ """A record of the migration events associated with an account.
+
+ Each migration event invalidates one or more datatypes on all clients
+ that had synced the datatype before the event. Such clients will continue
+ to receive MigrationDone errors until they throw away their progress and
+ re-sync that datatype from the beginning.
+ """
+ def __init__(self):
+ self._migrations = {}
+ for datatype in ALL_TYPES:
+ self._migrations[datatype] = [1]
+ self._next_migration_version = 2
+
+ def GetLatestVersion(self, datatype):
+ return self._migrations[datatype][-1]
+
+ def CheckAllCurrent(self, versions_map):
+ """Raises an error if any the provided versions are out of date.
+
+ This function intentionally returns migrations in the order that they were
+ triggered. Doing it this way allows the client to queue up two migrations
+ in a row, so the second one is received while responding to the first.
+
+ Arguments:
+ version_map: a map whose keys are datatypes and whose values are versions.
+
+ Raises:
+ MigrationDoneError: if a mismatch is found.
+ """
+ problems = {}
+ for datatype, client_migration in versions_map.iteritems():
+ for server_migration in self._migrations[datatype]:
+ if client_migration < server_migration:
+ problems.setdefault(server_migration, []).append(datatype)
+ if problems:
+ raise MigrationDoneError(problems[min(problems.keys())])
+
+ def Bump(self, datatypes):
+ """Add a record of a migration, to cause errors on future requests."""
+ for idx, datatype in enumerate(datatypes):
+ self._migrations[datatype].append(self._next_migration_version)
+ self._next_migration_version += 1
+
+
+class UpdateSieve(object):
+ """A filter to remove items the client has already seen."""
+ def __init__(self, request, migration_history=None):
+ self._original_request = request
+ self._state = {}
+ self._migration_history = migration_history or MigrationHistory()
+ self._migration_versions_to_check = {}
+ if request.from_progress_marker:
+ for marker in request.from_progress_marker:
+ data_type = ProtocolDataTypeIdToSyncType(marker.data_type_id)
+ if marker.HasField('timestamp_token_for_migration'):
+ timestamp = marker.timestamp_token_for_migration
+ if timestamp:
+ self._migration_versions_to_check[data_type] = 1
+ elif marker.token:
+ (timestamp, version) = pickle.loads(marker.token)
+ self._migration_versions_to_check[data_type] = version
+ elif marker.HasField('token'):
+ timestamp = 0
+ else:
+ raise ValueError('No timestamp information in progress marker.')
+ data_type = ProtocolDataTypeIdToSyncType(marker.data_type_id)
+ self._state[data_type] = timestamp
+ elif request.HasField('from_timestamp'):
+ for data_type in GetEntryTypesFromSpecifics(request.requested_types):
+ self._state[data_type] = request.from_timestamp
+ self._migration_versions_to_check[data_type] = 1
+ if self._state:
+ self._state[TOP_LEVEL] = min(self._state.itervalues())
+
+ def SummarizeRequest(self):
+ timestamps = {}
+ for data_type, timestamp in self._state.iteritems():
+ if data_type == TOP_LEVEL:
+ continue
+ timestamps.setdefault(timestamp, []).append(data_type)
+ return ', '.join('<%s>@%d' % (ShortDatatypeListSummary(types), stamp)
+ for stamp, types in sorted(timestamps.iteritems()))
+
+ def CheckMigrationState(self):
+ self._migration_history.CheckAllCurrent(self._migration_versions_to_check)
+
+ def ClientWantsItem(self, item):
+ """Return true if the client hasn't already seen an item."""
+ return self._state.get(GetEntryType(item), sys.maxint) < item.version
+
+ def HasAnyTimestamp(self):
+ """Return true if at least one datatype was requested."""
+ return bool(self._state)
+
+ def GetMinTimestamp(self):
+ """Return true the smallest timestamp requested across all datatypes."""
+ return min(self._state.itervalues())
+
+ def GetFirstTimeTypes(self):
+ """Return a list of datatypes requesting updates from timestamp zero."""
+ return [datatype for datatype, timestamp in self._state.iteritems()
+ if timestamp == 0]
+
+ def SaveProgress(self, new_timestamp, get_updates_response):
+ """Write the new_timestamp or new_progress_marker fields to a response."""
+ if self._original_request.from_progress_marker:
+ for data_type, old_timestamp in self._state.iteritems():
+ if data_type == TOP_LEVEL:
+ continue
+ new_marker = sync_pb2.DataTypeProgressMarker()
+ new_marker.data_type_id = SyncTypeToProtocolDataTypeId(data_type)
+ final_stamp = max(old_timestamp, new_timestamp)
+ final_migration = self._migration_history.GetLatestVersion(data_type)
+ new_marker.token = pickle.dumps((final_stamp, final_migration))
+ if new_marker not in self._original_request.from_progress_marker:
+ get_updates_response.new_progress_marker.add().MergeFrom(new_marker)
+ elif self._original_request.HasField('from_timestamp'):
+ if self._original_request.from_timestamp < new_timestamp:
+ get_updates_response.new_timestamp = new_timestamp
+
+
+class SyncDataModel(object):
+ """Models the account state of one sync user."""
+ _BATCH_SIZE = 100
+
+ # Specify all the permanent items that a model might need.
+ _PERMANENT_ITEM_SPECS = [
+ PermanentItem('google_chrome_apps', name='Apps',
+ parent_tag=ROOT_ID, sync_type=APPS),
+ PermanentItem('google_chrome_app_notifications', name='App Notifications',
+ parent_tag=ROOT_ID, sync_type=APP_NOTIFICATION),
+ PermanentItem('google_chrome_app_settings',
+ name='App Settings',
+ parent_tag=ROOT_ID, sync_type=APP_SETTINGS),
+ PermanentItem('google_chrome_bookmarks', name='Bookmarks',
+ parent_tag=ROOT_ID, sync_type=BOOKMARK),
+ PermanentItem('bookmark_bar', name='Bookmark Bar',
+ parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK),
+ PermanentItem('other_bookmarks', name='Other Bookmarks',
+ parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK),
+ PermanentItem('synced_bookmarks', name='Synced Bookmarks',
+ parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK,
+ create_by_default=False), # Must be True in the iOS tree.
+ PermanentItem('google_chrome_autofill', name='Autofill',
+ parent_tag=ROOT_ID, sync_type=AUTOFILL),
+ PermanentItem('google_chrome_autofill_profiles', name='Autofill Profiles',
+ parent_tag=ROOT_ID, sync_type=AUTOFILL_PROFILE),
+ PermanentItem('google_chrome_device_info', name='Device Info',
+ parent_tag=ROOT_ID, sync_type=DEVICE_INFO),
+ PermanentItem('google_chrome_experiments', name='Experiments',
+ parent_tag=ROOT_ID, sync_type=EXPERIMENTS),
+ PermanentItem('google_chrome_extension_settings',
+ name='Extension Settings',
+ parent_tag=ROOT_ID, sync_type=EXTENSION_SETTINGS),
+ PermanentItem('google_chrome_extensions', name='Extensions',
+ parent_tag=ROOT_ID, sync_type=EXTENSIONS),
+ PermanentItem('google_chrome_history_delete_directives',
+ name='History Delete Directives',
+ parent_tag=ROOT_ID,
+ sync_type=HISTORY_DELETE_DIRECTIVE),
+ PermanentItem('google_chrome_nigori', name='Nigori',
+ parent_tag=ROOT_ID, sync_type=NIGORI),
+ PermanentItem('google_chrome_passwords', name='Passwords',
+ parent_tag=ROOT_ID, sync_type=PASSWORD),
+ PermanentItem('google_chrome_preferences', name='Preferences',
+ parent_tag=ROOT_ID, sync_type=PREFERENCE),
+ PermanentItem('google_chrome_synced_notifications',
+ name='Synced Notifications',
+ parent_tag=ROOT_ID, sync_type=SYNCED_NOTIFICATION),
+ PermanentItem('google_chrome_search_engines', name='Search Engines',
+ parent_tag=ROOT_ID, sync_type=SEARCH_ENGINE),
+ PermanentItem('google_chrome_sessions', name='Sessions',
+ parent_tag=ROOT_ID, sync_type=SESSION),
+ PermanentItem('google_chrome_themes', name='Themes',
+ parent_tag=ROOT_ID, sync_type=THEME),
+ PermanentItem('google_chrome_typed_urls', name='Typed URLs',
+ parent_tag=ROOT_ID, sync_type=TYPED_URL),
+ ]
+
+ def __init__(self):
+ # Monotonically increasing version number. The next object change will
+ # take on this value + 1.
+ self._version = 0
+
+ # The definitive copy of this client's items: a map from ID string to a
+ # SyncEntity protocol buffer.
+ self._entries = {}
+
+ self.ResetStoreBirthday()
+
+ self.migration_history = MigrationHistory()
+
+ self.induced_error = sync_pb2.ClientToServerResponse.Error()
+ self.induced_error_frequency = 0
+ self.sync_count_before_errors = 0
+
+ self._keys = [MakeNewKeystoreKey()]
+
+ def _SaveEntry(self, entry):
+ """Insert or update an entry in the change log, and give it a new version.
+
+ The ID fields of this entry are assumed to be valid server IDs. This
+ entry will be updated with a new version number and sync_timestamp.
+
+ Args:
+ entry: The entry to be added or updated.
+ """
+ self._version += 1
+ # Maintain a global (rather than per-item) sequence number and use it
+ # both as the per-entry version as well as the update-progress timestamp.
+ # This simulates the behavior of the original server implementation.
+ entry.version = self._version
+ entry.sync_timestamp = self._version
+
+ # Preserve the originator info, which the client is not required to send
+ # when updating.
+ base_entry = self._entries.get(entry.id_string)
+ if base_entry:
+ entry.originator_cache_guid = base_entry.originator_cache_guid
+ entry.originator_client_item_id = base_entry.originator_client_item_id
+
+ self._entries[entry.id_string] = copy.deepcopy(entry)
+
+ def _ServerTagToId(self, tag):
+ """Determine the server ID from a server-unique tag.
+
+ The resulting value is guaranteed not to collide with the other ID
+ generation methods.
+
+ Args:
+ datatype: The sync type (python enum) of the identified object.
+ tag: The unique, known-to-the-client tag of a server-generated item.
+ Returns:
+ The string value of the computed server ID.
+ """
+ if not tag or tag == ROOT_ID:
+ return tag
+ spec = [x for x in self._PERMANENT_ITEM_SPECS if x.tag == tag][0]
+ return self._MakeCurrentId(spec.sync_type, '<server tag>%s' % tag)
+
+ def _ClientTagToId(self, datatype, tag):
+ """Determine the server ID from a client-unique tag.
+
+ The resulting value is guaranteed not to collide with the other ID
+ generation methods.
+
+ Args:
+ datatype: The sync type (python enum) of the identified object.
+ tag: The unique, opaque-to-the-server tag of a client-tagged item.
+ Returns:
+ The string value of the computed server ID.
+ """
+ return self._MakeCurrentId(datatype, '<client tag>%s' % tag)
+
+ def _ClientIdToId(self, datatype, client_guid, client_item_id):
+ """Compute a unique server ID from a client-local ID tag.
+
+ The resulting value is guaranteed not to collide with the other ID
+ generation methods.
+
+ Args:
+ datatype: The sync type (python enum) of the identified object.
+ client_guid: A globally unique ID that identifies the client which
+ created this item.
+ client_item_id: An ID that uniquely identifies this item on the client
+ which created it.
+ Returns:
+ The string value of the computed server ID.
+ """
+ # Using the client ID info is not required here (we could instead generate
+ # a random ID), but it's useful for debugging.
+ return self._MakeCurrentId(datatype,
+ '<server ID originally>%s/%s' % (client_guid, client_item_id))
+
+ def _MakeCurrentId(self, datatype, inner_id):
+ return '%d^%d^%s' % (datatype,
+ self.migration_history.GetLatestVersion(datatype),
+ inner_id)
+
+ def _ExtractIdInfo(self, id_string):
+ if not id_string or id_string == ROOT_ID:
+ return None
+ datatype_string, separator, remainder = id_string.partition('^')
+ migration_version_string, separator, inner_id = remainder.partition('^')
+ return (int(datatype_string), int(migration_version_string), inner_id)
+
+ def _WritePosition(self, entry, parent_id):
+ """Ensure the entry has an absolute, numeric position and parent_id.
+
+ Historically, clients would specify positions using the predecessor-based
+ references in the insert_after_item_id field; starting July 2011, this
+ was changed and Chrome now sends up the absolute position. The server
+ must store a position_in_parent value and must not maintain
+ insert_after_item_id.
+
+ Args:
+ entry: The entry for which to write a position. Its ID field are
+ assumed to be server IDs. This entry will have its parent_id_string
+ and position_in_parent fields updated; its insert_after_item_id field
+ will be cleared.
+ parent_id: The ID of the entry intended as the new parent.
+ """
+
+ entry.parent_id_string = parent_id
+ if not entry.HasField('position_in_parent'):
+ entry.position_in_parent = 1337 # A debuggable, distinctive default.
+ entry.ClearField('insert_after_item_id')
+
+ def _ItemExists(self, id_string):
+ """Determine whether an item exists in the changelog."""
+ return id_string in self._entries
+
+ def _CreatePermanentItem(self, spec):
+ """Create one permanent item from its spec, if it doesn't exist.
+
+ The resulting item is added to the changelog.
+
+ Args:
+ spec: A PermanentItem object holding the properties of the item to create.
+ """
+ id_string = self._ServerTagToId(spec.tag)
+ if self._ItemExists(id_string):
+ return
+ print 'Creating permanent item: %s' % spec.name
+ entry = sync_pb2.SyncEntity()
+ entry.id_string = id_string
+ entry.non_unique_name = spec.name
+ entry.name = spec.name
+ entry.server_defined_unique_tag = spec.tag
+ entry.folder = True
+ entry.deleted = False
+ entry.specifics.CopyFrom(GetDefaultEntitySpecifics(spec.sync_type))
+ self._WritePosition(entry, self._ServerTagToId(spec.parent_tag))
+ self._SaveEntry(entry)
+
+ def _CreateDefaultPermanentItems(self, requested_types):
+ """Ensure creation of all default permanent items for a given set of types.
+
+ Args:
+ requested_types: A list of sync data types from ALL_TYPES.
+ All default permanent items of only these types will be created.
+ """
+ for spec in self._PERMANENT_ITEM_SPECS:
+ if spec.sync_type in requested_types and spec.create_by_default:
+ self._CreatePermanentItem(spec)
+
+ def ResetStoreBirthday(self):
+ """Resets the store birthday to a random value."""
+ # TODO(nick): uuid.uuid1() is better, but python 2.5 only.
+ self.store_birthday = '%0.30f' % random.random()
+
+ def StoreBirthday(self):
+ """Gets the store birthday."""
+ return self.store_birthday
+
+ def GetChanges(self, sieve):
+ """Get entries which have changed, oldest first.
+
+ The returned entries are limited to being _BATCH_SIZE many. The entries
+ are returned in strict version order.
+
+ Args:
+ sieve: An update sieve to use to filter out updates the client
+ has already seen.
+ Returns:
+ A tuple of (version, entries, changes_remaining). Version is a new
+ timestamp value, which should be used as the starting point for the
+ next query. Entries is the batch of entries meeting the current
+ timestamp query. Changes_remaining indicates the number of changes
+ left on the server after this batch.
+ """
+ if not sieve.HasAnyTimestamp():
+ return (0, [], 0)
+ min_timestamp = sieve.GetMinTimestamp()
+ self._CreateDefaultPermanentItems(sieve.GetFirstTimeTypes())
+ change_log = sorted(self._entries.values(),
+ key=operator.attrgetter('version'))
+ new_changes = [x for x in change_log if x.version > min_timestamp]
+ # Pick batch_size new changes, and then filter them. This matches
+ # the RPC behavior of the production sync server.
+ batch = new_changes[:self._BATCH_SIZE]
+ if not batch:
+ # Client is up to date.
+ return (min_timestamp, [], 0)
+
+ # Restrict batch to requested types. Tombstones are untyped
+ # and will always get included.
+ filtered = [copy.deepcopy(item) for item in batch
+ if item.deleted or sieve.ClientWantsItem(item)]
+
+ # The new client timestamp is the timestamp of the last item in the
+ # batch, even if that item was filtered out.
+ return (batch[-1].version, filtered, len(new_changes) - len(batch))
+
+ def GetKeystoreKeys(self):
+ """Returns the encryption keys for this account."""
+ print "Returning encryption keys: %s" % self._keys
+ return self._keys
+
+ def _CopyOverImmutableFields(self, entry):
+ """Preserve immutable fields by copying pre-commit state.
+
+ Args:
+ entry: A sync entity from the client.
+ """
+ if entry.id_string in self._entries:
+ if self._entries[entry.id_string].HasField(
+ 'server_defined_unique_tag'):
+ entry.server_defined_unique_tag = (
+ self._entries[entry.id_string].server_defined_unique_tag)
+
+ def _CheckVersionForCommit(self, entry):
+ """Perform an optimistic concurrency check on the version number.
+
+ Clients are only allowed to commit if they report having seen the most
+ recent version of an object.
+
+ Args:
+ entry: A sync entity from the client. It is assumed that ID fields
+ have been converted to server IDs.
+ Returns:
+ A boolean value indicating whether the client's version matches the
+ newest server version for the given entry.
+ """
+ if entry.id_string in self._entries:
+ # Allow edits/deletes if the version matches, and any undeletion.
+ return (self._entries[entry.id_string].version == entry.version or
+ self._entries[entry.id_string].deleted)
+ else:
+ # Allow unknown ID only if the client thinks it's new too.
+ return entry.version == 0
+
+ def _CheckParentIdForCommit(self, entry):
+ """Check that the parent ID referenced in a SyncEntity actually exists.
+
+ Args:
+ entry: A sync entity from the client. It is assumed that ID fields
+ have been converted to server IDs.
+ Returns:
+ A boolean value indicating whether the entity's parent ID is an object
+ that actually exists (and is not deleted) in the current account state.
+ """
+ if entry.parent_id_string == ROOT_ID:
+ # This is generally allowed.
+ return True
+ if entry.parent_id_string not in self._entries:
+ print 'Warning: Client sent unknown ID. Should never happen.'
+ return False
+ if entry.parent_id_string == entry.id_string:
+ print 'Warning: Client sent circular reference. Should never happen.'
+ return False
+ if self._entries[entry.parent_id_string].deleted:
+ # This can happen in a race condition between two clients.
+ return False
+ if not self._entries[entry.parent_id_string].folder:
+ print 'Warning: Client sent non-folder parent. Should never happen.'
+ return False
+ return True
+
+ def _RewriteIdsAsServerIds(self, entry, cache_guid, commit_session):
+ """Convert ID fields in a client sync entry to server IDs.
+
+ A commit batch sent by a client may contain new items for which the
+ server has not generated IDs yet. And within a commit batch, later
+ items are allowed to refer to earlier items. This method will
+ generate server IDs for new items, as well as rewrite references
+ to items whose server IDs were generated earlier in the batch.
+
+ Args:
+ entry: The client sync entry to modify.
+ cache_guid: The globally unique ID of the client that sent this
+ commit request.
+ commit_session: A dictionary mapping the original IDs to the new server
+ IDs, for any items committed earlier in the batch.
+ """
+ if entry.version == 0:
+ data_type = GetEntryType(entry)
+ if entry.HasField('client_defined_unique_tag'):
+ # When present, this should determine the item's ID.
+ new_id = self._ClientTagToId(data_type, entry.client_defined_unique_tag)
+ else:
+ new_id = self._ClientIdToId(data_type, cache_guid, entry.id_string)
+ entry.originator_cache_guid = cache_guid
+ entry.originator_client_item_id = entry.id_string
+ commit_session[entry.id_string] = new_id # Remember the remapping.
+ entry.id_string = new_id
+ if entry.parent_id_string in commit_session:
+ entry.parent_id_string = commit_session[entry.parent_id_string]
+ if entry.insert_after_item_id in commit_session:
+ entry.insert_after_item_id = commit_session[entry.insert_after_item_id]
+
+ def ValidateCommitEntries(self, entries):
+ """Raise an exception if a commit batch contains any global errors.
+
+ Arguments:
+ entries: an iterable containing commit-form SyncEntity protocol buffers.
+
+ Raises:
+ MigrationDoneError: if any of the entries reference a recently-migrated
+ datatype.
+ """
+ server_ids_in_commit = set()
+ local_ids_in_commit = set()
+ for entry in entries:
+ if entry.version:
+ server_ids_in_commit.add(entry.id_string)
+ else:
+ local_ids_in_commit.add(entry.id_string)
+ if entry.HasField('parent_id_string'):
+ if entry.parent_id_string not in local_ids_in_commit:
+ server_ids_in_commit.add(entry.parent_id_string)
+
+ versions_present = {}
+ for server_id in server_ids_in_commit:
+ parsed = self._ExtractIdInfo(server_id)
+ if parsed:
+ datatype, version, _ = parsed
+ versions_present.setdefault(datatype, []).append(version)
+
+ self.migration_history.CheckAllCurrent(
+ dict((k, min(v)) for k, v in versions_present.iteritems()))
+
+ def CommitEntry(self, entry, cache_guid, commit_session):
+ """Attempt to commit one entry to the user's account.
+
+ Args:
+ entry: A SyncEntity protobuf representing desired object changes.
+ cache_guid: A string value uniquely identifying the client; this
+ is used for ID generation and will determine the originator_cache_guid
+ if the entry is new.
+ commit_session: A dictionary mapping client IDs to server IDs for any
+ objects committed earlier this session. If the entry gets a new ID
+ during commit, the change will be recorded here.
+ Returns:
+ A SyncEntity reflecting the post-commit value of the entry, or None
+ if the entry was not committed due to an error.
+ """
+ entry = copy.deepcopy(entry)
+
+ # Generate server IDs for this entry, and write generated server IDs
+ # from earlier entries into the message's fields, as appropriate. The
+ # ID generation state is stored in 'commit_session'.
+ self._RewriteIdsAsServerIds(entry, cache_guid, commit_session)
+
+ # Perform the optimistic concurrency check on the entry's version number.
+ # Clients are not allowed to commit unless they indicate that they've seen
+ # the most recent version of an object.
+ if not self._CheckVersionForCommit(entry):
+ return None
+
+ # Check the validity of the parent ID; it must exist at this point.
+ # TODO(nick): Implement cycle detection and resolution.
+ if not self._CheckParentIdForCommit(entry):
+ return None
+
+ self._CopyOverImmutableFields(entry);
+
+ # At this point, the commit is definitely going to happen.
+
+ # Deletion works by storing a limited record for an entry, called a
+ # tombstone. A sync server must track deleted IDs forever, since it does
+ # not keep track of client knowledge (there's no deletion ACK event).
+ if entry.deleted:
+ def MakeTombstone(id_string):
+ """Make a tombstone entry that will replace the entry being deleted.
+
+ Args:
+ id_string: Index of the SyncEntity to be deleted.
+ Returns:
+ A new SyncEntity reflecting the fact that the entry is deleted.
+ """
+ # Only the ID, version and deletion state are preserved on a tombstone.
+ # TODO(nick): Does the production server not preserve the type? Not
+ # doing so means that tombstones cannot be filtered based on
+ # requested_types at GetUpdates time.
+ tombstone = sync_pb2.SyncEntity()
+ tombstone.id_string = id_string
+ tombstone.deleted = True
+ tombstone.name = ''
+ return tombstone
+
+ def IsChild(child_id):
+ """Check if a SyncEntity is a child of entry, or any of its children.
+
+ Args:
+ child_id: Index of the SyncEntity that is a possible child of entry.
+ Returns:
+ True if it is a child; false otherwise.
+ """
+ if child_id not in self._entries:
+ return False
+ if self._entries[child_id].parent_id_string == entry.id_string:
+ return True
+ return IsChild(self._entries[child_id].parent_id_string)
+
+ # Identify any children entry might have.
+ child_ids = [child.id_string for child in self._entries.itervalues()
+ if IsChild(child.id_string)]
+
+ # Mark all children that were identified as deleted.
+ for child_id in child_ids:
+ self._SaveEntry(MakeTombstone(child_id))
+
+ # Delete entry itself.
+ entry = MakeTombstone(entry.id_string)
+ else:
+ # Comments in sync.proto detail how the representation of positional
+ # ordering works: either the 'insert_after_item_id' field or the
+ # 'position_in_parent' field may determine the sibling order during
+ # Commit operations. The 'position_in_parent' field provides an absolute
+ # ordering in GetUpdates contexts. Here we assume the client will
+ # always send a valid position_in_parent (this is the newer style), and
+ # we ignore insert_after_item_id (an older style).
+ self._WritePosition(entry, entry.parent_id_string)
+
+ # Preserve the originator info, which the client is not required to send
+ # when updating.
+ base_entry = self._entries.get(entry.id_string)
+ if base_entry and not entry.HasField('originator_cache_guid'):
+ entry.originator_cache_guid = base_entry.originator_cache_guid
+ entry.originator_client_item_id = base_entry.originator_client_item_id
+
+ # Store the current time since the Unix epoch in milliseconds.
+ entry.mtime = (int((time.mktime(time.gmtime()) -
+ time.mktime(UNIX_TIME_EPOCH))*1000))
+
+ # Commit the change. This also updates the version number.
+ self._SaveEntry(entry)
+ return entry
+
+ def _RewriteVersionInId(self, id_string):
+ """Rewrites an ID so that its migration version becomes current."""
+ parsed_id = self._ExtractIdInfo(id_string)
+ if not parsed_id:
+ return id_string
+ datatype, old_migration_version, inner_id = parsed_id
+ return self._MakeCurrentId(datatype, inner_id)
+
+ def TriggerMigration(self, datatypes):
+ """Cause a migration to occur for a set of datatypes on this account.
+
+ Clients will see the MIGRATION_DONE error for these datatypes until they
+ resync them.
+ """
+ versions_to_remap = self.migration_history.Bump(datatypes)
+ all_entries = self._entries.values()
+ self._entries.clear()
+ for entry in all_entries:
+ new_id = self._RewriteVersionInId(entry.id_string)
+ entry.id_string = new_id
+ if entry.HasField('parent_id_string'):
+ entry.parent_id_string = self._RewriteVersionInId(
+ entry.parent_id_string)
+ self._entries[entry.id_string] = entry
+
+ def TriggerSyncTabFavicons(self):
+ """Set the 'sync_tab_favicons' field to this account's nigori node.
+
+ If the field is not currently set, will write a new nigori node entry
+ with the field set. Else does nothing.
+ """
+
+ nigori_tag = "google_chrome_nigori"
+ nigori_original = self._entries.get(self._ServerTagToId(nigori_tag))
+ if (nigori_original.specifics.nigori.sync_tab_favicons):
+ return
+ nigori_new = copy.deepcopy(nigori_original)
+ nigori_new.specifics.nigori.sync_tabs = True
+ self._SaveEntry(nigori_new)
+
+ def TriggerCreateSyncedBookmarks(self):
+ """Create the Synced Bookmarks folder under the Bookmarks permanent item.
+
+ Clients will then receive the Synced Bookmarks folder on future
+ GetUpdates, and new bookmarks can be added within the Synced Bookmarks
+ folder.
+ """
+
+ synced_bookmarks_spec, = [spec for spec in self._PERMANENT_ITEM_SPECS
+ if spec.name == "Synced Bookmarks"]
+ self._CreatePermanentItem(synced_bookmarks_spec)
+
+ def TriggerEnableKeystoreEncryption(self):
+ """Create the keystore_encryption experiment entity and enable it.
+
+ A new entity within the EXPERIMENTS datatype is created with the unique
+ client tag "keystore_encryption" if it doesn't already exist. The
+ keystore_encryption message is then filled with |enabled| set to true.
+ """
+
+ experiment_id = self._ServerTagToId("google_chrome_experiments")
+ keystore_encryption_id = self._ClientTagToId(
+ EXPERIMENTS,
+ KEYSTORE_ENCRYPTION_EXPERIMENT_TAG)
+ keystore_entry = self._entries.get(keystore_encryption_id)
+ if keystore_entry is None:
+ keystore_entry = sync_pb2.SyncEntity()
+ keystore_entry.id_string = keystore_encryption_id
+ keystore_entry.name = "Keystore Encryption"
+ keystore_entry.client_defined_unique_tag = (
+ KEYSTORE_ENCRYPTION_EXPERIMENT_TAG)
+ keystore_entry.folder = False
+ keystore_entry.deleted = False
+ keystore_entry.specifics.CopyFrom(GetDefaultEntitySpecifics(EXPERIMENTS))
+ self._WritePosition(keystore_entry, experiment_id)
+
+ keystore_entry.specifics.experiments.keystore_encryption.enabled = True
+
+ self._SaveEntry(keystore_entry)
+
+ def TriggerRotateKeystoreKeys(self):
+ """Rotate the current set of keystore encryption keys.
+
+ |self._keys| will have a new random encryption key appended to it. We touch
+ the nigori node so that each client will receive the new encryption keys
+ only once.
+ """
+
+ # Add a new encryption key.
+ self._keys += [MakeNewKeystoreKey(), ]
+
+ # Increment the nigori node's timestamp, so clients will get the new keys
+ # on their next GetUpdates (any time the nigori node is sent back, we also
+ # send back the keystore keys).
+ nigori_tag = "google_chrome_nigori"
+ self._SaveEntry(self._entries.get(self._ServerTagToId(nigori_tag)))
+
+ def SetInducedError(self, error, error_frequency,
+ sync_count_before_errors):
+ self.induced_error = error
+ self.induced_error_frequency = error_frequency
+ self.sync_count_before_errors = sync_count_before_errors
+
+ def GetInducedError(self):
+ return self.induced_error
+
+
+class TestServer(object):
+ """An object to handle requests for one (and only one) Chrome Sync account.
+
+ TestServer consumes the sync command messages that are the outermost
+ layers of the protocol, performs the corresponding actions on its
+ SyncDataModel, and constructs an appropriate response message.
+ """
+
+ def __init__(self):
+ # The implementation supports exactly one account; its state is here.
+ self.account = SyncDataModel()
+ self.account_lock = threading.Lock()
+ # Clients that have talked to us: a map from the full client ID
+ # to its nickname.
+ self.clients = {}
+ self.client_name_generator = ('+' * times + chr(c)
+ for times in xrange(0, sys.maxint) for c in xrange(ord('A'), ord('Z')))
+ self.transient_error = False
+ self.sync_count = 0
+
+ def GetShortClientName(self, query):
+ parsed = cgi.parse_qs(query[query.find('?')+1:])
+ client_id = parsed.get('client_id')
+ if not client_id:
+ return '?'
+ client_id = client_id[0]
+ if client_id not in self.clients:
+ self.clients[client_id] = self.client_name_generator.next()
+ return self.clients[client_id]
+
+ def CheckStoreBirthday(self, request):
+ """Raises StoreBirthdayError if the request's birthday is a mismatch."""
+ if not request.HasField('store_birthday'):
+ return
+ if self.account.StoreBirthday() != request.store_birthday:
+ raise StoreBirthdayError
+
+ def CheckTransientError(self):
+ """Raises TransientError if transient_error variable is set."""
+ if self.transient_error:
+ raise TransientError
+
+ def CheckSendError(self):
+ """Raises SyncInducedError if needed."""
+ if (self.account.induced_error.error_type !=
+ sync_enums_pb2.SyncEnums.UNKNOWN):
+ # Always means return the given error for all requests.
+ if self.account.induced_error_frequency == ERROR_FREQUENCY_ALWAYS:
+ raise SyncInducedError
+ # This means the FIRST 2 requests of every 3 requests
+ # return an error. Don't switch the order of failures. There are
+ # test cases that rely on the first 2 being the failure rather than
+ # the last 2.
+ elif (self.account.induced_error_frequency ==
+ ERROR_FREQUENCY_TWO_THIRDS):
+ if (((self.sync_count -
+ self.account.sync_count_before_errors) % 3) != 0):
+ raise SyncInducedError
+ else:
+ raise InducedErrorFrequencyNotDefined
+
+ def HandleMigrate(self, path):
+ query = urlparse.urlparse(path)[4]
+ code = 200
+ self.account_lock.acquire()
+ try:
+ datatypes = [DataTypeStringToSyncTypeLoose(x)
+ for x in urlparse.parse_qs(query).get('type',[])]
+ if datatypes:
+ self.account.TriggerMigration(datatypes)
+ response = 'Migrated datatypes %s' % (
+ ' and '.join(SyncTypeToString(x).upper() for x in datatypes))
+ else:
+ response = 'Please specify one or more <i>type=name</i> parameters'
+ code = 400
+ except DataTypeIdNotRecognized, error:
+ response = 'Could not interpret datatype name'
+ code = 400
+ finally:
+ self.account_lock.release()
+ return (code, '<html><title>Migration: %d</title><H1>%d %s</H1></html>' %
+ (code, code, response))
+
+ def HandleSetInducedError(self, path):
+ query = urlparse.urlparse(path)[4]
+ self.account_lock.acquire()
+ code = 200
+ response = 'Success'
+ error = sync_pb2.ClientToServerResponse.Error()
+ try:
+ error_type = urlparse.parse_qs(query)['error']
+ action = urlparse.parse_qs(query)['action']
+ error.error_type = int(error_type[0])
+ error.action = int(action[0])
+ try:
+ error.url = (urlparse.parse_qs(query)['url'])[0]
+ except KeyError:
+ error.url = ''
+ try:
+ error.error_description =(
+ (urlparse.parse_qs(query)['error_description'])[0])
+ except KeyError:
+ error.error_description = ''
+ try:
+ error_frequency = int((urlparse.parse_qs(query)['frequency'])[0])
+ except KeyError:
+ error_frequency = ERROR_FREQUENCY_ALWAYS
+ self.account.SetInducedError(error, error_frequency, self.sync_count)
+ response = ('Error = %d, action = %d, url = %s, description = %s' %
+ (error.error_type, error.action,
+ error.url,
+ error.error_description))
+ except error:
+ response = 'Could not parse url'
+ code = 400
+ finally:
+ self.account_lock.release()
+ return (code, '<html><title>SetError: %d</title><H1>%d %s</H1></html>' %
+ (code, code, response))
+
+ def HandleCreateBirthdayError(self):
+ self.account.ResetStoreBirthday()
+ return (
+ 200,
+ '<html><title>Birthday error</title><H1>Birthday error</H1></html>')
+
+ def HandleSetTransientError(self):
+ self.transient_error = True
+ return (
+ 200,
+ '<html><title>Transient error</title><H1>Transient error</H1></html>')
+
+ def HandleSetSyncTabFavicons(self):
+ """Set 'sync_tab_favicons' field of the nigori node for this account."""
+ self.account.TriggerSyncTabFavicons()
+ return (
+ 200,
+ '<html><title>Tab Favicons</title><H1>Tab Favicons</H1></html>')
+
+ def HandleCreateSyncedBookmarks(self):
+ """Create the Synced Bookmarks folder under Bookmarks."""
+ self.account.TriggerCreateSyncedBookmarks()
+ return (
+ 200,
+ '<html><title>Synced Bookmarks</title><H1>Synced Bookmarks</H1></html>')
+
+ def HandleEnableKeystoreEncryption(self):
+ """Enables the keystore encryption experiment."""
+ self.account.TriggerEnableKeystoreEncryption()
+ return (
+ 200,
+ '<html><title>Enable Keystore Encryption</title>'
+ '<H1>Enable Keystore Encryption</H1></html>')
+
+ def HandleRotateKeystoreKeys(self):
+ """Rotate the keystore encryption keys."""
+ self.account.TriggerRotateKeystoreKeys()
+ return (
+ 200,
+ '<html><title>Rotate Keystore Keys</title>'
+ '<H1>Rotate Keystore Keys</H1></html>')
+
+ def HandleCommand(self, query, raw_request):
+ """Decode and handle a sync command from a raw input of bytes.
+
+ This is the main entry point for this class. It is safe to call this
+ method from multiple threads.
+
+ Args:
+ raw_request: An iterable byte sequence to be interpreted as a sync
+ protocol command.
+ Returns:
+ A tuple (response_code, raw_response); the first value is an HTTP
+ result code, while the second value is a string of bytes which is the
+ serialized reply to the command.
+ """
+ self.account_lock.acquire()
+ self.sync_count += 1
+ def print_context(direction):
+ print '[Client %s %s %s.py]' % (self.GetShortClientName(query), direction,
+ __name__),
+
+ try:
+ request = sync_pb2.ClientToServerMessage()
+ request.MergeFromString(raw_request)
+ contents = request.message_contents
+
+ response = sync_pb2.ClientToServerResponse()
+ response.error_code = sync_enums_pb2.SyncEnums.SUCCESS
+ self.CheckStoreBirthday(request)
+ response.store_birthday = self.account.store_birthday
+ self.CheckTransientError()
+ self.CheckSendError()
+
+ print_context('->')
+
+ if contents == sync_pb2.ClientToServerMessage.AUTHENTICATE:
+ print 'Authenticate'
+ # We accept any authentication token, and support only one account.
+ # TODO(nick): Mock out the GAIA authentication as well; hook up here.
+ response.authenticate.user.email = 'syncjuser@chromium'
+ response.authenticate.user.display_name = 'Sync J User'
+ elif contents == sync_pb2.ClientToServerMessage.COMMIT:
+ print 'Commit %d item(s)' % len(request.commit.entries)
+ self.HandleCommit(request.commit, response.commit)
+ elif contents == sync_pb2.ClientToServerMessage.GET_UPDATES:
+ print 'GetUpdates',
+ self.HandleGetUpdates(request.get_updates, response.get_updates)
+ print_context('<-')
+ print '%d update(s)' % len(response.get_updates.entries)
+ else:
+ print 'Unrecognizable sync request!'
+ return (400, None) # Bad request.
+ return (200, response.SerializeToString())
+ except MigrationDoneError, error:
+ print_context('<-')
+ print 'MIGRATION_DONE: <%s>' % (ShortDatatypeListSummary(error.datatypes))
+ response = sync_pb2.ClientToServerResponse()
+ response.store_birthday = self.account.store_birthday
+ response.error_code = sync_enums_pb2.SyncEnums.MIGRATION_DONE
+ response.migrated_data_type_id[:] = [
+ SyncTypeToProtocolDataTypeId(x) for x in error.datatypes]
+ return (200, response.SerializeToString())
+ except StoreBirthdayError, error:
+ print_context('<-')
+ print 'NOT_MY_BIRTHDAY'
+ response = sync_pb2.ClientToServerResponse()
+ response.store_birthday = self.account.store_birthday
+ response.error_code = sync_enums_pb2.SyncEnums.NOT_MY_BIRTHDAY
+ return (200, response.SerializeToString())
+ except TransientError, error:
+ ### This is deprecated now. Would be removed once test cases are removed.
+ print_context('<-')
+ print 'TRANSIENT_ERROR'
+ response.store_birthday = self.account.store_birthday
+ response.error_code = sync_enums_pb2.SyncEnums.TRANSIENT_ERROR
+ return (200, response.SerializeToString())
+ except SyncInducedError, error:
+ print_context('<-')
+ print 'INDUCED_ERROR'
+ response.store_birthday = self.account.store_birthday
+ error = self.account.GetInducedError()
+ response.error.error_type = error.error_type
+ response.error.url = error.url
+ response.error.error_description = error.error_description
+ response.error.action = error.action
+ return (200, response.SerializeToString())
+ finally:
+ self.account_lock.release()
+
+ def HandleCommit(self, commit_message, commit_response):
+ """Respond to a Commit request by updating the user's account state.
+
+ Commit attempts stop after the first error, returning a CONFLICT result
+ for any unattempted entries.
+
+ Args:
+ commit_message: A sync_pb.CommitMessage protobuf holding the content
+ of the client's request.
+ commit_response: A sync_pb.CommitResponse protobuf into which a reply
+ to the client request will be written.
+ """
+ commit_response.SetInParent()
+ batch_failure = False
+ session = {} # Tracks ID renaming during the commit operation.
+ guid = commit_message.cache_guid
+
+ self.account.ValidateCommitEntries(commit_message.entries)
+
+ for entry in commit_message.entries:
+ server_entry = None
+ if not batch_failure:
+ # Try to commit the change to the account.
+ server_entry = self.account.CommitEntry(entry, guid, session)
+
+ # An entryresponse is returned in both success and failure cases.
+ reply = commit_response.entryresponse.add()
+ if not server_entry:
+ reply.response_type = sync_pb2.CommitResponse.CONFLICT
+ reply.error_message = 'Conflict.'
+ batch_failure = True # One failure halts the batch.
+ else:
+ reply.response_type = sync_pb2.CommitResponse.SUCCESS
+ # These are the properties that the server is allowed to override
+ # during commit; the client wants to know their values at the end
+ # of the operation.
+ reply.id_string = server_entry.id_string
+ if not server_entry.deleted:
+ # Note: the production server doesn't actually send the
+ # parent_id_string on commit responses, so we don't either.
+ reply.position_in_parent = server_entry.position_in_parent
+ reply.version = server_entry.version
+ reply.name = server_entry.name
+ reply.non_unique_name = server_entry.non_unique_name
+ else:
+ reply.version = entry.version + 1
+
+ def HandleGetUpdates(self, update_request, update_response):
+ """Respond to a GetUpdates request by querying the user's account.
+
+ Args:
+ update_request: A sync_pb.GetUpdatesMessage protobuf holding the content
+ of the client's request.
+ update_response: A sync_pb.GetUpdatesResponse protobuf into which a reply
+ to the client request will be written.
+ """
+ update_response.SetInParent()
+ update_sieve = UpdateSieve(update_request, self.account.migration_history)
+
+ print CallerInfoToString(update_request.caller_info.source),
+ print update_sieve.SummarizeRequest()
+
+ update_sieve.CheckMigrationState()
+
+ new_timestamp, entries, remaining = self.account.GetChanges(update_sieve)
+
+ update_response.changes_remaining = remaining
+ sending_nigori_node = False
+ for entry in entries:
+ if entry.name == 'Nigori':
+ sending_nigori_node = True
+ reply = update_response.entries.add()
+ reply.CopyFrom(entry)
+ update_sieve.SaveProgress(new_timestamp, update_response)
+
+ if update_request.need_encryption_key or sending_nigori_node:
+ update_response.encryption_keys.extend(self.account.GetKeystoreKeys())
diff --git a/sync/tools/testserver/chromiumsync_test.py b/sync/tools/testserver/chromiumsync_test.py
new file mode 100755
index 0000000..e56c04b
--- /dev/null
+++ b/sync/tools/testserver/chromiumsync_test.py
@@ -0,0 +1,655 @@
+#!/usr/bin/env python
+# Copyright 2013 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Tests exercising chromiumsync and SyncDataModel."""
+
+import pickle
+import unittest
+
+import autofill_specifics_pb2
+import bookmark_specifics_pb2
+import chromiumsync
+import sync_pb2
+import theme_specifics_pb2
+
+class SyncDataModelTest(unittest.TestCase):
+ def setUp(self):
+ self.model = chromiumsync.SyncDataModel()
+ # The Synced Bookmarks folder is not created by default
+ self._expect_synced_bookmarks_folder = False
+
+ def AddToModel(self, proto):
+ self.model._entries[proto.id_string] = proto
+
+ def GetChangesFromTimestamp(self, requested_types, timestamp):
+ message = sync_pb2.GetUpdatesMessage()
+ message.from_timestamp = timestamp
+ for data_type in requested_types:
+ getattr(message.requested_types,
+ chromiumsync.SYNC_TYPE_TO_DESCRIPTOR[
+ data_type].name).SetInParent()
+ return self.model.GetChanges(
+ chromiumsync.UpdateSieve(message, self.model.migration_history))
+
+ def FindMarkerByNumber(self, markers, datatype):
+ """Search a list of progress markers and find the one for a datatype."""
+ for marker in markers:
+ if marker.data_type_id == datatype.number:
+ return marker
+ self.fail('Required marker not found: %s' % datatype.name)
+
+ def testPermanentItemSpecs(self):
+ specs = chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS
+
+ declared_specs = set(['0'])
+ for spec in specs:
+ self.assertTrue(spec.parent_tag in declared_specs, 'parent tags must '
+ 'be declared before use')
+ declared_specs.add(spec.tag)
+
+ unique_datatypes = set([x.sync_type for x in specs])
+ self.assertEqual(unique_datatypes,
+ set(chromiumsync.ALL_TYPES[1:]),
+ 'Every sync datatype should have a permanent folder '
+ 'associated with it')
+
+ def testSaveEntry(self):
+ proto = sync_pb2.SyncEntity()
+ proto.id_string = 'abcd'
+ proto.version = 0
+ self.assertFalse(self.model._ItemExists(proto.id_string))
+ self.model._SaveEntry(proto)
+ self.assertEqual(1, proto.version)
+ self.assertTrue(self.model._ItemExists(proto.id_string))
+ self.model._SaveEntry(proto)
+ self.assertEqual(2, proto.version)
+ proto.version = 0
+ self.assertTrue(self.model._ItemExists(proto.id_string))
+ self.assertEqual(2, self.model._entries[proto.id_string].version)
+
+ def testCreatePermanentItems(self):
+ self.model._CreateDefaultPermanentItems(chromiumsync.ALL_TYPES)
+ self.assertEqual(len(chromiumsync.ALL_TYPES) + 1,
+ len(self.model._entries))
+
+ def ExpectedPermanentItemCount(self, sync_type):
+ if sync_type == chromiumsync.BOOKMARK:
+ if self._expect_synced_bookmarks_folder:
+ return 4
+ else:
+ return 3
+ else:
+ return 1
+
+ def testGetChangesFromTimestampZeroForEachType(self):
+ all_types = chromiumsync.ALL_TYPES[1:]
+ for sync_type in all_types:
+ self.model = chromiumsync.SyncDataModel()
+ request_types = [sync_type]
+
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp(request_types, 0))
+
+ expected_count = self.ExpectedPermanentItemCount(sync_type)
+ self.assertEqual(expected_count, version)
+ self.assertEqual(expected_count, len(changes))
+ for change in changes:
+ self.assertTrue(change.HasField('server_defined_unique_tag'))
+ self.assertEqual(change.version, change.sync_timestamp)
+ self.assertTrue(change.version <= version)
+
+ # Test idempotence: another GetUpdates from ts=0 shouldn't recreate.
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp(request_types, 0))
+ self.assertEqual(expected_count, version)
+ self.assertEqual(expected_count, len(changes))
+ self.assertEqual(0, remaining)
+
+ # Doing a wider GetUpdates from timestamp zero shouldn't recreate either.
+ new_version, changes, remaining = (
+ self.GetChangesFromTimestamp(all_types, 0))
+ if self._expect_synced_bookmarks_folder:
+ self.assertEqual(len(chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS),
+ new_version)
+ else:
+ self.assertEqual(
+ len(chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS) -1,
+ new_version)
+ self.assertEqual(new_version, len(changes))
+ self.assertEqual(0, remaining)
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp(request_types, 0))
+ self.assertEqual(new_version, version)
+ self.assertEqual(expected_count, len(changes))
+ self.assertEqual(0, remaining)
+
+ def testBatchSize(self):
+ for sync_type in chromiumsync.ALL_TYPES[1:]:
+ specifics = chromiumsync.GetDefaultEntitySpecifics(sync_type)
+ self.model = chromiumsync.SyncDataModel()
+ request_types = [sync_type]
+
+ for i in range(self.model._BATCH_SIZE*3):
+ entry = sync_pb2.SyncEntity()
+ entry.id_string = 'batch test %d' % i
+ entry.specifics.CopyFrom(specifics)
+ self.model._SaveEntry(entry)
+ last_bit = self.ExpectedPermanentItemCount(sync_type)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, 0))
+ self.assertEqual(self.model._BATCH_SIZE, version)
+ self.assertEqual(self.model._BATCH_SIZE*2 + last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(self.model._BATCH_SIZE*2, version)
+ self.assertEqual(self.model._BATCH_SIZE + last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(self.model._BATCH_SIZE*3, version)
+ self.assertEqual(last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(self.model._BATCH_SIZE*3 + last_bit, version)
+ self.assertEqual(0, changes_remaining)
+
+ # Now delete a third of the items.
+ for i in xrange(self.model._BATCH_SIZE*3 - 1, 0, -3):
+ entry = sync_pb2.SyncEntity()
+ entry.id_string = 'batch test %d' % i
+ entry.deleted = True
+ self.model._SaveEntry(entry)
+
+ # The batch counts shouldn't change.
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, 0))
+ self.assertEqual(self.model._BATCH_SIZE, len(changes))
+ self.assertEqual(self.model._BATCH_SIZE*2 + last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(self.model._BATCH_SIZE, len(changes))
+ self.assertEqual(self.model._BATCH_SIZE + last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(self.model._BATCH_SIZE, len(changes))
+ self.assertEqual(last_bit, changes_remaining)
+ version, changes, changes_remaining = (
+ self.GetChangesFromTimestamp(request_types, version))
+ self.assertEqual(last_bit, len(changes))
+ self.assertEqual(self.model._BATCH_SIZE*4 + last_bit, version)
+ self.assertEqual(0, changes_remaining)
+
+ def testCommitEachDataType(self):
+ for sync_type in chromiumsync.ALL_TYPES[1:]:
+ specifics = chromiumsync.GetDefaultEntitySpecifics(sync_type)
+ self.model = chromiumsync.SyncDataModel()
+ my_cache_guid = '112358132134'
+ parent = 'foobar'
+ commit_session = {}
+
+ # Start with a GetUpdates from timestamp 0, to populate permanent items.
+ original_version, original_changes, changes_remaining = (
+ self.GetChangesFromTimestamp([sync_type], 0))
+
+ def DoCommit(original=None, id_string='', name=None, parent=None,
+ position=0):
+ proto = sync_pb2.SyncEntity()
+ if original is not None:
+ proto.version = original.version
+ proto.id_string = original.id_string
+ proto.parent_id_string = original.parent_id_string
+ proto.name = original.name
+ else:
+ proto.id_string = id_string
+ proto.version = 0
+ proto.specifics.CopyFrom(specifics)
+ if name is not None:
+ proto.name = name
+ if parent:
+ proto.parent_id_string = parent.id_string
+ proto.insert_after_item_id = 'please discard'
+ proto.position_in_parent = position
+ proto.folder = True
+ proto.deleted = False
+ result = self.model.CommitEntry(proto, my_cache_guid, commit_session)
+ self.assertTrue(result)
+ return (proto, result)
+
+ # Commit a new item.
+ proto1, result1 = DoCommit(name='namae', id_string='Foo',
+ parent=original_changes[-1], position=100)
+ # Commit an item whose parent is another item (referenced via the
+ # pre-commit ID).
+ proto2, result2 = DoCommit(name='Secondo', id_string='Bar',
+ parent=proto1, position=-100)
+ # Commit a sibling of the second item.
+ proto3, result3 = DoCommit(name='Third!', id_string='Baz',
+ parent=proto1, position=-50)
+
+ self.assertEqual(3, len(commit_session))
+ for p, r in [(proto1, result1), (proto2, result2), (proto3, result3)]:
+ self.assertNotEqual(r.id_string, p.id_string)
+ self.assertEqual(r.originator_client_item_id, p.id_string)
+ self.assertEqual(r.originator_cache_guid, my_cache_guid)
+ self.assertTrue(r is not self.model._entries[r.id_string],
+ "Commit result didn't make a defensive copy.")
+ self.assertTrue(p is not self.model._entries[r.id_string],
+ "Commit result didn't make a defensive copy.")
+ self.assertEqual(commit_session.get(p.id_string), r.id_string)
+ self.assertTrue(r.version > original_version)
+ self.assertEqual(result1.parent_id_string, proto1.parent_id_string)
+ self.assertEqual(result2.parent_id_string, result1.id_string)
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp([sync_type], original_version))
+ self.assertEqual(3, len(changes))
+ self.assertEqual(0, remaining)
+ self.assertEqual(original_version + 3, version)
+ self.assertEqual([result1, result2, result3], changes)
+ for c in changes:
+ self.assertTrue(c is not self.model._entries[c.id_string],
+ "GetChanges didn't make a defensive copy.")
+ self.assertTrue(result2.position_in_parent < result3.position_in_parent)
+ self.assertEqual(-100, result2.position_in_parent)
+
+ # Now update the items so that the second item is the parent of the
+ # first; with the first sandwiched between two new items (4 and 5).
+ # Do this in a new commit session, meaning we'll reference items from
+ # the first batch by their post-commit, server IDs.
+ commit_session = {}
+ old_cache_guid = my_cache_guid
+ my_cache_guid = 'A different GUID'
+ proto2b, result2b = DoCommit(original=result2,
+ parent=original_changes[-1])
+ proto4, result4 = DoCommit(id_string='ID4', name='Four',
+ parent=result2, position=-200)
+ proto1b, result1b = DoCommit(original=result1,
+ parent=result2, position=-150)
+ proto5, result5 = DoCommit(id_string='ID5', name='Five', parent=result2,
+ position=150)
+
+ self.assertEqual(2, len(commit_session), 'Only new items in second '
+ 'batch should be in the session')
+ for p, r, original in [(proto2b, result2b, proto2),
+ (proto4, result4, proto4),
+ (proto1b, result1b, proto1),
+ (proto5, result5, proto5)]:
+ self.assertEqual(r.originator_client_item_id, original.id_string)
+ if original is not p:
+ self.assertEqual(r.id_string, p.id_string,
+ 'Ids should be stable after first commit')
+ self.assertEqual(r.originator_cache_guid, old_cache_guid)
+ else:
+ self.assertNotEqual(r.id_string, p.id_string)
+ self.assertEqual(r.originator_cache_guid, my_cache_guid)
+ self.assertEqual(commit_session.get(p.id_string), r.id_string)
+ self.assertTrue(r is not self.model._entries[r.id_string],
+ "Commit result didn't make a defensive copy.")
+ self.assertTrue(p is not self.model._entries[r.id_string],
+ "Commit didn't make a defensive copy.")
+ self.assertTrue(r.version > p.version)
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp([sync_type], original_version))
+ self.assertEqual(5, len(changes))
+ self.assertEqual(0, remaining)
+ self.assertEqual(original_version + 7, version)
+ self.assertEqual([result3, result2b, result4, result1b, result5], changes)
+ for c in changes:
+ self.assertTrue(c is not self.model._entries[c.id_string],
+ "GetChanges didn't make a defensive copy.")
+ self.assertTrue(result4.parent_id_string ==
+ result1b.parent_id_string ==
+ result5.parent_id_string ==
+ result2b.id_string)
+ self.assertTrue(result4.position_in_parent <
+ result1b.position_in_parent <
+ result5.position_in_parent)
+
+ def testUpdateSieve(self):
+ # from_timestamp, legacy mode
+ autofill = chromiumsync.SYNC_TYPE_FIELDS['autofill']
+ theme = chromiumsync.SYNC_TYPE_FIELDS['theme']
+ msg = sync_pb2.GetUpdatesMessage()
+ msg.from_timestamp = 15412
+ msg.requested_types.autofill.SetInParent()
+ msg.requested_types.theme.SetInParent()
+
+ sieve = chromiumsync.UpdateSieve(msg)
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 15412,
+ chromiumsync.AUTOFILL: 15412,
+ chromiumsync.THEME: 15412})
+
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(15412, response)
+ self.assertEqual(0, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(15413, response)
+ self.assertEqual(0, len(response.new_progress_marker))
+ self.assertTrue(response.HasField('new_timestamp'))
+ self.assertEqual(15413, response.new_timestamp)
+
+ # Existing tokens
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = pickle.dumps((15412, 1))
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 1))
+ sieve = chromiumsync.UpdateSieve(msg)
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 15412,
+ chromiumsync.AUTOFILL: 15412,
+ chromiumsync.THEME: 15413})
+
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(15413, response)
+ self.assertEqual(1, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = response.new_progress_marker[0]
+ self.assertEqual(marker.data_type_id, autofill.number)
+ self.assertEqual(pickle.loads(marker.token), (15413, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ # Empty tokens indicating from timestamp = 0
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = pickle.dumps((412, 1))
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = ''
+ sieve = chromiumsync.UpdateSieve(msg)
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 0,
+ chromiumsync.AUTOFILL: 412,
+ chromiumsync.THEME: 0})
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(1, response)
+ self.assertEqual(1, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = response.new_progress_marker[0]
+ self.assertEqual(marker.data_type_id, theme.number)
+ self.assertEqual(pickle.loads(marker.token), (1, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(412, response)
+ self.assertEqual(1, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = response.new_progress_marker[0]
+ self.assertEqual(marker.data_type_id, theme.number)
+ self.assertEqual(pickle.loads(marker.token), (412, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(413, response)
+ self.assertEqual(2, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, theme)
+ self.assertEqual(pickle.loads(marker.token), (413, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, autofill)
+ self.assertEqual(pickle.loads(marker.token), (413, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ # Migration token timestamps (client gives timestamp, server returns token)
+ # These are for migrating from the old 'timestamp' protocol to the
+ # progressmarker protocol, and have nothing to do with the MIGRATION_DONE
+ # error code.
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.timestamp_token_for_migration = 15213
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.timestamp_token_for_migration = 15211
+ sieve = chromiumsync.UpdateSieve(msg)
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 15211,
+ chromiumsync.AUTOFILL: 15213,
+ chromiumsync.THEME: 15211})
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(16000, response) # There were updates
+ self.assertEqual(2, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, theme)
+ self.assertEqual(pickle.loads(marker.token), (16000, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, autofill)
+ self.assertEqual(pickle.loads(marker.token), (16000, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.timestamp_token_for_migration = 3000
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.timestamp_token_for_migration = 3000
+ sieve = chromiumsync.UpdateSieve(msg)
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 3000,
+ chromiumsync.AUTOFILL: 3000,
+ chromiumsync.THEME: 3000})
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(3000, response) # Already up to date
+ self.assertEqual(2, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, theme)
+ self.assertEqual(pickle.loads(marker.token), (3000, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, autofill)
+ self.assertEqual(pickle.loads(marker.token), (3000, 1))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+
+ def testCheckRaiseTransientError(self):
+ testserver = chromiumsync.TestServer()
+ http_code, raw_respon = testserver.HandleSetTransientError()
+ self.assertEqual(http_code, 200)
+ try:
+ testserver.CheckTransientError()
+ self.fail('Should have raised transient error exception')
+ except chromiumsync.TransientError:
+ self.assertTrue(testserver.transient_error)
+
+ def testUpdateSieveStoreMigration(self):
+ autofill = chromiumsync.SYNC_TYPE_FIELDS['autofill']
+ theme = chromiumsync.SYNC_TYPE_FIELDS['theme']
+ migrator = chromiumsync.MigrationHistory()
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = pickle.dumps((15412, 1))
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 1))
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+
+ migrator.Bump([chromiumsync.BOOKMARK, chromiumsync.PASSWORD]) # v=2
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+ self.assertEqual(sieve._state,
+ {chromiumsync.TOP_LEVEL: 15412,
+ chromiumsync.AUTOFILL: 15412,
+ chromiumsync.THEME: 15413})
+
+ migrator.Bump([chromiumsync.AUTOFILL, chromiumsync.PASSWORD]) # v=3
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ try:
+ sieve.CheckMigrationState()
+ self.fail('Should have raised.')
+ except chromiumsync.MigrationDoneError, error:
+ # We want this to happen.
+ self.assertEqual([chromiumsync.AUTOFILL], error.datatypes)
+
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = ''
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 1))
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(15412, response) # There were updates
+ self.assertEqual(1, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, autofill)
+ self.assertEqual(pickle.loads(marker.token), (15412, 3))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = pickle.dumps((15412, 3))
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 1))
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+
+ migrator.Bump([chromiumsync.THEME, chromiumsync.AUTOFILL]) # v=4
+ migrator.Bump([chromiumsync.AUTOFILL]) # v=5
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ try:
+ sieve.CheckMigrationState()
+ self.fail("Should have raised.")
+ except chromiumsync.MigrationDoneError, error:
+ # We want this to happen.
+ self.assertEqual(set([chromiumsync.THEME, chromiumsync.AUTOFILL]),
+ set(error.datatypes))
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = ''
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 1))
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ try:
+ sieve.CheckMigrationState()
+ self.fail("Should have raised.")
+ except chromiumsync.MigrationDoneError, error:
+ # We want this to happen.
+ self.assertEqual([chromiumsync.THEME], error.datatypes)
+
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = ''
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = ''
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+ response = sync_pb2.GetUpdatesResponse()
+ sieve.SaveProgress(15412, response) # There were updates
+ self.assertEqual(2, len(response.new_progress_marker))
+ self.assertFalse(response.HasField('new_timestamp'))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, autofill)
+ self.assertEqual(pickle.loads(marker.token), (15412, 5))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ marker = self.FindMarkerByNumber(response.new_progress_marker, theme)
+ self.assertEqual(pickle.loads(marker.token), (15412, 4))
+ self.assertFalse(marker.HasField('timestamp_token_for_migration'))
+ msg = sync_pb2.GetUpdatesMessage()
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = autofill.number
+ marker.token = pickle.dumps((15412, 5))
+ marker = msg.from_progress_marker.add()
+ marker.data_type_id = theme.number
+ marker.token = pickle.dumps((15413, 4))
+ sieve = chromiumsync.UpdateSieve(msg, migrator)
+ sieve.CheckMigrationState()
+
+ def testCreateSyncedBookmaks(self):
+ version1, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], 0))
+ id_string = self.model._MakeCurrentId(chromiumsync.BOOKMARK,
+ '<server tag>synced_bookmarks')
+ self.assertFalse(self.model._ItemExists(id_string))
+ self._expect_synced_bookmarks_folder = True
+ self.model.TriggerCreateSyncedBookmarks()
+ self.assertTrue(self.model._ItemExists(id_string))
+
+ # Check that the version changed when the folder was created and the only
+ # change was the folder creation.
+ version2, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], version1))
+ self.assertEqual(len(changes), 1)
+ self.assertEqual(changes[0].id_string, id_string)
+ self.assertNotEqual(version1, version2)
+ self.assertEqual(
+ self.ExpectedPermanentItemCount(chromiumsync.BOOKMARK),
+ version2)
+
+ # Ensure getting from timestamp 0 includes the folder.
+ version, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], 0))
+ self.assertEqual(
+ self.ExpectedPermanentItemCount(chromiumsync.BOOKMARK),
+ len(changes))
+ self.assertEqual(version2, version)
+
+ def testGetKey(self):
+ [key1] = self.model.GetKeystoreKeys()
+ [key2] = self.model.GetKeystoreKeys()
+ self.assertTrue(len(key1))
+ self.assertEqual(key1, key2)
+
+ # Trigger the rotation. A subsequent GetUpdates should return the nigori
+ # node (whose timestamp was bumped by the rotation).
+ version1, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.NIGORI], 0))
+ self.model.TriggerRotateKeystoreKeys()
+ version2, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.NIGORI], version1))
+ self.assertNotEqual(version1, version2)
+ self.assertEquals(len(changes), 1)
+ self.assertEquals(changes[0].name, "Nigori")
+
+ # The current keys should contain the old keys, with the new key appended.
+ [key1, key3] = self.model.GetKeystoreKeys()
+ self.assertEquals(key1, key2)
+ self.assertNotEqual(key1, key3)
+ self.assertTrue(len(key3) > 0)
+
+ def testTriggerEnableKeystoreEncryption(self):
+ version1, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.EXPERIMENTS], 0))
+ keystore_encryption_id_string = (
+ self.model._ClientTagToId(
+ chromiumsync.EXPERIMENTS,
+ chromiumsync.KEYSTORE_ENCRYPTION_EXPERIMENT_TAG))
+
+ self.assertFalse(self.model._ItemExists(keystore_encryption_id_string))
+ self.model.TriggerEnableKeystoreEncryption()
+ self.assertTrue(self.model._ItemExists(keystore_encryption_id_string))
+
+ # The creation of the experiment should be downloaded on the next
+ # GetUpdates.
+ version2, changes, remaining = (
+ self.GetChangesFromTimestamp([chromiumsync.EXPERIMENTS], version1))
+ self.assertEqual(len(changes), 1)
+ self.assertEqual(changes[0].id_string, keystore_encryption_id_string)
+ self.assertNotEqual(version1, version2)
+
+ # Verify the experiment was created properly and is enabled.
+ self.assertEqual(chromiumsync.KEYSTORE_ENCRYPTION_EXPERIMENT_TAG,
+ changes[0].client_defined_unique_tag)
+ self.assertTrue(changes[0].HasField("specifics"))
+ self.assertTrue(changes[0].specifics.HasField("experiments"))
+ self.assertTrue(
+ changes[0].specifics.experiments.HasField("keystore_encryption"))
+ self.assertTrue(
+ changes[0].specifics.experiments.keystore_encryption.enabled)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sync/tools/testserver/run_sync_testserver.cc b/sync/tools/testserver/run_sync_testserver.cc
new file mode 100644
index 0000000..74f186d
--- /dev/null
+++ b/sync/tools/testserver/run_sync_testserver.cc
@@ -0,0 +1,121 @@
+// Copyright 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <stdio.h>
+
+#include "base/at_exit.h"
+#include "base/command_line.h"
+#include "base/file_path.h"
+#include "base/logging.h"
+#include "base/message_loop.h"
+#include "base/process_util.h"
+#include "base/string_number_conversions.h"
+#include "base/test/test_timeouts.h"
+#include "net/test/python_utils.h"
+#include "sync/test/local_sync_test_server.h"
+
+static void PrintUsage() {
+ printf("run_sync_testserver [--port=<port>] [--xmpp-port=<xmpp_port>]\n");
+}
+
+// Launches the chromiumsync_test.py or xmppserver_test.py scripts, which test
+// the sync HTTP and XMPP sever functionality respectively.
+static bool RunSyncTest(const FilePath::StringType& sync_test_script_name) {
+ scoped_ptr<syncer::LocalSyncTestServer> test_server(
+ new syncer::LocalSyncTestServer());
+ if (!test_server->SetPythonPath()) {
+ LOG(ERROR) << "Error trying to set python path. Exiting.";
+ return false;
+ }
+
+ FilePath sync_test_script_path;
+ if (!test_server->GetTestScriptPath(sync_test_script_name,
+ &sync_test_script_path)) {
+ LOG(ERROR) << "Error trying to get path for test script "
+ << sync_test_script_name;
+ return false;
+ }
+
+ CommandLine python_command(CommandLine::NO_PROGRAM);
+ if (!GetPythonCommand(&python_command)) {
+ LOG(ERROR) << "Could not get python runtime command.";
+ return false;
+ }
+
+ python_command.AppendArgPath(sync_test_script_path);
+ if (!base::LaunchProcess(python_command, base::LaunchOptions(), NULL)) {
+ LOG(ERROR) << "Failed to launch test script " << sync_test_script_name;
+ return false;
+ }
+ return true;
+}
+
+// Gets a port value from the switch with name |switch_name| and writes it to
+// |port|. Returns true if a port was provided and false otherwise.
+static bool GetPortFromSwitch(const std::string& switch_name, uint16* port) {
+ DCHECK(port != NULL) << "|port| is NULL";
+ CommandLine* command_line = CommandLine::ForCurrentProcess();
+ int port_int = 0;
+ if (command_line->HasSwitch(switch_name)) {
+ std::string port_str = command_line->GetSwitchValueASCII(switch_name);
+ if (!base::StringToInt(port_str, &port_int)) {
+ return false;
+ }
+ }
+ *port = static_cast<uint16>(port_int);
+ return true;
+}
+
+int main(int argc, const char* argv[]) {
+ base::AtExitManager at_exit_manager;
+ MessageLoopForIO message_loop;
+
+ // Process command line
+ CommandLine::Init(argc, argv);
+ CommandLine* command_line = CommandLine::ForCurrentProcess();
+
+ if (!logging::InitLogging(
+ FILE_PATH_LITERAL("sync_testserver.log"),
+ logging::LOG_TO_BOTH_FILE_AND_SYSTEM_DEBUG_LOG,
+ logging::LOCK_LOG_FILE,
+ logging::APPEND_TO_OLD_LOG_FILE,
+ logging::DISABLE_DCHECK_FOR_NON_OFFICIAL_RELEASE_BUILDS)) {
+ printf("Error: could not initialize logging. Exiting.\n");
+ return -1;
+ }
+
+ TestTimeouts::Initialize();
+
+ if (command_line->HasSwitch("help")) {
+ PrintUsage();
+ return 0;
+ }
+
+ if (command_line->HasSwitch("sync-test")) {
+ return RunSyncTest(FILE_PATH_LITERAL("chromiumsync_test.py")) ? 0 : -1;
+ }
+
+ if (command_line->HasSwitch("xmpp-test")) {
+ return RunSyncTest(FILE_PATH_LITERAL("xmppserver_test.py")) ? 0 : -1;
+ }
+
+ uint16 port = 0;
+ GetPortFromSwitch("port", &port);
+
+ uint16 xmpp_port = 0;
+ GetPortFromSwitch("xmpp-port", &xmpp_port);
+
+ scoped_ptr<syncer::LocalSyncTestServer> test_server(
+ new syncer::LocalSyncTestServer(port, xmpp_port));
+ if (!test_server->Start()) {
+ printf("Error: failed to start python sync test server. Exiting.\n");
+ return -1;
+ }
+
+ printf("Python sync test server running at %s (type ctrl+c to exit)\n",
+ test_server->host_port_pair().ToString().c_str());
+
+ message_loop.Run();
+ return 0;
+}
diff --git a/sync/tools/testserver/sync_testserver.py b/sync/tools/testserver/sync_testserver.py
new file mode 100755
index 0000000..6925776
--- /dev/null
+++ b/sync/tools/testserver/sync_testserver.py
@@ -0,0 +1,447 @@
+#!/usr/bin/env python
+# Copyright 2013 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""This is a python sync server used for testing Chrome Sync.
+
+By default, it listens on an ephemeral port and xmpp_port and sends the port
+numbers back to the originating process over a pipe. The originating process can
+specify an explicit port and xmpp_port if necessary.
+"""
+
+import asyncore
+import BaseHTTPServer
+import errno
+import os
+import select
+import socket
+import sys
+import urlparse
+
+import chromiumsync
+import echo_message
+import testserver_base
+import xmppserver
+
+
+class SyncHTTPServer(testserver_base.ClientRestrictingServerMixIn,
+ testserver_base.BrokenPipeHandlerMixIn,
+ testserver_base.StoppableHTTPServer):
+ """An HTTP server that handles sync commands."""
+
+ def __init__(self, server_address, xmpp_port, request_handler_class):
+ testserver_base.StoppableHTTPServer.__init__(self,
+ server_address,
+ request_handler_class)
+ self._sync_handler = chromiumsync.TestServer()
+ self._xmpp_socket_map = {}
+ self._xmpp_server = xmppserver.XmppServer(
+ self._xmpp_socket_map, ('localhost', xmpp_port))
+ self.xmpp_port = self._xmpp_server.getsockname()[1]
+ self.authenticated = True
+
+ def GetXmppServer(self):
+ return self._xmpp_server
+
+ def HandleCommand(self, query, raw_request):
+ return self._sync_handler.HandleCommand(query, raw_request)
+
+ def HandleRequestNoBlock(self):
+ """Handles a single request.
+
+ Copied from SocketServer._handle_request_noblock().
+ """
+
+ try:
+ request, client_address = self.get_request()
+ except socket.error:
+ return
+ if self.verify_request(request, client_address):
+ try:
+ self.process_request(request, client_address)
+ except Exception:
+ self.handle_error(request, client_address)
+ self.close_request(request)
+
+ def SetAuthenticated(self, auth_valid):
+ self.authenticated = auth_valid
+
+ def GetAuthenticated(self):
+ return self.authenticated
+
+ def serve_forever(self):
+ """This is a merge of asyncore.loop() and SocketServer.serve_forever().
+ """
+
+ def HandleXmppSocket(fd, socket_map, handler):
+ """Runs the handler for the xmpp connection for fd.
+
+ Adapted from asyncore.read() et al.
+ """
+
+ xmpp_connection = socket_map.get(fd)
+ # This could happen if a previous handler call caused fd to get
+ # removed from socket_map.
+ if xmpp_connection is None:
+ return
+ try:
+ handler(xmpp_connection)
+ except (asyncore.ExitNow, KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ xmpp_connection.handle_error()
+
+ while True:
+ read_fds = [ self.fileno() ]
+ write_fds = []
+ exceptional_fds = []
+
+ for fd, xmpp_connection in self._xmpp_socket_map.items():
+ is_r = xmpp_connection.readable()
+ is_w = xmpp_connection.writable()
+ if is_r:
+ read_fds.append(fd)
+ if is_w:
+ write_fds.append(fd)
+ if is_r or is_w:
+ exceptional_fds.append(fd)
+
+ try:
+ read_fds, write_fds, exceptional_fds = (
+ select.select(read_fds, write_fds, exceptional_fds))
+ except select.error, err:
+ if err.args[0] != errno.EINTR:
+ raise
+ else:
+ continue
+
+ for fd in read_fds:
+ if fd == self.fileno():
+ self.HandleRequestNoBlock()
+ continue
+ HandleXmppSocket(fd, self._xmpp_socket_map,
+ asyncore.dispatcher.handle_read_event)
+
+ for fd in write_fds:
+ HandleXmppSocket(fd, self._xmpp_socket_map,
+ asyncore.dispatcher.handle_write_event)
+
+ for fd in exceptional_fds:
+ HandleXmppSocket(fd, self._xmpp_socket_map,
+ asyncore.dispatcher.handle_expt_event)
+
+
+class SyncPageHandler(testserver_base.BasePageHandler):
+ """Handler for the main HTTP sync server."""
+
+ def __init__(self, request, client_address, sync_http_server):
+ get_handlers = [self.ChromiumSyncTimeHandler,
+ self.ChromiumSyncMigrationOpHandler,
+ self.ChromiumSyncCredHandler,
+ self.ChromiumSyncXmppCredHandler,
+ self.ChromiumSyncDisableNotificationsOpHandler,
+ self.ChromiumSyncEnableNotificationsOpHandler,
+ self.ChromiumSyncSendNotificationOpHandler,
+ self.ChromiumSyncBirthdayErrorOpHandler,
+ self.ChromiumSyncTransientErrorOpHandler,
+ self.ChromiumSyncErrorOpHandler,
+ self.ChromiumSyncSyncTabFaviconsOpHandler,
+ self.ChromiumSyncCreateSyncedBookmarksOpHandler,
+ self.ChromiumSyncEnableKeystoreEncryptionOpHandler,
+ self.ChromiumSyncRotateKeystoreKeysOpHandler]
+
+ post_handlers = [self.ChromiumSyncCommandHandler,
+ self.ChromiumSyncTimeHandler]
+ testserver_base.BasePageHandler.__init__(self, request, client_address,
+ sync_http_server, [], get_handlers,
+ [], post_handlers, [])
+
+
+ def ChromiumSyncTimeHandler(self):
+ """Handle Chromium sync .../time requests.
+
+ The syncer sometimes checks server reachability by examining /time.
+ """
+
+ test_name = "/chromiumsync/time"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+
+ # Chrome hates it if we send a response before reading the request.
+ if self.headers.getheader('content-length'):
+ length = int(self.headers.getheader('content-length'))
+ _raw_request = self.rfile.read(length)
+
+ self.send_response(200)
+ self.send_header('Content-Type', 'text/plain')
+ self.end_headers()
+ self.wfile.write('0123456789')
+ return True
+
+ def ChromiumSyncCommandHandler(self):
+ """Handle a chromiumsync command arriving via http.
+
+ This covers all sync protocol commands: authentication, getupdates, and
+ commit.
+ """
+
+ test_name = "/chromiumsync/command"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+
+ length = int(self.headers.getheader('content-length'))
+ raw_request = self.rfile.read(length)
+ http_response = 200
+ raw_reply = None
+ if not self.server.GetAuthenticated():
+ http_response = 401
+ challenge = 'GoogleLogin realm="http://%s", service="chromiumsync"' % (
+ self.server.server_address[0])
+ else:
+ http_response, raw_reply = self.server.HandleCommand(
+ self.path, raw_request)
+
+ ### Now send the response to the client. ###
+ self.send_response(http_response)
+ if http_response == 401:
+ self.send_header('www-Authenticate', challenge)
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncMigrationOpHandler(self):
+ test_name = "/chromiumsync/migrate"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+
+ http_response, raw_reply = self.server._sync_handler.HandleMigrate(
+ self.path)
+ self.send_response(http_response)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncCredHandler(self):
+ test_name = "/chromiumsync/cred"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ try:
+ query = urlparse.urlparse(self.path)[4]
+ cred_valid = urlparse.parse_qs(query)['valid']
+ if cred_valid[0] == 'True':
+ self.server.SetAuthenticated(True)
+ else:
+ self.server.SetAuthenticated(False)
+ except Exception:
+ self.server.SetAuthenticated(False)
+
+ http_response = 200
+ raw_reply = 'Authenticated: %s ' % self.server.GetAuthenticated()
+ self.send_response(http_response)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncXmppCredHandler(self):
+ test_name = "/chromiumsync/xmppcred"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ xmpp_server = self.server.GetXmppServer()
+ try:
+ query = urlparse.urlparse(self.path)[4]
+ cred_valid = urlparse.parse_qs(query)['valid']
+ if cred_valid[0] == 'True':
+ xmpp_server.SetAuthenticated(True)
+ else:
+ xmpp_server.SetAuthenticated(False)
+ except:
+ xmpp_server.SetAuthenticated(False)
+
+ http_response = 200
+ raw_reply = 'XMPP Authenticated: %s ' % xmpp_server.GetAuthenticated()
+ self.send_response(http_response)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncDisableNotificationsOpHandler(self):
+ test_name = "/chromiumsync/disablenotifications"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ self.server.GetXmppServer().DisableNotifications()
+ result = 200
+ raw_reply = ('<html><title>Notifications disabled</title>'
+ '<H1>Notifications disabled</H1></html>')
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncEnableNotificationsOpHandler(self):
+ test_name = "/chromiumsync/enablenotifications"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ self.server.GetXmppServer().EnableNotifications()
+ result = 200
+ raw_reply = ('<html><title>Notifications enabled</title>'
+ '<H1>Notifications enabled</H1></html>')
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncSendNotificationOpHandler(self):
+ test_name = "/chromiumsync/sendnotification"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ query = urlparse.urlparse(self.path)[4]
+ query_params = urlparse.parse_qs(query)
+ channel = ''
+ data = ''
+ if 'channel' in query_params:
+ channel = query_params['channel'][0]
+ if 'data' in query_params:
+ data = query_params['data'][0]
+ self.server.GetXmppServer().SendNotification(channel, data)
+ result = 200
+ raw_reply = ('<html><title>Notification sent</title>'
+ '<H1>Notification sent with channel "%s" '
+ 'and data "%s"</H1></html>'
+ % (channel, data))
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncBirthdayErrorOpHandler(self):
+ test_name = "/chromiumsync/birthdayerror"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = self.server._sync_handler.HandleCreateBirthdayError()
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncTransientErrorOpHandler(self):
+ test_name = "/chromiumsync/transienterror"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = self.server._sync_handler.HandleSetTransientError()
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncErrorOpHandler(self):
+ test_name = "/chromiumsync/error"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = self.server._sync_handler.HandleSetInducedError(
+ self.path)
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncSyncTabFaviconsOpHandler(self):
+ test_name = "/chromiumsync/synctabfavicons"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = self.server._sync_handler.HandleSetSyncTabFavicons()
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncCreateSyncedBookmarksOpHandler(self):
+ test_name = "/chromiumsync/createsyncedbookmarks"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = self.server._sync_handler.HandleCreateSyncedBookmarks()
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncEnableKeystoreEncryptionOpHandler(self):
+ test_name = "/chromiumsync/enablekeystoreencryption"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = (
+ self.server._sync_handler.HandleEnableKeystoreEncryption())
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+ def ChromiumSyncRotateKeystoreKeysOpHandler(self):
+ test_name = "/chromiumsync/rotatekeystorekeys"
+ if not self._ShouldHandleRequest(test_name):
+ return False
+ result, raw_reply = (
+ self.server._sync_handler.HandleRotateKeystoreKeys())
+ self.send_response(result)
+ self.send_header('Content-Type', 'text/html')
+ self.send_header('Content-Length', len(raw_reply))
+ self.end_headers()
+ self.wfile.write(raw_reply)
+ return True
+
+
+class SyncServerRunner(testserver_base.TestServerRunner):
+ """TestServerRunner for the net test servers."""
+
+ def __init__(self):
+ super(SyncServerRunner, self).__init__()
+
+ def create_server(self, server_data):
+ port = self.options.port
+ host = self.options.host
+ xmpp_port = self.options.xmpp_port
+ server = SyncHTTPServer((host, port), xmpp_port, SyncPageHandler)
+ print 'Sync HTTP server started on port %d...' % server.server_port
+ print 'Sync XMPP server started on port %d...' % server.xmpp_port
+ server_data['port'] = server.server_port
+ server_data['xmpp_port'] = server.xmpp_port
+ return server
+
+ def run_server(self):
+ testserver_base.TestServerRunner.run_server(self)
+
+ def add_options(self):
+ testserver_base.TestServerRunner.add_options(self)
+ self.option_parser.add_option('--xmpp-port', default='0', type='int',
+ help='Port used by the XMPP server. If '
+ 'unspecified, the XMPP server will listen on '
+ 'an ephemeral port.')
+ # Override the default logfile name used in testserver.py.
+ self.option_parser.set_defaults(log_file='sync_testserver.log')
+
+if __name__ == '__main__':
+ sys.exit(SyncServerRunner().main())
diff --git a/sync/tools/testserver/xmppserver.py b/sync/tools/testserver/xmppserver.py
new file mode 100644
index 0000000..f9599c0
--- /dev/null
+++ b/sync/tools/testserver/xmppserver.py
@@ -0,0 +1,594 @@
+# Copyright 2013 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""A bare-bones and non-compliant XMPP server.
+
+Just enough of the protocol is implemented to get it to work with
+Chrome's sync notification system.
+"""
+
+import asynchat
+import asyncore
+import base64
+import re
+import socket
+from xml.dom import minidom
+
+# pychecker complains about the use of fileno(), which is implemented
+# by asyncore by forwarding to an internal object via __getattr__.
+__pychecker__ = 'no-classattr'
+
+
+class Error(Exception):
+ """Error class for this module."""
+ pass
+
+
+class UnexpectedXml(Error):
+ """Raised when an unexpected XML element has been encountered."""
+
+ def __init__(self, xml_element):
+ xml_text = xml_element.toxml()
+ Error.__init__(self, 'Unexpected XML element', xml_text)
+
+
+def ParseXml(xml_string):
+ """Parses the given string as XML and returns a minidom element
+ object.
+ """
+ dom = minidom.parseString(xml_string)
+
+ # minidom handles xmlns specially, but there's a bug where it sets
+ # the attribute value to None, which causes toxml() or toprettyxml()
+ # to break.
+ def FixMinidomXmlnsBug(xml_element):
+ if xml_element.getAttribute('xmlns') is None:
+ xml_element.setAttribute('xmlns', '')
+
+ def ApplyToAllDescendantElements(xml_element, fn):
+ fn(xml_element)
+ for node in xml_element.childNodes:
+ if node.nodeType == node.ELEMENT_NODE:
+ ApplyToAllDescendantElements(node, fn)
+
+ root = dom.documentElement
+ ApplyToAllDescendantElements(root, FixMinidomXmlnsBug)
+ return root
+
+
+def CloneXml(xml):
+ """Returns a deep copy of the given XML element.
+
+ Args:
+ xml: The XML element, which should be something returned from
+ ParseXml() (i.e., a root element).
+ """
+ return xml.ownerDocument.cloneNode(True).documentElement
+
+
+class StanzaParser(object):
+ """A hacky incremental XML parser.
+
+ StanzaParser consumes data incrementally via FeedString() and feeds
+ its delegate complete parsed stanzas (i.e., XML documents) via
+ FeedStanza(). Any stanzas passed to FeedStanza() are unlinked after
+ the callback is done.
+
+ Use like so:
+
+ class MyClass(object):
+ ...
+ def __init__(self, ...):
+ ...
+ self._parser = StanzaParser(self)
+ ...
+
+ def SomeFunction(self, ...):
+ ...
+ self._parser.FeedString(some_data)
+ ...
+
+ def FeedStanza(self, stanza):
+ ...
+ print stanza.toprettyxml()
+ ...
+ """
+
+ # NOTE(akalin): The following regexps are naive, but necessary since
+ # none of the existing Python 2.4/2.5 XML libraries support
+ # incremental parsing. This works well enough for our purposes.
+ #
+ # The regexps below assume that any present XML element starts at
+ # the beginning of the string, but there may be trailing whitespace.
+
+ # Matches an opening stream tag (e.g., '<stream:stream foo="bar">')
+ # (assumes that the stream XML namespace is defined in the tag).
+ _stream_re = re.compile(r'^(<stream:stream [^>]*>)\s*')
+
+ # Matches an empty element tag (e.g., '<foo bar="baz"/>').
+ _empty_element_re = re.compile(r'^(<[^>]*/>)\s*')
+
+ # Matches a non-empty element (e.g., '<foo bar="baz">quux</foo>').
+ # Does *not* handle nested elements.
+ _non_empty_element_re = re.compile(r'^(<([^ >]*)[^>]*>.*?</\2>)\s*')
+
+ # The closing tag for a stream tag. We have to insert this
+ # ourselves since all XML stanzas are children of the stream tag,
+ # which is never closed until the connection is closed.
+ _stream_suffix = '</stream:stream>'
+
+ def __init__(self, delegate):
+ self._buffer = ''
+ self._delegate = delegate
+
+ def FeedString(self, data):
+ """Consumes the given string data, possibly feeding one or more
+ stanzas to the delegate.
+ """
+ self._buffer += data
+ while (self._ProcessBuffer(self._stream_re, self._stream_suffix) or
+ self._ProcessBuffer(self._empty_element_re) or
+ self._ProcessBuffer(self._non_empty_element_re)):
+ pass
+
+ def _ProcessBuffer(self, regexp, xml_suffix=''):
+ """If the buffer matches the given regexp, removes the match from
+ the buffer, appends the given suffix, parses it, and feeds it to
+ the delegate.
+
+ Returns:
+ Whether or not the buffer matched the given regexp.
+ """
+ results = regexp.match(self._buffer)
+ if not results:
+ return False
+ xml_text = self._buffer[:results.end()] + xml_suffix
+ self._buffer = self._buffer[results.end():]
+ stanza = ParseXml(xml_text)
+ self._delegate.FeedStanza(stanza)
+ # Needed because stanza may have cycles.
+ stanza.unlink()
+ return True
+
+
+class Jid(object):
+ """Simple struct for an XMPP jid (essentially an e-mail address with
+ an optional resource string).
+ """
+
+ def __init__(self, username, domain, resource=''):
+ self.username = username
+ self.domain = domain
+ self.resource = resource
+
+ def __str__(self):
+ jid_str = "%s@%s" % (self.username, self.domain)
+ if self.resource:
+ jid_str += '/' + self.resource
+ return jid_str
+
+ def GetBareJid(self):
+ return Jid(self.username, self.domain)
+
+
+class IdGenerator(object):
+ """Simple class to generate unique IDs for XMPP messages."""
+
+ def __init__(self, prefix):
+ self._prefix = prefix
+ self._id = 0
+
+ def GetNextId(self):
+ next_id = "%s.%s" % (self._prefix, self._id)
+ self._id += 1
+ return next_id
+
+
+class HandshakeTask(object):
+ """Class to handle the initial handshake with a connected XMPP
+ client.
+ """
+
+ # The handshake states in order.
+ (_INITIAL_STREAM_NEEDED,
+ _AUTH_NEEDED,
+ _AUTH_STREAM_NEEDED,
+ _BIND_NEEDED,
+ _SESSION_NEEDED,
+ _FINISHED) = range(6)
+
+ # Used when in the _INITIAL_STREAM_NEEDED and _AUTH_STREAM_NEEDED
+ # states. Not an XML object as it's only the opening tag.
+ #
+ # The from and id attributes are filled in later.
+ _STREAM_DATA = (
+ '<stream:stream from="%s" id="%s" '
+ 'version="1.0" xmlns:stream="http://etherx.jabber.org/streams" '
+ 'xmlns="jabber:client">')
+
+ # Used when in the _INITIAL_STREAM_NEEDED state.
+ _AUTH_STANZA = ParseXml(
+ '<stream:features xmlns:stream="http://etherx.jabber.org/streams">'
+ ' <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">'
+ ' <mechanism>PLAIN</mechanism>'
+ ' <mechanism>X-GOOGLE-TOKEN</mechanism>'
+ ' </mechanisms>'
+ '</stream:features>')
+
+ # Used when in the _AUTH_NEEDED state.
+ _AUTH_SUCCESS_STANZA = ParseXml(
+ '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
+
+ # Used when in the _AUTH_NEEDED state.
+ _AUTH_FAILURE_STANZA = ParseXml(
+ '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>')
+
+ # Used when in the _AUTH_STREAM_NEEDED state.
+ _BIND_STANZA = ParseXml(
+ '<stream:features xmlns:stream="http://etherx.jabber.org/streams">'
+ ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"/>'
+ ' <session xmlns="urn:ietf:params:xml:ns:xmpp-session"/>'
+ '</stream:features>')
+
+ # Used when in the _BIND_NEEDED state.
+ #
+ # The id and jid attributes are filled in later.
+ _BIND_RESULT_STANZA = ParseXml(
+ '<iq id="" type="result">'
+ ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind">'
+ ' <jid/>'
+ ' </bind>'
+ '</iq>')
+
+ # Used when in the _SESSION_NEEDED state.
+ #
+ # The id attribute is filled in later.
+ _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>')
+
+ def __init__(self, connection, resource_prefix, authenticated):
+ self._connection = connection
+ self._id_generator = IdGenerator(resource_prefix)
+ self._username = ''
+ self._domain = ''
+ self._jid = None
+ self._authenticated = authenticated
+ self._resource_prefix = resource_prefix
+ self._state = self._INITIAL_STREAM_NEEDED
+
+ def FeedStanza(self, stanza):
+ """Inspects the given stanza and changes the handshake state if needed.
+
+ Called when a stanza is received from the client. Inspects the
+ stanza to make sure it has the expected attributes given the
+ current state, advances the state if needed, and sends a reply to
+ the client if needed.
+ """
+ def ExpectStanza(stanza, name):
+ if stanza.tagName != name:
+ raise UnexpectedXml(stanza)
+
+ def ExpectIq(stanza, type, name):
+ ExpectStanza(stanza, 'iq')
+ if (stanza.getAttribute('type') != type or
+ stanza.firstChild.tagName != name):
+ raise UnexpectedXml(stanza)
+
+ def GetStanzaId(stanza):
+ return stanza.getAttribute('id')
+
+ def HandleStream(stanza):
+ ExpectStanza(stanza, 'stream:stream')
+ domain = stanza.getAttribute('to')
+ if domain:
+ self._domain = domain
+ SendStreamData()
+
+ def SendStreamData():
+ next_id = self._id_generator.GetNextId()
+ stream_data = self._STREAM_DATA % (self._domain, next_id)
+ self._connection.SendData(stream_data)
+
+ def GetUserDomain(stanza):
+ encoded_username_password = stanza.firstChild.data
+ username_password = base64.b64decode(encoded_username_password)
+ (_, username_domain, _) = username_password.split('\0')
+ # The domain may be omitted.
+ #
+ # If we were using python 2.5, we'd be able to do:
+ #
+ # username, _, domain = username_domain.partition('@')
+ # if not domain:
+ # domain = self._domain
+ at_pos = username_domain.find('@')
+ if at_pos != -1:
+ username = username_domain[:at_pos]
+ domain = username_domain[at_pos+1:]
+ else:
+ username = username_domain
+ domain = self._domain
+ return (username, domain)
+
+ def Finish():
+ self._state = self._FINISHED
+ self._connection.HandshakeDone(self._jid)
+
+ if self._state == self._INITIAL_STREAM_NEEDED:
+ HandleStream(stanza)
+ self._connection.SendStanza(self._AUTH_STANZA, False)
+ self._state = self._AUTH_NEEDED
+
+ elif self._state == self._AUTH_NEEDED:
+ ExpectStanza(stanza, 'auth')
+ (self._username, self._domain) = GetUserDomain(stanza)
+ if self._authenticated:
+ self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False)
+ self._state = self._AUTH_STREAM_NEEDED
+ else:
+ self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False)
+ Finish()
+
+ elif self._state == self._AUTH_STREAM_NEEDED:
+ HandleStream(stanza)
+ self._connection.SendStanza(self._BIND_STANZA, False)
+ self._state = self._BIND_NEEDED
+
+ elif self._state == self._BIND_NEEDED:
+ ExpectIq(stanza, 'set', 'bind')
+ stanza_id = GetStanzaId(stanza)
+ resource_element = stanza.getElementsByTagName('resource')[0]
+ resource = resource_element.firstChild.data
+ full_resource = '%s.%s' % (self._resource_prefix, resource)
+ response = CloneXml(self._BIND_RESULT_STANZA)
+ response.setAttribute('id', stanza_id)
+ self._jid = Jid(self._username, self._domain, full_resource)
+ jid_text = response.parentNode.createTextNode(str(self._jid))
+ response.getElementsByTagName('jid')[0].appendChild(jid_text)
+ self._connection.SendStanza(response)
+ self._state = self._SESSION_NEEDED
+
+ elif self._state == self._SESSION_NEEDED:
+ ExpectIq(stanza, 'set', 'session')
+ stanza_id = GetStanzaId(stanza)
+ xml = CloneXml(self._IQ_RESPONSE_STANZA)
+ xml.setAttribute('id', stanza_id)
+ self._connection.SendStanza(xml)
+ Finish()
+
+
+def AddrString(addr):
+ return '%s:%d' % addr
+
+
+class XmppConnection(asynchat.async_chat):
+ """A single XMPP client connection.
+
+ This class handles the connection to a single XMPP client (via a
+ socket). It does the XMPP handshake and also implements the (old)
+ Google notification protocol.
+ """
+
+ # Used for acknowledgements to the client.
+ #
+ # The from and id attributes are filled in later.
+ _IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>')
+
+ def __init__(self, sock, socket_map, delegate, addr, authenticated):
+ """Starts up the xmpp connection.
+
+ Args:
+ sock: The socket to the client.
+ socket_map: A map from sockets to their owning objects.
+ delegate: The delegate, which is notified when the XMPP
+ handshake is successful, when the connection is closed, and
+ when a notification has to be broadcast.
+ addr: The host/port of the client.
+ """
+ # We do this because in versions of python < 2.6,
+ # async_chat.__init__ doesn't take a map argument nor pass it to
+ # dispatcher.__init__. We rely on the fact that
+ # async_chat.__init__ calls dispatcher.__init__ as the last thing
+ # it does, and that calling dispatcher.__init__ with socket=None
+ # and map=None is essentially a no-op.
+ asynchat.async_chat.__init__(self)
+ asyncore.dispatcher.__init__(self, sock, socket_map)
+
+ self.set_terminator(None)
+
+ self._delegate = delegate
+ self._parser = StanzaParser(self)
+ self._jid = None
+
+ self._addr = addr
+ addr_str = AddrString(self._addr)
+ self._handshake_task = HandshakeTask(self, addr_str, authenticated)
+ print 'Starting connection to %s' % self
+
+ def __str__(self):
+ if self._jid:
+ return str(self._jid)
+ else:
+ return AddrString(self._addr)
+
+ # async_chat implementation.
+
+ def collect_incoming_data(self, data):
+ self._parser.FeedString(data)
+
+ # This is only here to make pychecker happy.
+ def found_terminator(self):
+ asynchat.async_chat.found_terminator(self)
+
+ def close(self):
+ print "Closing connection to %s" % self
+ self._delegate.OnXmppConnectionClosed(self)
+ asynchat.async_chat.close(self)
+
+ # Called by self._parser.FeedString().
+ def FeedStanza(self, stanza):
+ if self._handshake_task:
+ self._handshake_task.FeedStanza(stanza)
+ elif stanza.tagName == 'iq' and stanza.getAttribute('type') == 'result':
+ # Ignore all client acks.
+ pass
+ elif (stanza.firstChild and
+ stanza.firstChild.namespaceURI == 'google:push'):
+ self._HandlePushCommand(stanza)
+ else:
+ raise UnexpectedXml(stanza)
+
+ # Called by self._handshake_task.
+ def HandshakeDone(self, jid):
+ if jid:
+ self._jid = jid
+ self._handshake_task = None
+ self._delegate.OnXmppHandshakeDone(self)
+ print "Handshake done for %s" % self
+ else:
+ print "Handshake failed for %s" % self
+ self.close()
+
+ def _HandlePushCommand(self, stanza):
+ if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe':
+ # Subscription request.
+ self._SendIqResponseStanza(stanza)
+ elif stanza.tagName == 'message' and stanza.firstChild.tagName == 'push':
+ # Send notification request.
+ self._delegate.ForwardNotification(self, stanza)
+ else:
+ raise UnexpectedXml(command_xml)
+
+ def _SendIqResponseStanza(self, iq):
+ stanza = CloneXml(self._IQ_RESPONSE_STANZA)
+ stanza.setAttribute('from', str(self._jid.GetBareJid()))
+ stanza.setAttribute('id', iq.getAttribute('id'))
+ self.SendStanza(stanza)
+
+ def SendStanza(self, stanza, unlink=True):
+ """Sends a stanza to the client.
+
+ Args:
+ stanza: The stanza to send.
+ unlink: Whether to unlink stanza after sending it. (Pass in
+ False if stanza is a constant.)
+ """
+ self.SendData(stanza.toxml())
+ if unlink:
+ stanza.unlink()
+
+ def SendData(self, data):
+ """Sends raw data to the client.
+ """
+ # We explicitly encode to ascii as that is what the client expects
+ # (some minidom library functions return unicode strings).
+ self.push(data.encode('ascii'))
+
+ def ForwardNotification(self, notification_stanza):
+ """Forwards a notification to the client."""
+ notification_stanza.setAttribute('from', str(self._jid.GetBareJid()))
+ notification_stanza.setAttribute('to', str(self._jid))
+ self.SendStanza(notification_stanza, False)
+
+
+class XmppServer(asyncore.dispatcher):
+ """The main XMPP server class.
+
+ The XMPP server starts accepting connections on the given address
+ and spawns off XmppConnection objects for each one.
+
+ Use like so:
+
+ socket_map = {}
+ xmpp_server = xmppserver.XmppServer(socket_map, ('127.0.0.1', 5222))
+ asyncore.loop(30.0, False, socket_map)
+ """
+
+ # Used when sending a notification.
+ _NOTIFICATION_STANZA = ParseXml(
+ '<message>'
+ ' <push xmlns="google:push">'
+ ' <data/>'
+ ' </push>'
+ '</message>')
+
+ def __init__(self, socket_map, addr):
+ asyncore.dispatcher.__init__(self, None, socket_map)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.set_reuse_addr()
+ self.bind(addr)
+ self.listen(5)
+ self._socket_map = socket_map
+ self._connections = set()
+ self._handshake_done_connections = set()
+ self._notifications_enabled = True
+ self._authenticated = True
+
+ def handle_accept(self):
+ (sock, addr) = self.accept()
+ xmpp_connection = XmppConnection(
+ sock, self._socket_map, self, addr, self._authenticated)
+ self._connections.add(xmpp_connection)
+ # Return the new XmppConnection for testing.
+ return xmpp_connection
+
+ def close(self):
+ # A copy is necessary since calling close on each connection
+ # removes it from self._connections.
+ for connection in self._connections.copy():
+ connection.close()
+ asyncore.dispatcher.close(self)
+
+ def EnableNotifications(self):
+ self._notifications_enabled = True
+
+ def DisableNotifications(self):
+ self._notifications_enabled = False
+
+ def MakeNotification(self, channel, data):
+ """Makes a notification from the given channel and encoded data.
+
+ Args:
+ channel: The channel on which to send the notification.
+ data: The notification payload.
+ """
+ notification_stanza = CloneXml(self._NOTIFICATION_STANZA)
+ push_element = notification_stanza.getElementsByTagName('push')[0]
+ push_element.setAttribute('channel', channel)
+ data_element = push_element.getElementsByTagName('data')[0]
+ encoded_data = base64.b64encode(data)
+ data_text = notification_stanza.parentNode.createTextNode(encoded_data)
+ data_element.appendChild(data_text)
+ return notification_stanza
+
+ def SendNotification(self, channel, data):
+ """Sends a notification to all connections.
+
+ Args:
+ channel: The channel on which to send the notification.
+ data: The notification payload.
+ """
+ notification_stanza = self.MakeNotification(channel, data)
+ self.ForwardNotification(None, notification_stanza)
+ notification_stanza.unlink()
+
+ def SetAuthenticated(self, auth_valid):
+ self._authenticated = auth_valid
+
+ def GetAuthenticated(self):
+ return self._authenticated
+
+ # XmppConnection delegate methods.
+ def OnXmppHandshakeDone(self, xmpp_connection):
+ self._handshake_done_connections.add(xmpp_connection)
+
+ def OnXmppConnectionClosed(self, xmpp_connection):
+ self._connections.discard(xmpp_connection)
+ self._handshake_done_connections.discard(xmpp_connection)
+
+ def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
+ if self._notifications_enabled:
+ for connection in self._handshake_done_connections:
+ print 'Sending notification to %s' % connection
+ connection.ForwardNotification(notification_stanza)
+ else:
+ print 'Notifications disabled; dropping notification'
diff --git a/sync/tools/testserver/xmppserver_test.py b/sync/tools/testserver/xmppserver_test.py
new file mode 100755
index 0000000..1a539d1
--- /dev/null
+++ b/sync/tools/testserver/xmppserver_test.py
@@ -0,0 +1,421 @@
+#!/usr/bin/env python
+# Copyright 2013 The Chromium Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Tests exercising the various classes in xmppserver.py."""
+
+import unittest
+
+import base64
+import xmppserver
+
+class XmlUtilsTest(unittest.TestCase):
+
+ def testParseXml(self):
+ xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
+ xml = xmppserver.ParseXml(xml_text)
+ self.assertEqual(xml.toxml(), xml_text)
+
+ def testCloneXml(self):
+ xml = xmppserver.ParseXml('<foo/>')
+ xml_clone = xmppserver.CloneXml(xml)
+ xml_clone.setAttribute('bar', 'baz')
+ self.assertEqual(xml, xml)
+ self.assertEqual(xml_clone, xml_clone)
+ self.assertNotEqual(xml, xml_clone)
+
+ def testCloneXmlUnlink(self):
+ xml_text = '<foo/>'
+ xml = xmppserver.ParseXml(xml_text)
+ xml_clone = xmppserver.CloneXml(xml)
+ xml.unlink()
+ self.assertEqual(xml.parentNode, None)
+ self.assertNotEqual(xml_clone.parentNode, None)
+ self.assertEqual(xml_clone.toxml(), xml_text)
+
+class StanzaParserTest(unittest.TestCase):
+
+ def setUp(self):
+ self.stanzas = []
+
+ def FeedStanza(self, stanza):
+ # We can't append stanza directly because it is unlinked after
+ # this callback.
+ self.stanzas.append(stanza.toxml())
+
+ def testBasic(self):
+ parser = xmppserver.StanzaParser(self)
+ parser.FeedString('<foo')
+ self.assertEqual(len(self.stanzas), 0)
+ parser.FeedString('/><bar></bar>')
+ self.assertEqual(self.stanzas[0], '<foo/>')
+ self.assertEqual(self.stanzas[1], '<bar/>')
+
+ def testStream(self):
+ parser = xmppserver.StanzaParser(self)
+ parser.FeedString('<stream')
+ self.assertEqual(len(self.stanzas), 0)
+ parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
+ self.assertEqual(self.stanzas[0],
+ '<stream:stream foo="bar" xmlns:stream="baz"/>')
+
+ def testNested(self):
+ parser = xmppserver.StanzaParser(self)
+ parser.FeedString('<foo')
+ self.assertEqual(len(self.stanzas), 0)
+ parser.FeedString(' bar="baz"')
+ parser.FeedString('><baz/><blah>meh</blah></foo>')
+ self.assertEqual(self.stanzas[0],
+ '<foo bar="baz"><baz/><blah>meh</blah></foo>')
+
+
+class JidTest(unittest.TestCase):
+
+ def testBasic(self):
+ jid = xmppserver.Jid('foo', 'bar.com')
+ self.assertEqual(str(jid), 'foo@bar.com')
+
+ def testResource(self):
+ jid = xmppserver.Jid('foo', 'bar.com', 'resource')
+ self.assertEqual(str(jid), 'foo@bar.com/resource')
+
+ def testGetBareJid(self):
+ jid = xmppserver.Jid('foo', 'bar.com', 'resource')
+ self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com')
+
+
+class IdGeneratorTest(unittest.TestCase):
+
+ def testBasic(self):
+ id_generator = xmppserver.IdGenerator('foo')
+ for i in xrange(0, 100):
+ self.assertEqual('foo.%d' % i, id_generator.GetNextId())
+
+
+class HandshakeTaskTest(unittest.TestCase):
+
+ def setUp(self):
+ self.Reset()
+
+ def Reset(self):
+ self.data_received = 0
+ self.handshake_done = False
+ self.jid = None
+
+ def SendData(self, _):
+ self.data_received += 1
+
+ def SendStanza(self, _, unused=True):
+ self.data_received += 1
+
+ def HandshakeDone(self, jid):
+ self.handshake_done = True
+ self.jid = jid
+
+ def DoHandshake(self, resource_prefix, resource, username,
+ initial_stream_domain, auth_domain, auth_stream_domain):
+ self.Reset()
+ handshake_task = (
+ xmppserver.HandshakeTask(self, resource_prefix, True))
+ stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
+ stream_xml.setAttribute('to', initial_stream_domain)
+ self.assertEqual(self.data_received, 0)
+ handshake_task.FeedStanza(stream_xml)
+ self.assertEqual(self.data_received, 2)
+
+ if auth_domain:
+ username_domain = '%s@%s' % (username, auth_domain)
+ else:
+ username_domain = username
+ auth_string = base64.b64encode('\0%s\0bar' % username_domain)
+ auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
+ handshake_task.FeedStanza(auth_xml)
+ self.assertEqual(self.data_received, 3)
+
+ stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
+ stream_xml.setAttribute('to', auth_stream_domain)
+ handshake_task.FeedStanza(stream_xml)
+ self.assertEqual(self.data_received, 5)
+
+ bind_xml = xmppserver.ParseXml(
+ '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
+ handshake_task.FeedStanza(bind_xml)
+ self.assertEqual(self.data_received, 6)
+
+ self.assertFalse(self.handshake_done)
+
+ session_xml = xmppserver.ParseXml(
+ '<iq type="set"><session></session></iq>')
+ handshake_task.FeedStanza(session_xml)
+ self.assertEqual(self.data_received, 7)
+
+ self.assertTrue(self.handshake_done)
+
+ self.assertEqual(self.jid.username, username)
+ self.assertEqual(self.jid.domain,
+ auth_stream_domain or auth_domain or
+ initial_stream_domain)
+ self.assertEqual(self.jid.resource,
+ '%s.%s' % (resource_prefix, resource))
+
+ handshake_task.FeedStanza('<ignored/>')
+ self.assertEqual(self.data_received, 7)
+
+ def DoHandshakeUnauthenticated(self, resource_prefix, resource, username,
+ initial_stream_domain):
+ self.Reset()
+ handshake_task = (
+ xmppserver.HandshakeTask(self, resource_prefix, False))
+ stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
+ stream_xml.setAttribute('to', initial_stream_domain)
+ self.assertEqual(self.data_received, 0)
+ handshake_task.FeedStanza(stream_xml)
+ self.assertEqual(self.data_received, 2)
+
+ self.assertFalse(self.handshake_done)
+
+ auth_string = base64.b64encode('\0%s\0bar' % username)
+ auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
+ handshake_task.FeedStanza(auth_xml)
+ self.assertEqual(self.data_received, 3)
+
+ self.assertTrue(self.handshake_done)
+
+ self.assertEqual(self.jid, None)
+
+ handshake_task.FeedStanza('<ignored/>')
+ self.assertEqual(self.data_received, 3)
+
+ def testBasic(self):
+ self.DoHandshake('resource_prefix', 'resource',
+ 'foo', 'bar.com', 'baz.com', 'quux.com')
+
+ def testDomainBehavior(self):
+ self.DoHandshake('resource_prefix', 'resource',
+ 'foo', 'bar.com', 'baz.com', 'quux.com')
+ self.DoHandshake('resource_prefix', 'resource',
+ 'foo', 'bar.com', 'baz.com', '')
+ self.DoHandshake('resource_prefix', 'resource',
+ 'foo', 'bar.com', '', '')
+ self.DoHandshake('resource_prefix', 'resource',
+ 'foo', '', '', '')
+
+ def testBasicUnauthenticated(self):
+ self.DoHandshakeUnauthenticated('resource_prefix', 'resource',
+ 'foo', 'bar.com')
+
+
+class FakeSocket(object):
+ """A fake socket object used for testing.
+ """
+
+ def __init__(self):
+ self._sent_data = []
+
+ def GetSentData(self):
+ return self._sent_data
+
+ # socket-like methods.
+ def fileno(self):
+ return 0
+
+ def setblocking(self, int):
+ pass
+
+ def getpeername(self):
+ return ('', 0)
+
+ def send(self, data):
+ self._sent_data.append(data)
+ pass
+
+ def close(self):
+ pass
+
+
+class XmppConnectionTest(unittest.TestCase):
+
+ def setUp(self):
+ self.connections = set()
+ self.fake_socket = FakeSocket()
+
+ # XmppConnection delegate methods.
+ def OnXmppHandshakeDone(self, xmpp_connection):
+ self.connections.add(xmpp_connection)
+
+ def OnXmppConnectionClosed(self, xmpp_connection):
+ self.connections.discard(xmpp_connection)
+
+ def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
+ for connection in self.connections:
+ connection.ForwardNotification(notification_stanza)
+
+ def testBasic(self):
+ socket_map = {}
+ xmpp_connection = xmppserver.XmppConnection(
+ self.fake_socket, socket_map, self, ('', 0), True)
+ self.assertEqual(len(socket_map), 1)
+ self.assertEqual(len(self.connections), 0)
+ xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
+ self.assertEqual(len(socket_map), 1)
+ self.assertEqual(len(self.connections), 1)
+
+ sent_data = self.fake_socket.GetSentData()
+
+ # Test subscription request.
+ self.assertEqual(len(sent_data), 0)
+ xmpp_connection.collect_incoming_data(
+ '<iq><subscribe xmlns="google:push"></subscribe></iq>')
+ self.assertEqual(len(sent_data), 1)
+
+ # Test acks.
+ xmpp_connection.collect_incoming_data('<iq type="result"/>')
+ self.assertEqual(len(sent_data), 1)
+
+ # Test notification.
+ xmpp_connection.collect_incoming_data(
+ '<message><push xmlns="google:push"/></message>')
+ self.assertEqual(len(sent_data), 2)
+
+ # Test unexpected stanza.
+ def SendUnexpectedStanza():
+ xmpp_connection.collect_incoming_data('<foo/>')
+ self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
+
+ # Test unexpected notifier command.
+ def SendUnexpectedNotifierCommand():
+ xmpp_connection.collect_incoming_data(
+ '<iq><foo xmlns="google:notifier"/></iq>')
+ self.assertRaises(xmppserver.UnexpectedXml,
+ SendUnexpectedNotifierCommand)
+
+ # Test close.
+ xmpp_connection.close()
+ self.assertEqual(len(socket_map), 0)
+ self.assertEqual(len(self.connections), 0)
+
+ def testBasicUnauthenticated(self):
+ socket_map = {}
+ xmpp_connection = xmppserver.XmppConnection(
+ self.fake_socket, socket_map, self, ('', 0), False)
+ self.assertEqual(len(socket_map), 1)
+ self.assertEqual(len(self.connections), 0)
+ xmpp_connection.HandshakeDone(None)
+ self.assertEqual(len(socket_map), 0)
+ self.assertEqual(len(self.connections), 0)
+
+ # Test unexpected stanza.
+ def SendUnexpectedStanza():
+ xmpp_connection.collect_incoming_data('<foo/>')
+ self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
+
+ # Test redundant close.
+ xmpp_connection.close()
+ self.assertEqual(len(socket_map), 0)
+ self.assertEqual(len(self.connections), 0)
+
+
+class FakeXmppServer(xmppserver.XmppServer):
+ """A fake XMPP server object used for testing.
+ """
+
+ def __init__(self):
+ self._socket_map = {}
+ self._fake_sockets = set()
+ self._next_jid_suffix = 1
+ xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0))
+
+ def GetSocketMap(self):
+ return self._socket_map
+
+ def GetFakeSockets(self):
+ return self._fake_sockets
+
+ def AddHandshakeCompletedConnection(self):
+ """Creates a new XMPP connection and completes its handshake.
+ """
+ xmpp_connection = self.handle_accept()
+ jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com')
+ self._next_jid_suffix += 1
+ xmpp_connection.HandshakeDone(jid)
+
+ # XmppServer overrides.
+ def accept(self):
+ fake_socket = FakeSocket()
+ self._fake_sockets.add(fake_socket)
+ return (fake_socket, ('', 0))
+
+ def close(self):
+ self._fake_sockets.clear()
+ xmppserver.XmppServer.close(self)
+
+
+class XmppServerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.xmpp_server = FakeXmppServer()
+
+ def AssertSentDataLength(self, expected_length):
+ for fake_socket in self.xmpp_server.GetFakeSockets():
+ self.assertEqual(len(fake_socket.GetSentData()), expected_length)
+
+ def testBasic(self):
+ socket_map = self.xmpp_server.GetSocketMap()
+ self.assertEqual(len(socket_map), 1)
+ self.xmpp_server.AddHandshakeCompletedConnection()
+ self.assertEqual(len(socket_map), 2)
+ self.xmpp_server.close()
+ self.assertEqual(len(socket_map), 0)
+
+ def testMakeNotification(self):
+ notification = self.xmpp_server.MakeNotification('channel', 'data')
+ expected_xml = (
+ '<message>'
+ ' <push channel="channel" xmlns="google:push">'
+ ' <data>%s</data>'
+ ' </push>'
+ '</message>' % base64.b64encode('data'))
+ self.assertEqual(notification.toxml(), expected_xml)
+
+ def testSendNotification(self):
+ # Add a few connections.
+ for _ in xrange(0, 7):
+ self.xmpp_server.AddHandshakeCompletedConnection()
+
+ self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7)
+
+ self.AssertSentDataLength(0)
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(1)
+
+ def testEnableDisableNotifications(self):
+ # Add a few connections.
+ for _ in xrange(0, 5):
+ self.xmpp_server.AddHandshakeCompletedConnection()
+
+ self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5)
+
+ self.AssertSentDataLength(0)
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(1)
+
+ self.xmpp_server.EnableNotifications()
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(2)
+
+ self.xmpp_server.DisableNotifications()
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(2)
+
+ self.xmpp_server.DisableNotifications()
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(2)
+
+ self.xmpp_server.EnableNotifications()
+ self.xmpp_server.SendNotification('channel', 'data')
+ self.AssertSentDataLength(3)
+
+
+if __name__ == '__main__':
+ unittest.main()