diff options
author | rsimha@chromium.org <rsimha@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-01-20 01:10:24 +0000 |
---|---|---|
committer | rsimha@chromium.org <rsimha@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2013-01-20 01:10:24 +0000 |
commit | e4c029f76eb948af468a4d11ec0d3272671ddb58 (patch) | |
tree | f80be661f716ca63d85e2f94d562b89c242af599 /sync/tools/testserver | |
parent | 00353d188a87c6a2f953e22a73d7c18fa2c37b2a (diff) | |
download | chromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.zip chromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.tar.gz chromium_src-e4c029f76eb948af468a4d11ec0d3272671ddb58.tar.bz2 |
[sync] Divorce python sync test server chromiumsync.py from testserver.py
Various chrome test suites use the infrastructure in net::LocalTestServer and net/tools/testserver.py to create local test server instances against which to run automated tests. Sync tests use reference implementations of sync and xmpp servers, which build on the testserver infrastructure in net/. In the past, the sync testserver was small enough that it made sense for it to be a part of the testserver in net/. This, however, resulted in an unwanted dependency from net/ onto sync/, due to the sync proto modules needed to run a python sync server. Now that the sync testserver has grown considerably in scope, it is time to separate it out from net/ while reusing base testserver code, and eliminate the dependency from net/ onto sync/. This work also provides us with the opportunity to remove a whole bunch of dead pyauto sync test code in chrome/test/functional.
This patch does the following:
- Moves the native class LocalSyncTestServer from net/test/ to sync/test/.
- Moves chromiumsync{_test}.py and xmppserver{_test}.py from net/tools/testserver/ to sync/tools/testserver/.
- Removes all sync server specific code from net/.
- Adds a new sync_testserver.py runner script for the python sync test.
- Moves some base classes from testserver.py to testserver_base.py so they can be reused by sync_testserver.py.
- Audits all the python imports in testserver.py, testserver_base.py and sync_testserver.py to make sure there are no unnecessary / missing imports.
- Adds a new run_sync_testserver runner executable to launch a sync testserver.
- Removes a couple of static methods from LocalTestServer, that were being used by run_testserver, and refactors run_sync_testserver to use their non-static versions.
- Adds the ability to run both chromiumsync_test.py and xmppserver_test.py from run_sync_testserver.
- Fixes chromiumsync.py to undo / rectify some older changes that broke tests in chromiumsync_test.py.
- Adds a new test target called test_support_sync_testserver to sync.gyp.
- Removes the hacky dependency on sync_proto from net.gyp:net_test_support.
- Updates various gyp files across chrome to use the new sync testserver target.
- Audits dependencies of net_test_support, run_testserver, and the newly added targets.
- Fixes the android chrome testserver spawner script to account for the above changes.
- Removes all mentions of TYPE_SYNC from the pyauto TestServer shim.
- Deletes all (deprecated) pyauto sync tests. (They had all become broken over time, gotten disabled, and were all redundant due to their equivalent sync integration tests.)
- Removes all sync related pyauto hooks from TestingAutomationProvider, since they are no longer going to be used.
- Takes care of a TODO in safe_browser_testserver.py to remove an unnecessary code block.
Note: A majority of the bugs listed below are for individual pyauto sync tests. Deleting the sync pyauto test script fixes all these bugs in one fell swoop.
TBR=mattm@chromium.org
BUG=117559, 119403, 159731, 15016, 80329, 49378, 87642, 86949, 88679, 104227, 88593, 124913
TEST=run_testserver, run_sync_testserver, sync_integration_tests, sync_performance_tests. All chrome tests that use a testserver should continue to work.
Review URL: https://codereview.chromium.org/11971025
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@177864 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'sync/tools/testserver')
-rw-r--r-- | sync/tools/testserver/DEPS | 3 | ||||
-rw-r--r-- | sync/tools/testserver/OWNERS | 3 | ||||
-rw-r--r-- | sync/tools/testserver/chromiumsync.py | 1370 | ||||
-rwxr-xr-x | sync/tools/testserver/chromiumsync_test.py | 655 | ||||
-rw-r--r-- | sync/tools/testserver/run_sync_testserver.cc | 121 | ||||
-rwxr-xr-x | sync/tools/testserver/sync_testserver.py | 447 | ||||
-rw-r--r-- | sync/tools/testserver/xmppserver.py | 594 | ||||
-rwxr-xr-x | sync/tools/testserver/xmppserver_test.py | 421 |
8 files changed, 3614 insertions, 0 deletions
diff --git a/sync/tools/testserver/DEPS b/sync/tools/testserver/DEPS new file mode 100644 index 0000000..f9b201f --- /dev/null +++ b/sync/tools/testserver/DEPS @@ -0,0 +1,3 @@ +include_rules = [ + "+sync/test", +] diff --git a/sync/tools/testserver/OWNERS b/sync/tools/testserver/OWNERS new file mode 100644 index 0000000..e628479 --- /dev/null +++ b/sync/tools/testserver/OWNERS @@ -0,0 +1,3 @@ +akalin@chromium.org +nick@chromium.org +rsimha@chromium.org diff --git a/sync/tools/testserver/chromiumsync.py b/sync/tools/testserver/chromiumsync.py new file mode 100644 index 0000000..eb631c2 --- /dev/null +++ b/sync/tools/testserver/chromiumsync.py @@ -0,0 +1,1370 @@ +# Copyright 2013 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""An implementation of the server side of the Chromium sync protocol. + +The details of the protocol are described mostly by comments in the protocol +buffer definition at chrome/browser/sync/protocol/sync.proto. +""" + +import cgi +import copy +import operator +import pickle +import random +import string +import sys +import threading +import time +import urlparse + +import app_notification_specifics_pb2 +import app_setting_specifics_pb2 +import app_specifics_pb2 +import autofill_specifics_pb2 +import bookmark_specifics_pb2 +import get_updates_caller_info_pb2 +import extension_setting_specifics_pb2 +import extension_specifics_pb2 +import history_delete_directive_specifics_pb2 +import nigori_specifics_pb2 +import password_specifics_pb2 +import preference_specifics_pb2 +import search_engine_specifics_pb2 +import session_specifics_pb2 +import sync_pb2 +import sync_enums_pb2 +import synced_notification_specifics_pb2 +import theme_specifics_pb2 +import typed_url_specifics_pb2 + +# An enumeration of the various kinds of data that can be synced. +# Over the wire, this enumeration is not used: a sync object's type is +# inferred by which EntitySpecifics field it has. But in the context +# of a program, it is useful to have an enumeration. +ALL_TYPES = ( + TOP_LEVEL, # The type of the 'Google Chrome' folder. + APPS, + APP_NOTIFICATION, + APP_SETTINGS, + AUTOFILL, + AUTOFILL_PROFILE, + BOOKMARK, + DEVICE_INFO, + EXPERIMENTS, + EXTENSIONS, + HISTORY_DELETE_DIRECTIVE, + NIGORI, + PASSWORD, + PREFERENCE, + SEARCH_ENGINE, + SESSION, + SYNCED_NOTIFICATION, + THEME, + TYPED_URL, + EXTENSION_SETTINGS) = range(20) + +# An eumeration on the frequency at which the server should send errors +# to the client. This would be specified by the url that triggers the error. +# Note: This enum should be kept in the same order as the enum in sync_test.h. +SYNC_ERROR_FREQUENCY = ( + ERROR_FREQUENCY_NONE, + ERROR_FREQUENCY_ALWAYS, + ERROR_FREQUENCY_TWO_THIRDS) = range(3) + +# Well-known server tag of the top level 'Google Chrome' folder. +TOP_LEVEL_FOLDER_TAG = 'google_chrome' + +# Given a sync type from ALL_TYPES, find the FieldDescriptor corresponding +# to that datatype. Note that TOP_LEVEL has no such token. +SYNC_TYPE_FIELDS = sync_pb2.EntitySpecifics.DESCRIPTOR.fields_by_name +SYNC_TYPE_TO_DESCRIPTOR = { + APP_NOTIFICATION: SYNC_TYPE_FIELDS['app_notification'], + APP_SETTINGS: SYNC_TYPE_FIELDS['app_setting'], + APPS: SYNC_TYPE_FIELDS['app'], + AUTOFILL: SYNC_TYPE_FIELDS['autofill'], + AUTOFILL_PROFILE: SYNC_TYPE_FIELDS['autofill_profile'], + BOOKMARK: SYNC_TYPE_FIELDS['bookmark'], + DEVICE_INFO: SYNC_TYPE_FIELDS['device_info'], + EXPERIMENTS: SYNC_TYPE_FIELDS['experiments'], + EXTENSION_SETTINGS: SYNC_TYPE_FIELDS['extension_setting'], + EXTENSIONS: SYNC_TYPE_FIELDS['extension'], + HISTORY_DELETE_DIRECTIVE: SYNC_TYPE_FIELDS['history_delete_directive'], + NIGORI: SYNC_TYPE_FIELDS['nigori'], + PASSWORD: SYNC_TYPE_FIELDS['password'], + PREFERENCE: SYNC_TYPE_FIELDS['preference'], + SEARCH_ENGINE: SYNC_TYPE_FIELDS['search_engine'], + SESSION: SYNC_TYPE_FIELDS['session'], + SYNCED_NOTIFICATION: SYNC_TYPE_FIELDS["synced_notification"], + THEME: SYNC_TYPE_FIELDS['theme'], + TYPED_URL: SYNC_TYPE_FIELDS['typed_url'], + } + +# The parent ID used to indicate a top-level node. +ROOT_ID = '0' + +# Unix time epoch in struct_time format. The tuple corresponds to UTC Wednesday +# Jan 1 1970, 00:00:00, non-dst. +UNIX_TIME_EPOCH = (1970, 1, 1, 0, 0, 0, 3, 1, 0) + +# The number of characters in the server-generated encryption key. +KEYSTORE_KEY_LENGTH = 16 + +# The hashed client tag for the keystore encryption experiment node. +KEYSTORE_ENCRYPTION_EXPERIMENT_TAG = "pis8ZRzh98/MKLtVEio2mr42LQA=" + +class Error(Exception): + """Error class for this module.""" + + +class ProtobufDataTypeFieldNotUnique(Error): + """An entry should not have more than one data type present.""" + + +class DataTypeIdNotRecognized(Error): + """The requested data type is not recognized.""" + + +class MigrationDoneError(Error): + """A server-side migration occurred; clients must re-sync some datatypes. + + Attributes: + datatypes: a list of the datatypes (python enum) needing migration. + """ + + def __init__(self, datatypes): + self.datatypes = datatypes + + +class StoreBirthdayError(Error): + """The client sent a birthday that doesn't correspond to this server.""" + + +class TransientError(Error): + """The client would be sent a transient error.""" + + +class SyncInducedError(Error): + """The client would be sent an error.""" + + +class InducedErrorFrequencyNotDefined(Error): + """The error frequency defined is not handled.""" + + +def GetEntryType(entry): + """Extract the sync type from a SyncEntry. + + Args: + entry: A SyncEntity protobuf object whose type to determine. + Returns: + An enum value from ALL_TYPES if the entry's type can be determined, or None + if the type cannot be determined. + Raises: + ProtobufDataTypeFieldNotUnique: More than one type was indicated by + the entry. + """ + if entry.server_defined_unique_tag == TOP_LEVEL_FOLDER_TAG: + return TOP_LEVEL + entry_types = GetEntryTypesFromSpecifics(entry.specifics) + if not entry_types: + return None + + # If there is more than one, either there's a bug, or else the caller + # should use GetEntryTypes. + if len(entry_types) > 1: + raise ProtobufDataTypeFieldNotUnique + return entry_types[0] + + +def GetEntryTypesFromSpecifics(specifics): + """Determine the sync types indicated by an EntitySpecifics's field(s). + + If the specifics have more than one recognized data type field (as commonly + happens with the requested_types field of GetUpdatesMessage), all types + will be returned. Callers must handle the possibility of the returned + value having more than one item. + + Args: + specifics: A EntitySpecifics protobuf message whose extensions to + enumerate. + Returns: + A list of the sync types (values from ALL_TYPES) associated with each + recognized extension of the specifics message. + """ + return [data_type for data_type, field_descriptor + in SYNC_TYPE_TO_DESCRIPTOR.iteritems() + if specifics.HasField(field_descriptor.name)] + + +def SyncTypeToProtocolDataTypeId(data_type): + """Convert from a sync type (python enum) to the protocol's data type id.""" + return SYNC_TYPE_TO_DESCRIPTOR[data_type].number + + +def ProtocolDataTypeIdToSyncType(protocol_data_type_id): + """Convert from the protocol's data type id to a sync type (python enum).""" + for data_type, field_descriptor in SYNC_TYPE_TO_DESCRIPTOR.iteritems(): + if field_descriptor.number == protocol_data_type_id: + return data_type + raise DataTypeIdNotRecognized + + +def DataTypeStringToSyncTypeLoose(data_type_string): + """Converts a human-readable string to a sync type (python enum). + + Capitalization and pluralization don't matter; this function is appropriate + for values that might have been typed by a human being; e.g., command-line + flags or query parameters. + """ + if data_type_string.isdigit(): + return ProtocolDataTypeIdToSyncType(int(data_type_string)) + name = data_type_string.lower().rstrip('s') + for data_type, field_descriptor in SYNC_TYPE_TO_DESCRIPTOR.iteritems(): + if field_descriptor.name.lower().rstrip('s') == name: + return data_type + raise DataTypeIdNotRecognized + + +def MakeNewKeystoreKey(): + """Returns a new random keystore key.""" + return ''.join(random.choice(string.ascii_uppercase + string.digits) + for x in xrange(KEYSTORE_KEY_LENGTH)) + + +def SyncTypeToString(data_type): + """Formats a sync type enum (from ALL_TYPES) to a human-readable string.""" + return SYNC_TYPE_TO_DESCRIPTOR[data_type].name + + +def CallerInfoToString(caller_info_source): + """Formats a GetUpdatesSource enum value to a readable string.""" + return get_updates_caller_info_pb2.GetUpdatesCallerInfo \ + .DESCRIPTOR.enum_types_by_name['GetUpdatesSource'] \ + .values_by_number[caller_info_source].name + + +def ShortDatatypeListSummary(data_types): + """Formats compactly a list of sync types (python enums) for human eyes. + + This function is intended for use by logging. If the list of datatypes + contains almost all of the values, the return value will be expressed + in terms of the datatypes that aren't set. + """ + included = set(data_types) - set([TOP_LEVEL]) + if not included: + return 'nothing' + excluded = set(ALL_TYPES) - included - set([TOP_LEVEL]) + if not excluded: + return 'everything' + simple_text = '+'.join(sorted([SyncTypeToString(x) for x in included])) + all_but_text = 'all except %s' % ( + '+'.join(sorted([SyncTypeToString(x) for x in excluded]))) + if len(included) < len(excluded) or len(simple_text) <= len(all_but_text): + return simple_text + else: + return all_but_text + + +def GetDefaultEntitySpecifics(data_type): + """Get an EntitySpecifics having a sync type's default field value.""" + specifics = sync_pb2.EntitySpecifics() + if data_type in SYNC_TYPE_TO_DESCRIPTOR: + descriptor = SYNC_TYPE_TO_DESCRIPTOR[data_type] + getattr(specifics, descriptor.name).SetInParent() + return specifics + + +class PermanentItem(object): + """A specification of one server-created permanent item. + + Attributes: + tag: A known-to-the-client value that uniquely identifies a server-created + permanent item. + name: The human-readable display name for this item. + parent_tag: The tag of the permanent item's parent. If ROOT_ID, indicates + a top-level item. Otherwise, this must be the tag value of some other + server-created permanent item. + sync_type: A value from ALL_TYPES, giving the datatype of this permanent + item. This controls which types of client GetUpdates requests will + cause the permanent item to be created and returned. + create_by_default: Whether the permanent item is created at startup or not. + This value is set to True in the default case. Non-default permanent items + are those that are created only when a client explicitly tells the server + to do so. + """ + + def __init__(self, tag, name, parent_tag, sync_type, create_by_default=True): + self.tag = tag + self.name = name + self.parent_tag = parent_tag + self.sync_type = sync_type + self.create_by_default = create_by_default + + +class MigrationHistory(object): + """A record of the migration events associated with an account. + + Each migration event invalidates one or more datatypes on all clients + that had synced the datatype before the event. Such clients will continue + to receive MigrationDone errors until they throw away their progress and + re-sync that datatype from the beginning. + """ + def __init__(self): + self._migrations = {} + for datatype in ALL_TYPES: + self._migrations[datatype] = [1] + self._next_migration_version = 2 + + def GetLatestVersion(self, datatype): + return self._migrations[datatype][-1] + + def CheckAllCurrent(self, versions_map): + """Raises an error if any the provided versions are out of date. + + This function intentionally returns migrations in the order that they were + triggered. Doing it this way allows the client to queue up two migrations + in a row, so the second one is received while responding to the first. + + Arguments: + version_map: a map whose keys are datatypes and whose values are versions. + + Raises: + MigrationDoneError: if a mismatch is found. + """ + problems = {} + for datatype, client_migration in versions_map.iteritems(): + for server_migration in self._migrations[datatype]: + if client_migration < server_migration: + problems.setdefault(server_migration, []).append(datatype) + if problems: + raise MigrationDoneError(problems[min(problems.keys())]) + + def Bump(self, datatypes): + """Add a record of a migration, to cause errors on future requests.""" + for idx, datatype in enumerate(datatypes): + self._migrations[datatype].append(self._next_migration_version) + self._next_migration_version += 1 + + +class UpdateSieve(object): + """A filter to remove items the client has already seen.""" + def __init__(self, request, migration_history=None): + self._original_request = request + self._state = {} + self._migration_history = migration_history or MigrationHistory() + self._migration_versions_to_check = {} + if request.from_progress_marker: + for marker in request.from_progress_marker: + data_type = ProtocolDataTypeIdToSyncType(marker.data_type_id) + if marker.HasField('timestamp_token_for_migration'): + timestamp = marker.timestamp_token_for_migration + if timestamp: + self._migration_versions_to_check[data_type] = 1 + elif marker.token: + (timestamp, version) = pickle.loads(marker.token) + self._migration_versions_to_check[data_type] = version + elif marker.HasField('token'): + timestamp = 0 + else: + raise ValueError('No timestamp information in progress marker.') + data_type = ProtocolDataTypeIdToSyncType(marker.data_type_id) + self._state[data_type] = timestamp + elif request.HasField('from_timestamp'): + for data_type in GetEntryTypesFromSpecifics(request.requested_types): + self._state[data_type] = request.from_timestamp + self._migration_versions_to_check[data_type] = 1 + if self._state: + self._state[TOP_LEVEL] = min(self._state.itervalues()) + + def SummarizeRequest(self): + timestamps = {} + for data_type, timestamp in self._state.iteritems(): + if data_type == TOP_LEVEL: + continue + timestamps.setdefault(timestamp, []).append(data_type) + return ', '.join('<%s>@%d' % (ShortDatatypeListSummary(types), stamp) + for stamp, types in sorted(timestamps.iteritems())) + + def CheckMigrationState(self): + self._migration_history.CheckAllCurrent(self._migration_versions_to_check) + + def ClientWantsItem(self, item): + """Return true if the client hasn't already seen an item.""" + return self._state.get(GetEntryType(item), sys.maxint) < item.version + + def HasAnyTimestamp(self): + """Return true if at least one datatype was requested.""" + return bool(self._state) + + def GetMinTimestamp(self): + """Return true the smallest timestamp requested across all datatypes.""" + return min(self._state.itervalues()) + + def GetFirstTimeTypes(self): + """Return a list of datatypes requesting updates from timestamp zero.""" + return [datatype for datatype, timestamp in self._state.iteritems() + if timestamp == 0] + + def SaveProgress(self, new_timestamp, get_updates_response): + """Write the new_timestamp or new_progress_marker fields to a response.""" + if self._original_request.from_progress_marker: + for data_type, old_timestamp in self._state.iteritems(): + if data_type == TOP_LEVEL: + continue + new_marker = sync_pb2.DataTypeProgressMarker() + new_marker.data_type_id = SyncTypeToProtocolDataTypeId(data_type) + final_stamp = max(old_timestamp, new_timestamp) + final_migration = self._migration_history.GetLatestVersion(data_type) + new_marker.token = pickle.dumps((final_stamp, final_migration)) + if new_marker not in self._original_request.from_progress_marker: + get_updates_response.new_progress_marker.add().MergeFrom(new_marker) + elif self._original_request.HasField('from_timestamp'): + if self._original_request.from_timestamp < new_timestamp: + get_updates_response.new_timestamp = new_timestamp + + +class SyncDataModel(object): + """Models the account state of one sync user.""" + _BATCH_SIZE = 100 + + # Specify all the permanent items that a model might need. + _PERMANENT_ITEM_SPECS = [ + PermanentItem('google_chrome_apps', name='Apps', + parent_tag=ROOT_ID, sync_type=APPS), + PermanentItem('google_chrome_app_notifications', name='App Notifications', + parent_tag=ROOT_ID, sync_type=APP_NOTIFICATION), + PermanentItem('google_chrome_app_settings', + name='App Settings', + parent_tag=ROOT_ID, sync_type=APP_SETTINGS), + PermanentItem('google_chrome_bookmarks', name='Bookmarks', + parent_tag=ROOT_ID, sync_type=BOOKMARK), + PermanentItem('bookmark_bar', name='Bookmark Bar', + parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK), + PermanentItem('other_bookmarks', name='Other Bookmarks', + parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK), + PermanentItem('synced_bookmarks', name='Synced Bookmarks', + parent_tag='google_chrome_bookmarks', sync_type=BOOKMARK, + create_by_default=False), # Must be True in the iOS tree. + PermanentItem('google_chrome_autofill', name='Autofill', + parent_tag=ROOT_ID, sync_type=AUTOFILL), + PermanentItem('google_chrome_autofill_profiles', name='Autofill Profiles', + parent_tag=ROOT_ID, sync_type=AUTOFILL_PROFILE), + PermanentItem('google_chrome_device_info', name='Device Info', + parent_tag=ROOT_ID, sync_type=DEVICE_INFO), + PermanentItem('google_chrome_experiments', name='Experiments', + parent_tag=ROOT_ID, sync_type=EXPERIMENTS), + PermanentItem('google_chrome_extension_settings', + name='Extension Settings', + parent_tag=ROOT_ID, sync_type=EXTENSION_SETTINGS), + PermanentItem('google_chrome_extensions', name='Extensions', + parent_tag=ROOT_ID, sync_type=EXTENSIONS), + PermanentItem('google_chrome_history_delete_directives', + name='History Delete Directives', + parent_tag=ROOT_ID, + sync_type=HISTORY_DELETE_DIRECTIVE), + PermanentItem('google_chrome_nigori', name='Nigori', + parent_tag=ROOT_ID, sync_type=NIGORI), + PermanentItem('google_chrome_passwords', name='Passwords', + parent_tag=ROOT_ID, sync_type=PASSWORD), + PermanentItem('google_chrome_preferences', name='Preferences', + parent_tag=ROOT_ID, sync_type=PREFERENCE), + PermanentItem('google_chrome_synced_notifications', + name='Synced Notifications', + parent_tag=ROOT_ID, sync_type=SYNCED_NOTIFICATION), + PermanentItem('google_chrome_search_engines', name='Search Engines', + parent_tag=ROOT_ID, sync_type=SEARCH_ENGINE), + PermanentItem('google_chrome_sessions', name='Sessions', + parent_tag=ROOT_ID, sync_type=SESSION), + PermanentItem('google_chrome_themes', name='Themes', + parent_tag=ROOT_ID, sync_type=THEME), + PermanentItem('google_chrome_typed_urls', name='Typed URLs', + parent_tag=ROOT_ID, sync_type=TYPED_URL), + ] + + def __init__(self): + # Monotonically increasing version number. The next object change will + # take on this value + 1. + self._version = 0 + + # The definitive copy of this client's items: a map from ID string to a + # SyncEntity protocol buffer. + self._entries = {} + + self.ResetStoreBirthday() + + self.migration_history = MigrationHistory() + + self.induced_error = sync_pb2.ClientToServerResponse.Error() + self.induced_error_frequency = 0 + self.sync_count_before_errors = 0 + + self._keys = [MakeNewKeystoreKey()] + + def _SaveEntry(self, entry): + """Insert or update an entry in the change log, and give it a new version. + + The ID fields of this entry are assumed to be valid server IDs. This + entry will be updated with a new version number and sync_timestamp. + + Args: + entry: The entry to be added or updated. + """ + self._version += 1 + # Maintain a global (rather than per-item) sequence number and use it + # both as the per-entry version as well as the update-progress timestamp. + # This simulates the behavior of the original server implementation. + entry.version = self._version + entry.sync_timestamp = self._version + + # Preserve the originator info, which the client is not required to send + # when updating. + base_entry = self._entries.get(entry.id_string) + if base_entry: + entry.originator_cache_guid = base_entry.originator_cache_guid + entry.originator_client_item_id = base_entry.originator_client_item_id + + self._entries[entry.id_string] = copy.deepcopy(entry) + + def _ServerTagToId(self, tag): + """Determine the server ID from a server-unique tag. + + The resulting value is guaranteed not to collide with the other ID + generation methods. + + Args: + datatype: The sync type (python enum) of the identified object. + tag: The unique, known-to-the-client tag of a server-generated item. + Returns: + The string value of the computed server ID. + """ + if not tag or tag == ROOT_ID: + return tag + spec = [x for x in self._PERMANENT_ITEM_SPECS if x.tag == tag][0] + return self._MakeCurrentId(spec.sync_type, '<server tag>%s' % tag) + + def _ClientTagToId(self, datatype, tag): + """Determine the server ID from a client-unique tag. + + The resulting value is guaranteed not to collide with the other ID + generation methods. + + Args: + datatype: The sync type (python enum) of the identified object. + tag: The unique, opaque-to-the-server tag of a client-tagged item. + Returns: + The string value of the computed server ID. + """ + return self._MakeCurrentId(datatype, '<client tag>%s' % tag) + + def _ClientIdToId(self, datatype, client_guid, client_item_id): + """Compute a unique server ID from a client-local ID tag. + + The resulting value is guaranteed not to collide with the other ID + generation methods. + + Args: + datatype: The sync type (python enum) of the identified object. + client_guid: A globally unique ID that identifies the client which + created this item. + client_item_id: An ID that uniquely identifies this item on the client + which created it. + Returns: + The string value of the computed server ID. + """ + # Using the client ID info is not required here (we could instead generate + # a random ID), but it's useful for debugging. + return self._MakeCurrentId(datatype, + '<server ID originally>%s/%s' % (client_guid, client_item_id)) + + def _MakeCurrentId(self, datatype, inner_id): + return '%d^%d^%s' % (datatype, + self.migration_history.GetLatestVersion(datatype), + inner_id) + + def _ExtractIdInfo(self, id_string): + if not id_string or id_string == ROOT_ID: + return None + datatype_string, separator, remainder = id_string.partition('^') + migration_version_string, separator, inner_id = remainder.partition('^') + return (int(datatype_string), int(migration_version_string), inner_id) + + def _WritePosition(self, entry, parent_id): + """Ensure the entry has an absolute, numeric position and parent_id. + + Historically, clients would specify positions using the predecessor-based + references in the insert_after_item_id field; starting July 2011, this + was changed and Chrome now sends up the absolute position. The server + must store a position_in_parent value and must not maintain + insert_after_item_id. + + Args: + entry: The entry for which to write a position. Its ID field are + assumed to be server IDs. This entry will have its parent_id_string + and position_in_parent fields updated; its insert_after_item_id field + will be cleared. + parent_id: The ID of the entry intended as the new parent. + """ + + entry.parent_id_string = parent_id + if not entry.HasField('position_in_parent'): + entry.position_in_parent = 1337 # A debuggable, distinctive default. + entry.ClearField('insert_after_item_id') + + def _ItemExists(self, id_string): + """Determine whether an item exists in the changelog.""" + return id_string in self._entries + + def _CreatePermanentItem(self, spec): + """Create one permanent item from its spec, if it doesn't exist. + + The resulting item is added to the changelog. + + Args: + spec: A PermanentItem object holding the properties of the item to create. + """ + id_string = self._ServerTagToId(spec.tag) + if self._ItemExists(id_string): + return + print 'Creating permanent item: %s' % spec.name + entry = sync_pb2.SyncEntity() + entry.id_string = id_string + entry.non_unique_name = spec.name + entry.name = spec.name + entry.server_defined_unique_tag = spec.tag + entry.folder = True + entry.deleted = False + entry.specifics.CopyFrom(GetDefaultEntitySpecifics(spec.sync_type)) + self._WritePosition(entry, self._ServerTagToId(spec.parent_tag)) + self._SaveEntry(entry) + + def _CreateDefaultPermanentItems(self, requested_types): + """Ensure creation of all default permanent items for a given set of types. + + Args: + requested_types: A list of sync data types from ALL_TYPES. + All default permanent items of only these types will be created. + """ + for spec in self._PERMANENT_ITEM_SPECS: + if spec.sync_type in requested_types and spec.create_by_default: + self._CreatePermanentItem(spec) + + def ResetStoreBirthday(self): + """Resets the store birthday to a random value.""" + # TODO(nick): uuid.uuid1() is better, but python 2.5 only. + self.store_birthday = '%0.30f' % random.random() + + def StoreBirthday(self): + """Gets the store birthday.""" + return self.store_birthday + + def GetChanges(self, sieve): + """Get entries which have changed, oldest first. + + The returned entries are limited to being _BATCH_SIZE many. The entries + are returned in strict version order. + + Args: + sieve: An update sieve to use to filter out updates the client + has already seen. + Returns: + A tuple of (version, entries, changes_remaining). Version is a new + timestamp value, which should be used as the starting point for the + next query. Entries is the batch of entries meeting the current + timestamp query. Changes_remaining indicates the number of changes + left on the server after this batch. + """ + if not sieve.HasAnyTimestamp(): + return (0, [], 0) + min_timestamp = sieve.GetMinTimestamp() + self._CreateDefaultPermanentItems(sieve.GetFirstTimeTypes()) + change_log = sorted(self._entries.values(), + key=operator.attrgetter('version')) + new_changes = [x for x in change_log if x.version > min_timestamp] + # Pick batch_size new changes, and then filter them. This matches + # the RPC behavior of the production sync server. + batch = new_changes[:self._BATCH_SIZE] + if not batch: + # Client is up to date. + return (min_timestamp, [], 0) + + # Restrict batch to requested types. Tombstones are untyped + # and will always get included. + filtered = [copy.deepcopy(item) for item in batch + if item.deleted or sieve.ClientWantsItem(item)] + + # The new client timestamp is the timestamp of the last item in the + # batch, even if that item was filtered out. + return (batch[-1].version, filtered, len(new_changes) - len(batch)) + + def GetKeystoreKeys(self): + """Returns the encryption keys for this account.""" + print "Returning encryption keys: %s" % self._keys + return self._keys + + def _CopyOverImmutableFields(self, entry): + """Preserve immutable fields by copying pre-commit state. + + Args: + entry: A sync entity from the client. + """ + if entry.id_string in self._entries: + if self._entries[entry.id_string].HasField( + 'server_defined_unique_tag'): + entry.server_defined_unique_tag = ( + self._entries[entry.id_string].server_defined_unique_tag) + + def _CheckVersionForCommit(self, entry): + """Perform an optimistic concurrency check on the version number. + + Clients are only allowed to commit if they report having seen the most + recent version of an object. + + Args: + entry: A sync entity from the client. It is assumed that ID fields + have been converted to server IDs. + Returns: + A boolean value indicating whether the client's version matches the + newest server version for the given entry. + """ + if entry.id_string in self._entries: + # Allow edits/deletes if the version matches, and any undeletion. + return (self._entries[entry.id_string].version == entry.version or + self._entries[entry.id_string].deleted) + else: + # Allow unknown ID only if the client thinks it's new too. + return entry.version == 0 + + def _CheckParentIdForCommit(self, entry): + """Check that the parent ID referenced in a SyncEntity actually exists. + + Args: + entry: A sync entity from the client. It is assumed that ID fields + have been converted to server IDs. + Returns: + A boolean value indicating whether the entity's parent ID is an object + that actually exists (and is not deleted) in the current account state. + """ + if entry.parent_id_string == ROOT_ID: + # This is generally allowed. + return True + if entry.parent_id_string not in self._entries: + print 'Warning: Client sent unknown ID. Should never happen.' + return False + if entry.parent_id_string == entry.id_string: + print 'Warning: Client sent circular reference. Should never happen.' + return False + if self._entries[entry.parent_id_string].deleted: + # This can happen in a race condition between two clients. + return False + if not self._entries[entry.parent_id_string].folder: + print 'Warning: Client sent non-folder parent. Should never happen.' + return False + return True + + def _RewriteIdsAsServerIds(self, entry, cache_guid, commit_session): + """Convert ID fields in a client sync entry to server IDs. + + A commit batch sent by a client may contain new items for which the + server has not generated IDs yet. And within a commit batch, later + items are allowed to refer to earlier items. This method will + generate server IDs for new items, as well as rewrite references + to items whose server IDs were generated earlier in the batch. + + Args: + entry: The client sync entry to modify. + cache_guid: The globally unique ID of the client that sent this + commit request. + commit_session: A dictionary mapping the original IDs to the new server + IDs, for any items committed earlier in the batch. + """ + if entry.version == 0: + data_type = GetEntryType(entry) + if entry.HasField('client_defined_unique_tag'): + # When present, this should determine the item's ID. + new_id = self._ClientTagToId(data_type, entry.client_defined_unique_tag) + else: + new_id = self._ClientIdToId(data_type, cache_guid, entry.id_string) + entry.originator_cache_guid = cache_guid + entry.originator_client_item_id = entry.id_string + commit_session[entry.id_string] = new_id # Remember the remapping. + entry.id_string = new_id + if entry.parent_id_string in commit_session: + entry.parent_id_string = commit_session[entry.parent_id_string] + if entry.insert_after_item_id in commit_session: + entry.insert_after_item_id = commit_session[entry.insert_after_item_id] + + def ValidateCommitEntries(self, entries): + """Raise an exception if a commit batch contains any global errors. + + Arguments: + entries: an iterable containing commit-form SyncEntity protocol buffers. + + Raises: + MigrationDoneError: if any of the entries reference a recently-migrated + datatype. + """ + server_ids_in_commit = set() + local_ids_in_commit = set() + for entry in entries: + if entry.version: + server_ids_in_commit.add(entry.id_string) + else: + local_ids_in_commit.add(entry.id_string) + if entry.HasField('parent_id_string'): + if entry.parent_id_string not in local_ids_in_commit: + server_ids_in_commit.add(entry.parent_id_string) + + versions_present = {} + for server_id in server_ids_in_commit: + parsed = self._ExtractIdInfo(server_id) + if parsed: + datatype, version, _ = parsed + versions_present.setdefault(datatype, []).append(version) + + self.migration_history.CheckAllCurrent( + dict((k, min(v)) for k, v in versions_present.iteritems())) + + def CommitEntry(self, entry, cache_guid, commit_session): + """Attempt to commit one entry to the user's account. + + Args: + entry: A SyncEntity protobuf representing desired object changes. + cache_guid: A string value uniquely identifying the client; this + is used for ID generation and will determine the originator_cache_guid + if the entry is new. + commit_session: A dictionary mapping client IDs to server IDs for any + objects committed earlier this session. If the entry gets a new ID + during commit, the change will be recorded here. + Returns: + A SyncEntity reflecting the post-commit value of the entry, or None + if the entry was not committed due to an error. + """ + entry = copy.deepcopy(entry) + + # Generate server IDs for this entry, and write generated server IDs + # from earlier entries into the message's fields, as appropriate. The + # ID generation state is stored in 'commit_session'. + self._RewriteIdsAsServerIds(entry, cache_guid, commit_session) + + # Perform the optimistic concurrency check on the entry's version number. + # Clients are not allowed to commit unless they indicate that they've seen + # the most recent version of an object. + if not self._CheckVersionForCommit(entry): + return None + + # Check the validity of the parent ID; it must exist at this point. + # TODO(nick): Implement cycle detection and resolution. + if not self._CheckParentIdForCommit(entry): + return None + + self._CopyOverImmutableFields(entry); + + # At this point, the commit is definitely going to happen. + + # Deletion works by storing a limited record for an entry, called a + # tombstone. A sync server must track deleted IDs forever, since it does + # not keep track of client knowledge (there's no deletion ACK event). + if entry.deleted: + def MakeTombstone(id_string): + """Make a tombstone entry that will replace the entry being deleted. + + Args: + id_string: Index of the SyncEntity to be deleted. + Returns: + A new SyncEntity reflecting the fact that the entry is deleted. + """ + # Only the ID, version and deletion state are preserved on a tombstone. + # TODO(nick): Does the production server not preserve the type? Not + # doing so means that tombstones cannot be filtered based on + # requested_types at GetUpdates time. + tombstone = sync_pb2.SyncEntity() + tombstone.id_string = id_string + tombstone.deleted = True + tombstone.name = '' + return tombstone + + def IsChild(child_id): + """Check if a SyncEntity is a child of entry, or any of its children. + + Args: + child_id: Index of the SyncEntity that is a possible child of entry. + Returns: + True if it is a child; false otherwise. + """ + if child_id not in self._entries: + return False + if self._entries[child_id].parent_id_string == entry.id_string: + return True + return IsChild(self._entries[child_id].parent_id_string) + + # Identify any children entry might have. + child_ids = [child.id_string for child in self._entries.itervalues() + if IsChild(child.id_string)] + + # Mark all children that were identified as deleted. + for child_id in child_ids: + self._SaveEntry(MakeTombstone(child_id)) + + # Delete entry itself. + entry = MakeTombstone(entry.id_string) + else: + # Comments in sync.proto detail how the representation of positional + # ordering works: either the 'insert_after_item_id' field or the + # 'position_in_parent' field may determine the sibling order during + # Commit operations. The 'position_in_parent' field provides an absolute + # ordering in GetUpdates contexts. Here we assume the client will + # always send a valid position_in_parent (this is the newer style), and + # we ignore insert_after_item_id (an older style). + self._WritePosition(entry, entry.parent_id_string) + + # Preserve the originator info, which the client is not required to send + # when updating. + base_entry = self._entries.get(entry.id_string) + if base_entry and not entry.HasField('originator_cache_guid'): + entry.originator_cache_guid = base_entry.originator_cache_guid + entry.originator_client_item_id = base_entry.originator_client_item_id + + # Store the current time since the Unix epoch in milliseconds. + entry.mtime = (int((time.mktime(time.gmtime()) - + time.mktime(UNIX_TIME_EPOCH))*1000)) + + # Commit the change. This also updates the version number. + self._SaveEntry(entry) + return entry + + def _RewriteVersionInId(self, id_string): + """Rewrites an ID so that its migration version becomes current.""" + parsed_id = self._ExtractIdInfo(id_string) + if not parsed_id: + return id_string + datatype, old_migration_version, inner_id = parsed_id + return self._MakeCurrentId(datatype, inner_id) + + def TriggerMigration(self, datatypes): + """Cause a migration to occur for a set of datatypes on this account. + + Clients will see the MIGRATION_DONE error for these datatypes until they + resync them. + """ + versions_to_remap = self.migration_history.Bump(datatypes) + all_entries = self._entries.values() + self._entries.clear() + for entry in all_entries: + new_id = self._RewriteVersionInId(entry.id_string) + entry.id_string = new_id + if entry.HasField('parent_id_string'): + entry.parent_id_string = self._RewriteVersionInId( + entry.parent_id_string) + self._entries[entry.id_string] = entry + + def TriggerSyncTabFavicons(self): + """Set the 'sync_tab_favicons' field to this account's nigori node. + + If the field is not currently set, will write a new nigori node entry + with the field set. Else does nothing. + """ + + nigori_tag = "google_chrome_nigori" + nigori_original = self._entries.get(self._ServerTagToId(nigori_tag)) + if (nigori_original.specifics.nigori.sync_tab_favicons): + return + nigori_new = copy.deepcopy(nigori_original) + nigori_new.specifics.nigori.sync_tabs = True + self._SaveEntry(nigori_new) + + def TriggerCreateSyncedBookmarks(self): + """Create the Synced Bookmarks folder under the Bookmarks permanent item. + + Clients will then receive the Synced Bookmarks folder on future + GetUpdates, and new bookmarks can be added within the Synced Bookmarks + folder. + """ + + synced_bookmarks_spec, = [spec for spec in self._PERMANENT_ITEM_SPECS + if spec.name == "Synced Bookmarks"] + self._CreatePermanentItem(synced_bookmarks_spec) + + def TriggerEnableKeystoreEncryption(self): + """Create the keystore_encryption experiment entity and enable it. + + A new entity within the EXPERIMENTS datatype is created with the unique + client tag "keystore_encryption" if it doesn't already exist. The + keystore_encryption message is then filled with |enabled| set to true. + """ + + experiment_id = self._ServerTagToId("google_chrome_experiments") + keystore_encryption_id = self._ClientTagToId( + EXPERIMENTS, + KEYSTORE_ENCRYPTION_EXPERIMENT_TAG) + keystore_entry = self._entries.get(keystore_encryption_id) + if keystore_entry is None: + keystore_entry = sync_pb2.SyncEntity() + keystore_entry.id_string = keystore_encryption_id + keystore_entry.name = "Keystore Encryption" + keystore_entry.client_defined_unique_tag = ( + KEYSTORE_ENCRYPTION_EXPERIMENT_TAG) + keystore_entry.folder = False + keystore_entry.deleted = False + keystore_entry.specifics.CopyFrom(GetDefaultEntitySpecifics(EXPERIMENTS)) + self._WritePosition(keystore_entry, experiment_id) + + keystore_entry.specifics.experiments.keystore_encryption.enabled = True + + self._SaveEntry(keystore_entry) + + def TriggerRotateKeystoreKeys(self): + """Rotate the current set of keystore encryption keys. + + |self._keys| will have a new random encryption key appended to it. We touch + the nigori node so that each client will receive the new encryption keys + only once. + """ + + # Add a new encryption key. + self._keys += [MakeNewKeystoreKey(), ] + + # Increment the nigori node's timestamp, so clients will get the new keys + # on their next GetUpdates (any time the nigori node is sent back, we also + # send back the keystore keys). + nigori_tag = "google_chrome_nigori" + self._SaveEntry(self._entries.get(self._ServerTagToId(nigori_tag))) + + def SetInducedError(self, error, error_frequency, + sync_count_before_errors): + self.induced_error = error + self.induced_error_frequency = error_frequency + self.sync_count_before_errors = sync_count_before_errors + + def GetInducedError(self): + return self.induced_error + + +class TestServer(object): + """An object to handle requests for one (and only one) Chrome Sync account. + + TestServer consumes the sync command messages that are the outermost + layers of the protocol, performs the corresponding actions on its + SyncDataModel, and constructs an appropriate response message. + """ + + def __init__(self): + # The implementation supports exactly one account; its state is here. + self.account = SyncDataModel() + self.account_lock = threading.Lock() + # Clients that have talked to us: a map from the full client ID + # to its nickname. + self.clients = {} + self.client_name_generator = ('+' * times + chr(c) + for times in xrange(0, sys.maxint) for c in xrange(ord('A'), ord('Z'))) + self.transient_error = False + self.sync_count = 0 + + def GetShortClientName(self, query): + parsed = cgi.parse_qs(query[query.find('?')+1:]) + client_id = parsed.get('client_id') + if not client_id: + return '?' + client_id = client_id[0] + if client_id not in self.clients: + self.clients[client_id] = self.client_name_generator.next() + return self.clients[client_id] + + def CheckStoreBirthday(self, request): + """Raises StoreBirthdayError if the request's birthday is a mismatch.""" + if not request.HasField('store_birthday'): + return + if self.account.StoreBirthday() != request.store_birthday: + raise StoreBirthdayError + + def CheckTransientError(self): + """Raises TransientError if transient_error variable is set.""" + if self.transient_error: + raise TransientError + + def CheckSendError(self): + """Raises SyncInducedError if needed.""" + if (self.account.induced_error.error_type != + sync_enums_pb2.SyncEnums.UNKNOWN): + # Always means return the given error for all requests. + if self.account.induced_error_frequency == ERROR_FREQUENCY_ALWAYS: + raise SyncInducedError + # This means the FIRST 2 requests of every 3 requests + # return an error. Don't switch the order of failures. There are + # test cases that rely on the first 2 being the failure rather than + # the last 2. + elif (self.account.induced_error_frequency == + ERROR_FREQUENCY_TWO_THIRDS): + if (((self.sync_count - + self.account.sync_count_before_errors) % 3) != 0): + raise SyncInducedError + else: + raise InducedErrorFrequencyNotDefined + + def HandleMigrate(self, path): + query = urlparse.urlparse(path)[4] + code = 200 + self.account_lock.acquire() + try: + datatypes = [DataTypeStringToSyncTypeLoose(x) + for x in urlparse.parse_qs(query).get('type',[])] + if datatypes: + self.account.TriggerMigration(datatypes) + response = 'Migrated datatypes %s' % ( + ' and '.join(SyncTypeToString(x).upper() for x in datatypes)) + else: + response = 'Please specify one or more <i>type=name</i> parameters' + code = 400 + except DataTypeIdNotRecognized, error: + response = 'Could not interpret datatype name' + code = 400 + finally: + self.account_lock.release() + return (code, '<html><title>Migration: %d</title><H1>%d %s</H1></html>' % + (code, code, response)) + + def HandleSetInducedError(self, path): + query = urlparse.urlparse(path)[4] + self.account_lock.acquire() + code = 200 + response = 'Success' + error = sync_pb2.ClientToServerResponse.Error() + try: + error_type = urlparse.parse_qs(query)['error'] + action = urlparse.parse_qs(query)['action'] + error.error_type = int(error_type[0]) + error.action = int(action[0]) + try: + error.url = (urlparse.parse_qs(query)['url'])[0] + except KeyError: + error.url = '' + try: + error.error_description =( + (urlparse.parse_qs(query)['error_description'])[0]) + except KeyError: + error.error_description = '' + try: + error_frequency = int((urlparse.parse_qs(query)['frequency'])[0]) + except KeyError: + error_frequency = ERROR_FREQUENCY_ALWAYS + self.account.SetInducedError(error, error_frequency, self.sync_count) + response = ('Error = %d, action = %d, url = %s, description = %s' % + (error.error_type, error.action, + error.url, + error.error_description)) + except error: + response = 'Could not parse url' + code = 400 + finally: + self.account_lock.release() + return (code, '<html><title>SetError: %d</title><H1>%d %s</H1></html>' % + (code, code, response)) + + def HandleCreateBirthdayError(self): + self.account.ResetStoreBirthday() + return ( + 200, + '<html><title>Birthday error</title><H1>Birthday error</H1></html>') + + def HandleSetTransientError(self): + self.transient_error = True + return ( + 200, + '<html><title>Transient error</title><H1>Transient error</H1></html>') + + def HandleSetSyncTabFavicons(self): + """Set 'sync_tab_favicons' field of the nigori node for this account.""" + self.account.TriggerSyncTabFavicons() + return ( + 200, + '<html><title>Tab Favicons</title><H1>Tab Favicons</H1></html>') + + def HandleCreateSyncedBookmarks(self): + """Create the Synced Bookmarks folder under Bookmarks.""" + self.account.TriggerCreateSyncedBookmarks() + return ( + 200, + '<html><title>Synced Bookmarks</title><H1>Synced Bookmarks</H1></html>') + + def HandleEnableKeystoreEncryption(self): + """Enables the keystore encryption experiment.""" + self.account.TriggerEnableKeystoreEncryption() + return ( + 200, + '<html><title>Enable Keystore Encryption</title>' + '<H1>Enable Keystore Encryption</H1></html>') + + def HandleRotateKeystoreKeys(self): + """Rotate the keystore encryption keys.""" + self.account.TriggerRotateKeystoreKeys() + return ( + 200, + '<html><title>Rotate Keystore Keys</title>' + '<H1>Rotate Keystore Keys</H1></html>') + + def HandleCommand(self, query, raw_request): + """Decode and handle a sync command from a raw input of bytes. + + This is the main entry point for this class. It is safe to call this + method from multiple threads. + + Args: + raw_request: An iterable byte sequence to be interpreted as a sync + protocol command. + Returns: + A tuple (response_code, raw_response); the first value is an HTTP + result code, while the second value is a string of bytes which is the + serialized reply to the command. + """ + self.account_lock.acquire() + self.sync_count += 1 + def print_context(direction): + print '[Client %s %s %s.py]' % (self.GetShortClientName(query), direction, + __name__), + + try: + request = sync_pb2.ClientToServerMessage() + request.MergeFromString(raw_request) + contents = request.message_contents + + response = sync_pb2.ClientToServerResponse() + response.error_code = sync_enums_pb2.SyncEnums.SUCCESS + self.CheckStoreBirthday(request) + response.store_birthday = self.account.store_birthday + self.CheckTransientError() + self.CheckSendError() + + print_context('->') + + if contents == sync_pb2.ClientToServerMessage.AUTHENTICATE: + print 'Authenticate' + # We accept any authentication token, and support only one account. + # TODO(nick): Mock out the GAIA authentication as well; hook up here. + response.authenticate.user.email = 'syncjuser@chromium' + response.authenticate.user.display_name = 'Sync J User' + elif contents == sync_pb2.ClientToServerMessage.COMMIT: + print 'Commit %d item(s)' % len(request.commit.entries) + self.HandleCommit(request.commit, response.commit) + elif contents == sync_pb2.ClientToServerMessage.GET_UPDATES: + print 'GetUpdates', + self.HandleGetUpdates(request.get_updates, response.get_updates) + print_context('<-') + print '%d update(s)' % len(response.get_updates.entries) + else: + print 'Unrecognizable sync request!' + return (400, None) # Bad request. + return (200, response.SerializeToString()) + except MigrationDoneError, error: + print_context('<-') + print 'MIGRATION_DONE: <%s>' % (ShortDatatypeListSummary(error.datatypes)) + response = sync_pb2.ClientToServerResponse() + response.store_birthday = self.account.store_birthday + response.error_code = sync_enums_pb2.SyncEnums.MIGRATION_DONE + response.migrated_data_type_id[:] = [ + SyncTypeToProtocolDataTypeId(x) for x in error.datatypes] + return (200, response.SerializeToString()) + except StoreBirthdayError, error: + print_context('<-') + print 'NOT_MY_BIRTHDAY' + response = sync_pb2.ClientToServerResponse() + response.store_birthday = self.account.store_birthday + response.error_code = sync_enums_pb2.SyncEnums.NOT_MY_BIRTHDAY + return (200, response.SerializeToString()) + except TransientError, error: + ### This is deprecated now. Would be removed once test cases are removed. + print_context('<-') + print 'TRANSIENT_ERROR' + response.store_birthday = self.account.store_birthday + response.error_code = sync_enums_pb2.SyncEnums.TRANSIENT_ERROR + return (200, response.SerializeToString()) + except SyncInducedError, error: + print_context('<-') + print 'INDUCED_ERROR' + response.store_birthday = self.account.store_birthday + error = self.account.GetInducedError() + response.error.error_type = error.error_type + response.error.url = error.url + response.error.error_description = error.error_description + response.error.action = error.action + return (200, response.SerializeToString()) + finally: + self.account_lock.release() + + def HandleCommit(self, commit_message, commit_response): + """Respond to a Commit request by updating the user's account state. + + Commit attempts stop after the first error, returning a CONFLICT result + for any unattempted entries. + + Args: + commit_message: A sync_pb.CommitMessage protobuf holding the content + of the client's request. + commit_response: A sync_pb.CommitResponse protobuf into which a reply + to the client request will be written. + """ + commit_response.SetInParent() + batch_failure = False + session = {} # Tracks ID renaming during the commit operation. + guid = commit_message.cache_guid + + self.account.ValidateCommitEntries(commit_message.entries) + + for entry in commit_message.entries: + server_entry = None + if not batch_failure: + # Try to commit the change to the account. + server_entry = self.account.CommitEntry(entry, guid, session) + + # An entryresponse is returned in both success and failure cases. + reply = commit_response.entryresponse.add() + if not server_entry: + reply.response_type = sync_pb2.CommitResponse.CONFLICT + reply.error_message = 'Conflict.' + batch_failure = True # One failure halts the batch. + else: + reply.response_type = sync_pb2.CommitResponse.SUCCESS + # These are the properties that the server is allowed to override + # during commit; the client wants to know their values at the end + # of the operation. + reply.id_string = server_entry.id_string + if not server_entry.deleted: + # Note: the production server doesn't actually send the + # parent_id_string on commit responses, so we don't either. + reply.position_in_parent = server_entry.position_in_parent + reply.version = server_entry.version + reply.name = server_entry.name + reply.non_unique_name = server_entry.non_unique_name + else: + reply.version = entry.version + 1 + + def HandleGetUpdates(self, update_request, update_response): + """Respond to a GetUpdates request by querying the user's account. + + Args: + update_request: A sync_pb.GetUpdatesMessage protobuf holding the content + of the client's request. + update_response: A sync_pb.GetUpdatesResponse protobuf into which a reply + to the client request will be written. + """ + update_response.SetInParent() + update_sieve = UpdateSieve(update_request, self.account.migration_history) + + print CallerInfoToString(update_request.caller_info.source), + print update_sieve.SummarizeRequest() + + update_sieve.CheckMigrationState() + + new_timestamp, entries, remaining = self.account.GetChanges(update_sieve) + + update_response.changes_remaining = remaining + sending_nigori_node = False + for entry in entries: + if entry.name == 'Nigori': + sending_nigori_node = True + reply = update_response.entries.add() + reply.CopyFrom(entry) + update_sieve.SaveProgress(new_timestamp, update_response) + + if update_request.need_encryption_key or sending_nigori_node: + update_response.encryption_keys.extend(self.account.GetKeystoreKeys()) diff --git a/sync/tools/testserver/chromiumsync_test.py b/sync/tools/testserver/chromiumsync_test.py new file mode 100755 index 0000000..e56c04b --- /dev/null +++ b/sync/tools/testserver/chromiumsync_test.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python +# Copyright 2013 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""Tests exercising chromiumsync and SyncDataModel.""" + +import pickle +import unittest + +import autofill_specifics_pb2 +import bookmark_specifics_pb2 +import chromiumsync +import sync_pb2 +import theme_specifics_pb2 + +class SyncDataModelTest(unittest.TestCase): + def setUp(self): + self.model = chromiumsync.SyncDataModel() + # The Synced Bookmarks folder is not created by default + self._expect_synced_bookmarks_folder = False + + def AddToModel(self, proto): + self.model._entries[proto.id_string] = proto + + def GetChangesFromTimestamp(self, requested_types, timestamp): + message = sync_pb2.GetUpdatesMessage() + message.from_timestamp = timestamp + for data_type in requested_types: + getattr(message.requested_types, + chromiumsync.SYNC_TYPE_TO_DESCRIPTOR[ + data_type].name).SetInParent() + return self.model.GetChanges( + chromiumsync.UpdateSieve(message, self.model.migration_history)) + + def FindMarkerByNumber(self, markers, datatype): + """Search a list of progress markers and find the one for a datatype.""" + for marker in markers: + if marker.data_type_id == datatype.number: + return marker + self.fail('Required marker not found: %s' % datatype.name) + + def testPermanentItemSpecs(self): + specs = chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS + + declared_specs = set(['0']) + for spec in specs: + self.assertTrue(spec.parent_tag in declared_specs, 'parent tags must ' + 'be declared before use') + declared_specs.add(spec.tag) + + unique_datatypes = set([x.sync_type for x in specs]) + self.assertEqual(unique_datatypes, + set(chromiumsync.ALL_TYPES[1:]), + 'Every sync datatype should have a permanent folder ' + 'associated with it') + + def testSaveEntry(self): + proto = sync_pb2.SyncEntity() + proto.id_string = 'abcd' + proto.version = 0 + self.assertFalse(self.model._ItemExists(proto.id_string)) + self.model._SaveEntry(proto) + self.assertEqual(1, proto.version) + self.assertTrue(self.model._ItemExists(proto.id_string)) + self.model._SaveEntry(proto) + self.assertEqual(2, proto.version) + proto.version = 0 + self.assertTrue(self.model._ItemExists(proto.id_string)) + self.assertEqual(2, self.model._entries[proto.id_string].version) + + def testCreatePermanentItems(self): + self.model._CreateDefaultPermanentItems(chromiumsync.ALL_TYPES) + self.assertEqual(len(chromiumsync.ALL_TYPES) + 1, + len(self.model._entries)) + + def ExpectedPermanentItemCount(self, sync_type): + if sync_type == chromiumsync.BOOKMARK: + if self._expect_synced_bookmarks_folder: + return 4 + else: + return 3 + else: + return 1 + + def testGetChangesFromTimestampZeroForEachType(self): + all_types = chromiumsync.ALL_TYPES[1:] + for sync_type in all_types: + self.model = chromiumsync.SyncDataModel() + request_types = [sync_type] + + version, changes, remaining = ( + self.GetChangesFromTimestamp(request_types, 0)) + + expected_count = self.ExpectedPermanentItemCount(sync_type) + self.assertEqual(expected_count, version) + self.assertEqual(expected_count, len(changes)) + for change in changes: + self.assertTrue(change.HasField('server_defined_unique_tag')) + self.assertEqual(change.version, change.sync_timestamp) + self.assertTrue(change.version <= version) + + # Test idempotence: another GetUpdates from ts=0 shouldn't recreate. + version, changes, remaining = ( + self.GetChangesFromTimestamp(request_types, 0)) + self.assertEqual(expected_count, version) + self.assertEqual(expected_count, len(changes)) + self.assertEqual(0, remaining) + + # Doing a wider GetUpdates from timestamp zero shouldn't recreate either. + new_version, changes, remaining = ( + self.GetChangesFromTimestamp(all_types, 0)) + if self._expect_synced_bookmarks_folder: + self.assertEqual(len(chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS), + new_version) + else: + self.assertEqual( + len(chromiumsync.SyncDataModel._PERMANENT_ITEM_SPECS) -1, + new_version) + self.assertEqual(new_version, len(changes)) + self.assertEqual(0, remaining) + version, changes, remaining = ( + self.GetChangesFromTimestamp(request_types, 0)) + self.assertEqual(new_version, version) + self.assertEqual(expected_count, len(changes)) + self.assertEqual(0, remaining) + + def testBatchSize(self): + for sync_type in chromiumsync.ALL_TYPES[1:]: + specifics = chromiumsync.GetDefaultEntitySpecifics(sync_type) + self.model = chromiumsync.SyncDataModel() + request_types = [sync_type] + + for i in range(self.model._BATCH_SIZE*3): + entry = sync_pb2.SyncEntity() + entry.id_string = 'batch test %d' % i + entry.specifics.CopyFrom(specifics) + self.model._SaveEntry(entry) + last_bit = self.ExpectedPermanentItemCount(sync_type) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, 0)) + self.assertEqual(self.model._BATCH_SIZE, version) + self.assertEqual(self.model._BATCH_SIZE*2 + last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(self.model._BATCH_SIZE*2, version) + self.assertEqual(self.model._BATCH_SIZE + last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(self.model._BATCH_SIZE*3, version) + self.assertEqual(last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(self.model._BATCH_SIZE*3 + last_bit, version) + self.assertEqual(0, changes_remaining) + + # Now delete a third of the items. + for i in xrange(self.model._BATCH_SIZE*3 - 1, 0, -3): + entry = sync_pb2.SyncEntity() + entry.id_string = 'batch test %d' % i + entry.deleted = True + self.model._SaveEntry(entry) + + # The batch counts shouldn't change. + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, 0)) + self.assertEqual(self.model._BATCH_SIZE, len(changes)) + self.assertEqual(self.model._BATCH_SIZE*2 + last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(self.model._BATCH_SIZE, len(changes)) + self.assertEqual(self.model._BATCH_SIZE + last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(self.model._BATCH_SIZE, len(changes)) + self.assertEqual(last_bit, changes_remaining) + version, changes, changes_remaining = ( + self.GetChangesFromTimestamp(request_types, version)) + self.assertEqual(last_bit, len(changes)) + self.assertEqual(self.model._BATCH_SIZE*4 + last_bit, version) + self.assertEqual(0, changes_remaining) + + def testCommitEachDataType(self): + for sync_type in chromiumsync.ALL_TYPES[1:]: + specifics = chromiumsync.GetDefaultEntitySpecifics(sync_type) + self.model = chromiumsync.SyncDataModel() + my_cache_guid = '112358132134' + parent = 'foobar' + commit_session = {} + + # Start with a GetUpdates from timestamp 0, to populate permanent items. + original_version, original_changes, changes_remaining = ( + self.GetChangesFromTimestamp([sync_type], 0)) + + def DoCommit(original=None, id_string='', name=None, parent=None, + position=0): + proto = sync_pb2.SyncEntity() + if original is not None: + proto.version = original.version + proto.id_string = original.id_string + proto.parent_id_string = original.parent_id_string + proto.name = original.name + else: + proto.id_string = id_string + proto.version = 0 + proto.specifics.CopyFrom(specifics) + if name is not None: + proto.name = name + if parent: + proto.parent_id_string = parent.id_string + proto.insert_after_item_id = 'please discard' + proto.position_in_parent = position + proto.folder = True + proto.deleted = False + result = self.model.CommitEntry(proto, my_cache_guid, commit_session) + self.assertTrue(result) + return (proto, result) + + # Commit a new item. + proto1, result1 = DoCommit(name='namae', id_string='Foo', + parent=original_changes[-1], position=100) + # Commit an item whose parent is another item (referenced via the + # pre-commit ID). + proto2, result2 = DoCommit(name='Secondo', id_string='Bar', + parent=proto1, position=-100) + # Commit a sibling of the second item. + proto3, result3 = DoCommit(name='Third!', id_string='Baz', + parent=proto1, position=-50) + + self.assertEqual(3, len(commit_session)) + for p, r in [(proto1, result1), (proto2, result2), (proto3, result3)]: + self.assertNotEqual(r.id_string, p.id_string) + self.assertEqual(r.originator_client_item_id, p.id_string) + self.assertEqual(r.originator_cache_guid, my_cache_guid) + self.assertTrue(r is not self.model._entries[r.id_string], + "Commit result didn't make a defensive copy.") + self.assertTrue(p is not self.model._entries[r.id_string], + "Commit result didn't make a defensive copy.") + self.assertEqual(commit_session.get(p.id_string), r.id_string) + self.assertTrue(r.version > original_version) + self.assertEqual(result1.parent_id_string, proto1.parent_id_string) + self.assertEqual(result2.parent_id_string, result1.id_string) + version, changes, remaining = ( + self.GetChangesFromTimestamp([sync_type], original_version)) + self.assertEqual(3, len(changes)) + self.assertEqual(0, remaining) + self.assertEqual(original_version + 3, version) + self.assertEqual([result1, result2, result3], changes) + for c in changes: + self.assertTrue(c is not self.model._entries[c.id_string], + "GetChanges didn't make a defensive copy.") + self.assertTrue(result2.position_in_parent < result3.position_in_parent) + self.assertEqual(-100, result2.position_in_parent) + + # Now update the items so that the second item is the parent of the + # first; with the first sandwiched between two new items (4 and 5). + # Do this in a new commit session, meaning we'll reference items from + # the first batch by their post-commit, server IDs. + commit_session = {} + old_cache_guid = my_cache_guid + my_cache_guid = 'A different GUID' + proto2b, result2b = DoCommit(original=result2, + parent=original_changes[-1]) + proto4, result4 = DoCommit(id_string='ID4', name='Four', + parent=result2, position=-200) + proto1b, result1b = DoCommit(original=result1, + parent=result2, position=-150) + proto5, result5 = DoCommit(id_string='ID5', name='Five', parent=result2, + position=150) + + self.assertEqual(2, len(commit_session), 'Only new items in second ' + 'batch should be in the session') + for p, r, original in [(proto2b, result2b, proto2), + (proto4, result4, proto4), + (proto1b, result1b, proto1), + (proto5, result5, proto5)]: + self.assertEqual(r.originator_client_item_id, original.id_string) + if original is not p: + self.assertEqual(r.id_string, p.id_string, + 'Ids should be stable after first commit') + self.assertEqual(r.originator_cache_guid, old_cache_guid) + else: + self.assertNotEqual(r.id_string, p.id_string) + self.assertEqual(r.originator_cache_guid, my_cache_guid) + self.assertEqual(commit_session.get(p.id_string), r.id_string) + self.assertTrue(r is not self.model._entries[r.id_string], + "Commit result didn't make a defensive copy.") + self.assertTrue(p is not self.model._entries[r.id_string], + "Commit didn't make a defensive copy.") + self.assertTrue(r.version > p.version) + version, changes, remaining = ( + self.GetChangesFromTimestamp([sync_type], original_version)) + self.assertEqual(5, len(changes)) + self.assertEqual(0, remaining) + self.assertEqual(original_version + 7, version) + self.assertEqual([result3, result2b, result4, result1b, result5], changes) + for c in changes: + self.assertTrue(c is not self.model._entries[c.id_string], + "GetChanges didn't make a defensive copy.") + self.assertTrue(result4.parent_id_string == + result1b.parent_id_string == + result5.parent_id_string == + result2b.id_string) + self.assertTrue(result4.position_in_parent < + result1b.position_in_parent < + result5.position_in_parent) + + def testUpdateSieve(self): + # from_timestamp, legacy mode + autofill = chromiumsync.SYNC_TYPE_FIELDS['autofill'] + theme = chromiumsync.SYNC_TYPE_FIELDS['theme'] + msg = sync_pb2.GetUpdatesMessage() + msg.from_timestamp = 15412 + msg.requested_types.autofill.SetInParent() + msg.requested_types.theme.SetInParent() + + sieve = chromiumsync.UpdateSieve(msg) + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 15412, + chromiumsync.AUTOFILL: 15412, + chromiumsync.THEME: 15412}) + + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(15412, response) + self.assertEqual(0, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(15413, response) + self.assertEqual(0, len(response.new_progress_marker)) + self.assertTrue(response.HasField('new_timestamp')) + self.assertEqual(15413, response.new_timestamp) + + # Existing tokens + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = pickle.dumps((15412, 1)) + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 1)) + sieve = chromiumsync.UpdateSieve(msg) + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 15412, + chromiumsync.AUTOFILL: 15412, + chromiumsync.THEME: 15413}) + + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(15413, response) + self.assertEqual(1, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = response.new_progress_marker[0] + self.assertEqual(marker.data_type_id, autofill.number) + self.assertEqual(pickle.loads(marker.token), (15413, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + # Empty tokens indicating from timestamp = 0 + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = pickle.dumps((412, 1)) + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = '' + sieve = chromiumsync.UpdateSieve(msg) + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 0, + chromiumsync.AUTOFILL: 412, + chromiumsync.THEME: 0}) + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(1, response) + self.assertEqual(1, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = response.new_progress_marker[0] + self.assertEqual(marker.data_type_id, theme.number) + self.assertEqual(pickle.loads(marker.token), (1, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(412, response) + self.assertEqual(1, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = response.new_progress_marker[0] + self.assertEqual(marker.data_type_id, theme.number) + self.assertEqual(pickle.loads(marker.token), (412, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(413, response) + self.assertEqual(2, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = self.FindMarkerByNumber(response.new_progress_marker, theme) + self.assertEqual(pickle.loads(marker.token), (413, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, autofill) + self.assertEqual(pickle.loads(marker.token), (413, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + # Migration token timestamps (client gives timestamp, server returns token) + # These are for migrating from the old 'timestamp' protocol to the + # progressmarker protocol, and have nothing to do with the MIGRATION_DONE + # error code. + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.timestamp_token_for_migration = 15213 + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.timestamp_token_for_migration = 15211 + sieve = chromiumsync.UpdateSieve(msg) + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 15211, + chromiumsync.AUTOFILL: 15213, + chromiumsync.THEME: 15211}) + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(16000, response) # There were updates + self.assertEqual(2, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = self.FindMarkerByNumber(response.new_progress_marker, theme) + self.assertEqual(pickle.loads(marker.token), (16000, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, autofill) + self.assertEqual(pickle.loads(marker.token), (16000, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.timestamp_token_for_migration = 3000 + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.timestamp_token_for_migration = 3000 + sieve = chromiumsync.UpdateSieve(msg) + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 3000, + chromiumsync.AUTOFILL: 3000, + chromiumsync.THEME: 3000}) + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(3000, response) # Already up to date + self.assertEqual(2, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + marker = self.FindMarkerByNumber(response.new_progress_marker, theme) + self.assertEqual(pickle.loads(marker.token), (3000, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, autofill) + self.assertEqual(pickle.loads(marker.token), (3000, 1)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + + def testCheckRaiseTransientError(self): + testserver = chromiumsync.TestServer() + http_code, raw_respon = testserver.HandleSetTransientError() + self.assertEqual(http_code, 200) + try: + testserver.CheckTransientError() + self.fail('Should have raised transient error exception') + except chromiumsync.TransientError: + self.assertTrue(testserver.transient_error) + + def testUpdateSieveStoreMigration(self): + autofill = chromiumsync.SYNC_TYPE_FIELDS['autofill'] + theme = chromiumsync.SYNC_TYPE_FIELDS['theme'] + migrator = chromiumsync.MigrationHistory() + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = pickle.dumps((15412, 1)) + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 1)) + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + + migrator.Bump([chromiumsync.BOOKMARK, chromiumsync.PASSWORD]) # v=2 + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + self.assertEqual(sieve._state, + {chromiumsync.TOP_LEVEL: 15412, + chromiumsync.AUTOFILL: 15412, + chromiumsync.THEME: 15413}) + + migrator.Bump([chromiumsync.AUTOFILL, chromiumsync.PASSWORD]) # v=3 + sieve = chromiumsync.UpdateSieve(msg, migrator) + try: + sieve.CheckMigrationState() + self.fail('Should have raised.') + except chromiumsync.MigrationDoneError, error: + # We want this to happen. + self.assertEqual([chromiumsync.AUTOFILL], error.datatypes) + + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = '' + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 1)) + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(15412, response) # There were updates + self.assertEqual(1, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, autofill) + self.assertEqual(pickle.loads(marker.token), (15412, 3)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = pickle.dumps((15412, 3)) + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 1)) + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + + migrator.Bump([chromiumsync.THEME, chromiumsync.AUTOFILL]) # v=4 + migrator.Bump([chromiumsync.AUTOFILL]) # v=5 + sieve = chromiumsync.UpdateSieve(msg, migrator) + try: + sieve.CheckMigrationState() + self.fail("Should have raised.") + except chromiumsync.MigrationDoneError, error: + # We want this to happen. + self.assertEqual(set([chromiumsync.THEME, chromiumsync.AUTOFILL]), + set(error.datatypes)) + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = '' + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 1)) + sieve = chromiumsync.UpdateSieve(msg, migrator) + try: + sieve.CheckMigrationState() + self.fail("Should have raised.") + except chromiumsync.MigrationDoneError, error: + # We want this to happen. + self.assertEqual([chromiumsync.THEME], error.datatypes) + + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = '' + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = '' + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + response = sync_pb2.GetUpdatesResponse() + sieve.SaveProgress(15412, response) # There were updates + self.assertEqual(2, len(response.new_progress_marker)) + self.assertFalse(response.HasField('new_timestamp')) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, autofill) + self.assertEqual(pickle.loads(marker.token), (15412, 5)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + marker = self.FindMarkerByNumber(response.new_progress_marker, theme) + self.assertEqual(pickle.loads(marker.token), (15412, 4)) + self.assertFalse(marker.HasField('timestamp_token_for_migration')) + msg = sync_pb2.GetUpdatesMessage() + marker = msg.from_progress_marker.add() + marker.data_type_id = autofill.number + marker.token = pickle.dumps((15412, 5)) + marker = msg.from_progress_marker.add() + marker.data_type_id = theme.number + marker.token = pickle.dumps((15413, 4)) + sieve = chromiumsync.UpdateSieve(msg, migrator) + sieve.CheckMigrationState() + + def testCreateSyncedBookmaks(self): + version1, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], 0)) + id_string = self.model._MakeCurrentId(chromiumsync.BOOKMARK, + '<server tag>synced_bookmarks') + self.assertFalse(self.model._ItemExists(id_string)) + self._expect_synced_bookmarks_folder = True + self.model.TriggerCreateSyncedBookmarks() + self.assertTrue(self.model._ItemExists(id_string)) + + # Check that the version changed when the folder was created and the only + # change was the folder creation. + version2, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], version1)) + self.assertEqual(len(changes), 1) + self.assertEqual(changes[0].id_string, id_string) + self.assertNotEqual(version1, version2) + self.assertEqual( + self.ExpectedPermanentItemCount(chromiumsync.BOOKMARK), + version2) + + # Ensure getting from timestamp 0 includes the folder. + version, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.BOOKMARK], 0)) + self.assertEqual( + self.ExpectedPermanentItemCount(chromiumsync.BOOKMARK), + len(changes)) + self.assertEqual(version2, version) + + def testGetKey(self): + [key1] = self.model.GetKeystoreKeys() + [key2] = self.model.GetKeystoreKeys() + self.assertTrue(len(key1)) + self.assertEqual(key1, key2) + + # Trigger the rotation. A subsequent GetUpdates should return the nigori + # node (whose timestamp was bumped by the rotation). + version1, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.NIGORI], 0)) + self.model.TriggerRotateKeystoreKeys() + version2, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.NIGORI], version1)) + self.assertNotEqual(version1, version2) + self.assertEquals(len(changes), 1) + self.assertEquals(changes[0].name, "Nigori") + + # The current keys should contain the old keys, with the new key appended. + [key1, key3] = self.model.GetKeystoreKeys() + self.assertEquals(key1, key2) + self.assertNotEqual(key1, key3) + self.assertTrue(len(key3) > 0) + + def testTriggerEnableKeystoreEncryption(self): + version1, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.EXPERIMENTS], 0)) + keystore_encryption_id_string = ( + self.model._ClientTagToId( + chromiumsync.EXPERIMENTS, + chromiumsync.KEYSTORE_ENCRYPTION_EXPERIMENT_TAG)) + + self.assertFalse(self.model._ItemExists(keystore_encryption_id_string)) + self.model.TriggerEnableKeystoreEncryption() + self.assertTrue(self.model._ItemExists(keystore_encryption_id_string)) + + # The creation of the experiment should be downloaded on the next + # GetUpdates. + version2, changes, remaining = ( + self.GetChangesFromTimestamp([chromiumsync.EXPERIMENTS], version1)) + self.assertEqual(len(changes), 1) + self.assertEqual(changes[0].id_string, keystore_encryption_id_string) + self.assertNotEqual(version1, version2) + + # Verify the experiment was created properly and is enabled. + self.assertEqual(chromiumsync.KEYSTORE_ENCRYPTION_EXPERIMENT_TAG, + changes[0].client_defined_unique_tag) + self.assertTrue(changes[0].HasField("specifics")) + self.assertTrue(changes[0].specifics.HasField("experiments")) + self.assertTrue( + changes[0].specifics.experiments.HasField("keystore_encryption")) + self.assertTrue( + changes[0].specifics.experiments.keystore_encryption.enabled) + +if __name__ == '__main__': + unittest.main() diff --git a/sync/tools/testserver/run_sync_testserver.cc b/sync/tools/testserver/run_sync_testserver.cc new file mode 100644 index 0000000..74f186d --- /dev/null +++ b/sync/tools/testserver/run_sync_testserver.cc @@ -0,0 +1,121 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stdio.h> + +#include "base/at_exit.h" +#include "base/command_line.h" +#include "base/file_path.h" +#include "base/logging.h" +#include "base/message_loop.h" +#include "base/process_util.h" +#include "base/string_number_conversions.h" +#include "base/test/test_timeouts.h" +#include "net/test/python_utils.h" +#include "sync/test/local_sync_test_server.h" + +static void PrintUsage() { + printf("run_sync_testserver [--port=<port>] [--xmpp-port=<xmpp_port>]\n"); +} + +// Launches the chromiumsync_test.py or xmppserver_test.py scripts, which test +// the sync HTTP and XMPP sever functionality respectively. +static bool RunSyncTest(const FilePath::StringType& sync_test_script_name) { + scoped_ptr<syncer::LocalSyncTestServer> test_server( + new syncer::LocalSyncTestServer()); + if (!test_server->SetPythonPath()) { + LOG(ERROR) << "Error trying to set python path. Exiting."; + return false; + } + + FilePath sync_test_script_path; + if (!test_server->GetTestScriptPath(sync_test_script_name, + &sync_test_script_path)) { + LOG(ERROR) << "Error trying to get path for test script " + << sync_test_script_name; + return false; + } + + CommandLine python_command(CommandLine::NO_PROGRAM); + if (!GetPythonCommand(&python_command)) { + LOG(ERROR) << "Could not get python runtime command."; + return false; + } + + python_command.AppendArgPath(sync_test_script_path); + if (!base::LaunchProcess(python_command, base::LaunchOptions(), NULL)) { + LOG(ERROR) << "Failed to launch test script " << sync_test_script_name; + return false; + } + return true; +} + +// Gets a port value from the switch with name |switch_name| and writes it to +// |port|. Returns true if a port was provided and false otherwise. +static bool GetPortFromSwitch(const std::string& switch_name, uint16* port) { + DCHECK(port != NULL) << "|port| is NULL"; + CommandLine* command_line = CommandLine::ForCurrentProcess(); + int port_int = 0; + if (command_line->HasSwitch(switch_name)) { + std::string port_str = command_line->GetSwitchValueASCII(switch_name); + if (!base::StringToInt(port_str, &port_int)) { + return false; + } + } + *port = static_cast<uint16>(port_int); + return true; +} + +int main(int argc, const char* argv[]) { + base::AtExitManager at_exit_manager; + MessageLoopForIO message_loop; + + // Process command line + CommandLine::Init(argc, argv); + CommandLine* command_line = CommandLine::ForCurrentProcess(); + + if (!logging::InitLogging( + FILE_PATH_LITERAL("sync_testserver.log"), + logging::LOG_TO_BOTH_FILE_AND_SYSTEM_DEBUG_LOG, + logging::LOCK_LOG_FILE, + logging::APPEND_TO_OLD_LOG_FILE, + logging::DISABLE_DCHECK_FOR_NON_OFFICIAL_RELEASE_BUILDS)) { + printf("Error: could not initialize logging. Exiting.\n"); + return -1; + } + + TestTimeouts::Initialize(); + + if (command_line->HasSwitch("help")) { + PrintUsage(); + return 0; + } + + if (command_line->HasSwitch("sync-test")) { + return RunSyncTest(FILE_PATH_LITERAL("chromiumsync_test.py")) ? 0 : -1; + } + + if (command_line->HasSwitch("xmpp-test")) { + return RunSyncTest(FILE_PATH_LITERAL("xmppserver_test.py")) ? 0 : -1; + } + + uint16 port = 0; + GetPortFromSwitch("port", &port); + + uint16 xmpp_port = 0; + GetPortFromSwitch("xmpp-port", &xmpp_port); + + scoped_ptr<syncer::LocalSyncTestServer> test_server( + new syncer::LocalSyncTestServer(port, xmpp_port)); + if (!test_server->Start()) { + printf("Error: failed to start python sync test server. Exiting.\n"); + return -1; + } + + printf("Python sync test server running at %s (type ctrl+c to exit)\n", + test_server->host_port_pair().ToString().c_str()); + + message_loop.Run(); + return 0; +} diff --git a/sync/tools/testserver/sync_testserver.py b/sync/tools/testserver/sync_testserver.py new file mode 100755 index 0000000..6925776 --- /dev/null +++ b/sync/tools/testserver/sync_testserver.py @@ -0,0 +1,447 @@ +#!/usr/bin/env python +# Copyright 2013 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""This is a python sync server used for testing Chrome Sync. + +By default, it listens on an ephemeral port and xmpp_port and sends the port +numbers back to the originating process over a pipe. The originating process can +specify an explicit port and xmpp_port if necessary. +""" + +import asyncore +import BaseHTTPServer +import errno +import os +import select +import socket +import sys +import urlparse + +import chromiumsync +import echo_message +import testserver_base +import xmppserver + + +class SyncHTTPServer(testserver_base.ClientRestrictingServerMixIn, + testserver_base.BrokenPipeHandlerMixIn, + testserver_base.StoppableHTTPServer): + """An HTTP server that handles sync commands.""" + + def __init__(self, server_address, xmpp_port, request_handler_class): + testserver_base.StoppableHTTPServer.__init__(self, + server_address, + request_handler_class) + self._sync_handler = chromiumsync.TestServer() + self._xmpp_socket_map = {} + self._xmpp_server = xmppserver.XmppServer( + self._xmpp_socket_map, ('localhost', xmpp_port)) + self.xmpp_port = self._xmpp_server.getsockname()[1] + self.authenticated = True + + def GetXmppServer(self): + return self._xmpp_server + + def HandleCommand(self, query, raw_request): + return self._sync_handler.HandleCommand(query, raw_request) + + def HandleRequestNoBlock(self): + """Handles a single request. + + Copied from SocketServer._handle_request_noblock(). + """ + + try: + request, client_address = self.get_request() + except socket.error: + return + if self.verify_request(request, client_address): + try: + self.process_request(request, client_address) + except Exception: + self.handle_error(request, client_address) + self.close_request(request) + + def SetAuthenticated(self, auth_valid): + self.authenticated = auth_valid + + def GetAuthenticated(self): + return self.authenticated + + def serve_forever(self): + """This is a merge of asyncore.loop() and SocketServer.serve_forever(). + """ + + def HandleXmppSocket(fd, socket_map, handler): + """Runs the handler for the xmpp connection for fd. + + Adapted from asyncore.read() et al. + """ + + xmpp_connection = socket_map.get(fd) + # This could happen if a previous handler call caused fd to get + # removed from socket_map. + if xmpp_connection is None: + return + try: + handler(xmpp_connection) + except (asyncore.ExitNow, KeyboardInterrupt, SystemExit): + raise + except: + xmpp_connection.handle_error() + + while True: + read_fds = [ self.fileno() ] + write_fds = [] + exceptional_fds = [] + + for fd, xmpp_connection in self._xmpp_socket_map.items(): + is_r = xmpp_connection.readable() + is_w = xmpp_connection.writable() + if is_r: + read_fds.append(fd) + if is_w: + write_fds.append(fd) + if is_r or is_w: + exceptional_fds.append(fd) + + try: + read_fds, write_fds, exceptional_fds = ( + select.select(read_fds, write_fds, exceptional_fds)) + except select.error, err: + if err.args[0] != errno.EINTR: + raise + else: + continue + + for fd in read_fds: + if fd == self.fileno(): + self.HandleRequestNoBlock() + continue + HandleXmppSocket(fd, self._xmpp_socket_map, + asyncore.dispatcher.handle_read_event) + + for fd in write_fds: + HandleXmppSocket(fd, self._xmpp_socket_map, + asyncore.dispatcher.handle_write_event) + + for fd in exceptional_fds: + HandleXmppSocket(fd, self._xmpp_socket_map, + asyncore.dispatcher.handle_expt_event) + + +class SyncPageHandler(testserver_base.BasePageHandler): + """Handler for the main HTTP sync server.""" + + def __init__(self, request, client_address, sync_http_server): + get_handlers = [self.ChromiumSyncTimeHandler, + self.ChromiumSyncMigrationOpHandler, + self.ChromiumSyncCredHandler, + self.ChromiumSyncXmppCredHandler, + self.ChromiumSyncDisableNotificationsOpHandler, + self.ChromiumSyncEnableNotificationsOpHandler, + self.ChromiumSyncSendNotificationOpHandler, + self.ChromiumSyncBirthdayErrorOpHandler, + self.ChromiumSyncTransientErrorOpHandler, + self.ChromiumSyncErrorOpHandler, + self.ChromiumSyncSyncTabFaviconsOpHandler, + self.ChromiumSyncCreateSyncedBookmarksOpHandler, + self.ChromiumSyncEnableKeystoreEncryptionOpHandler, + self.ChromiumSyncRotateKeystoreKeysOpHandler] + + post_handlers = [self.ChromiumSyncCommandHandler, + self.ChromiumSyncTimeHandler] + testserver_base.BasePageHandler.__init__(self, request, client_address, + sync_http_server, [], get_handlers, + [], post_handlers, []) + + + def ChromiumSyncTimeHandler(self): + """Handle Chromium sync .../time requests. + + The syncer sometimes checks server reachability by examining /time. + """ + + test_name = "/chromiumsync/time" + if not self._ShouldHandleRequest(test_name): + return False + + # Chrome hates it if we send a response before reading the request. + if self.headers.getheader('content-length'): + length = int(self.headers.getheader('content-length')) + _raw_request = self.rfile.read(length) + + self.send_response(200) + self.send_header('Content-Type', 'text/plain') + self.end_headers() + self.wfile.write('0123456789') + return True + + def ChromiumSyncCommandHandler(self): + """Handle a chromiumsync command arriving via http. + + This covers all sync protocol commands: authentication, getupdates, and + commit. + """ + + test_name = "/chromiumsync/command" + if not self._ShouldHandleRequest(test_name): + return False + + length = int(self.headers.getheader('content-length')) + raw_request = self.rfile.read(length) + http_response = 200 + raw_reply = None + if not self.server.GetAuthenticated(): + http_response = 401 + challenge = 'GoogleLogin realm="http://%s", service="chromiumsync"' % ( + self.server.server_address[0]) + else: + http_response, raw_reply = self.server.HandleCommand( + self.path, raw_request) + + ### Now send the response to the client. ### + self.send_response(http_response) + if http_response == 401: + self.send_header('www-Authenticate', challenge) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncMigrationOpHandler(self): + test_name = "/chromiumsync/migrate" + if not self._ShouldHandleRequest(test_name): + return False + + http_response, raw_reply = self.server._sync_handler.HandleMigrate( + self.path) + self.send_response(http_response) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncCredHandler(self): + test_name = "/chromiumsync/cred" + if not self._ShouldHandleRequest(test_name): + return False + try: + query = urlparse.urlparse(self.path)[4] + cred_valid = urlparse.parse_qs(query)['valid'] + if cred_valid[0] == 'True': + self.server.SetAuthenticated(True) + else: + self.server.SetAuthenticated(False) + except Exception: + self.server.SetAuthenticated(False) + + http_response = 200 + raw_reply = 'Authenticated: %s ' % self.server.GetAuthenticated() + self.send_response(http_response) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncXmppCredHandler(self): + test_name = "/chromiumsync/xmppcred" + if not self._ShouldHandleRequest(test_name): + return False + xmpp_server = self.server.GetXmppServer() + try: + query = urlparse.urlparse(self.path)[4] + cred_valid = urlparse.parse_qs(query)['valid'] + if cred_valid[0] == 'True': + xmpp_server.SetAuthenticated(True) + else: + xmpp_server.SetAuthenticated(False) + except: + xmpp_server.SetAuthenticated(False) + + http_response = 200 + raw_reply = 'XMPP Authenticated: %s ' % xmpp_server.GetAuthenticated() + self.send_response(http_response) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncDisableNotificationsOpHandler(self): + test_name = "/chromiumsync/disablenotifications" + if not self._ShouldHandleRequest(test_name): + return False + self.server.GetXmppServer().DisableNotifications() + result = 200 + raw_reply = ('<html><title>Notifications disabled</title>' + '<H1>Notifications disabled</H1></html>') + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncEnableNotificationsOpHandler(self): + test_name = "/chromiumsync/enablenotifications" + if not self._ShouldHandleRequest(test_name): + return False + self.server.GetXmppServer().EnableNotifications() + result = 200 + raw_reply = ('<html><title>Notifications enabled</title>' + '<H1>Notifications enabled</H1></html>') + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncSendNotificationOpHandler(self): + test_name = "/chromiumsync/sendnotification" + if not self._ShouldHandleRequest(test_name): + return False + query = urlparse.urlparse(self.path)[4] + query_params = urlparse.parse_qs(query) + channel = '' + data = '' + if 'channel' in query_params: + channel = query_params['channel'][0] + if 'data' in query_params: + data = query_params['data'][0] + self.server.GetXmppServer().SendNotification(channel, data) + result = 200 + raw_reply = ('<html><title>Notification sent</title>' + '<H1>Notification sent with channel "%s" ' + 'and data "%s"</H1></html>' + % (channel, data)) + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncBirthdayErrorOpHandler(self): + test_name = "/chromiumsync/birthdayerror" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = self.server._sync_handler.HandleCreateBirthdayError() + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncTransientErrorOpHandler(self): + test_name = "/chromiumsync/transienterror" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = self.server._sync_handler.HandleSetTransientError() + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncErrorOpHandler(self): + test_name = "/chromiumsync/error" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = self.server._sync_handler.HandleSetInducedError( + self.path) + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncSyncTabFaviconsOpHandler(self): + test_name = "/chromiumsync/synctabfavicons" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = self.server._sync_handler.HandleSetSyncTabFavicons() + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncCreateSyncedBookmarksOpHandler(self): + test_name = "/chromiumsync/createsyncedbookmarks" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = self.server._sync_handler.HandleCreateSyncedBookmarks() + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncEnableKeystoreEncryptionOpHandler(self): + test_name = "/chromiumsync/enablekeystoreencryption" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = ( + self.server._sync_handler.HandleEnableKeystoreEncryption()) + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + def ChromiumSyncRotateKeystoreKeysOpHandler(self): + test_name = "/chromiumsync/rotatekeystorekeys" + if not self._ShouldHandleRequest(test_name): + return False + result, raw_reply = ( + self.server._sync_handler.HandleRotateKeystoreKeys()) + self.send_response(result) + self.send_header('Content-Type', 'text/html') + self.send_header('Content-Length', len(raw_reply)) + self.end_headers() + self.wfile.write(raw_reply) + return True + + +class SyncServerRunner(testserver_base.TestServerRunner): + """TestServerRunner for the net test servers.""" + + def __init__(self): + super(SyncServerRunner, self).__init__() + + def create_server(self, server_data): + port = self.options.port + host = self.options.host + xmpp_port = self.options.xmpp_port + server = SyncHTTPServer((host, port), xmpp_port, SyncPageHandler) + print 'Sync HTTP server started on port %d...' % server.server_port + print 'Sync XMPP server started on port %d...' % server.xmpp_port + server_data['port'] = server.server_port + server_data['xmpp_port'] = server.xmpp_port + return server + + def run_server(self): + testserver_base.TestServerRunner.run_server(self) + + def add_options(self): + testserver_base.TestServerRunner.add_options(self) + self.option_parser.add_option('--xmpp-port', default='0', type='int', + help='Port used by the XMPP server. If ' + 'unspecified, the XMPP server will listen on ' + 'an ephemeral port.') + # Override the default logfile name used in testserver.py. + self.option_parser.set_defaults(log_file='sync_testserver.log') + +if __name__ == '__main__': + sys.exit(SyncServerRunner().main()) diff --git a/sync/tools/testserver/xmppserver.py b/sync/tools/testserver/xmppserver.py new file mode 100644 index 0000000..f9599c0 --- /dev/null +++ b/sync/tools/testserver/xmppserver.py @@ -0,0 +1,594 @@ +# Copyright 2013 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""A bare-bones and non-compliant XMPP server. + +Just enough of the protocol is implemented to get it to work with +Chrome's sync notification system. +""" + +import asynchat +import asyncore +import base64 +import re +import socket +from xml.dom import minidom + +# pychecker complains about the use of fileno(), which is implemented +# by asyncore by forwarding to an internal object via __getattr__. +__pychecker__ = 'no-classattr' + + +class Error(Exception): + """Error class for this module.""" + pass + + +class UnexpectedXml(Error): + """Raised when an unexpected XML element has been encountered.""" + + def __init__(self, xml_element): + xml_text = xml_element.toxml() + Error.__init__(self, 'Unexpected XML element', xml_text) + + +def ParseXml(xml_string): + """Parses the given string as XML and returns a minidom element + object. + """ + dom = minidom.parseString(xml_string) + + # minidom handles xmlns specially, but there's a bug where it sets + # the attribute value to None, which causes toxml() or toprettyxml() + # to break. + def FixMinidomXmlnsBug(xml_element): + if xml_element.getAttribute('xmlns') is None: + xml_element.setAttribute('xmlns', '') + + def ApplyToAllDescendantElements(xml_element, fn): + fn(xml_element) + for node in xml_element.childNodes: + if node.nodeType == node.ELEMENT_NODE: + ApplyToAllDescendantElements(node, fn) + + root = dom.documentElement + ApplyToAllDescendantElements(root, FixMinidomXmlnsBug) + return root + + +def CloneXml(xml): + """Returns a deep copy of the given XML element. + + Args: + xml: The XML element, which should be something returned from + ParseXml() (i.e., a root element). + """ + return xml.ownerDocument.cloneNode(True).documentElement + + +class StanzaParser(object): + """A hacky incremental XML parser. + + StanzaParser consumes data incrementally via FeedString() and feeds + its delegate complete parsed stanzas (i.e., XML documents) via + FeedStanza(). Any stanzas passed to FeedStanza() are unlinked after + the callback is done. + + Use like so: + + class MyClass(object): + ... + def __init__(self, ...): + ... + self._parser = StanzaParser(self) + ... + + def SomeFunction(self, ...): + ... + self._parser.FeedString(some_data) + ... + + def FeedStanza(self, stanza): + ... + print stanza.toprettyxml() + ... + """ + + # NOTE(akalin): The following regexps are naive, but necessary since + # none of the existing Python 2.4/2.5 XML libraries support + # incremental parsing. This works well enough for our purposes. + # + # The regexps below assume that any present XML element starts at + # the beginning of the string, but there may be trailing whitespace. + + # Matches an opening stream tag (e.g., '<stream:stream foo="bar">') + # (assumes that the stream XML namespace is defined in the tag). + _stream_re = re.compile(r'^(<stream:stream [^>]*>)\s*') + + # Matches an empty element tag (e.g., '<foo bar="baz"/>'). + _empty_element_re = re.compile(r'^(<[^>]*/>)\s*') + + # Matches a non-empty element (e.g., '<foo bar="baz">quux</foo>'). + # Does *not* handle nested elements. + _non_empty_element_re = re.compile(r'^(<([^ >]*)[^>]*>.*?</\2>)\s*') + + # The closing tag for a stream tag. We have to insert this + # ourselves since all XML stanzas are children of the stream tag, + # which is never closed until the connection is closed. + _stream_suffix = '</stream:stream>' + + def __init__(self, delegate): + self._buffer = '' + self._delegate = delegate + + def FeedString(self, data): + """Consumes the given string data, possibly feeding one or more + stanzas to the delegate. + """ + self._buffer += data + while (self._ProcessBuffer(self._stream_re, self._stream_suffix) or + self._ProcessBuffer(self._empty_element_re) or + self._ProcessBuffer(self._non_empty_element_re)): + pass + + def _ProcessBuffer(self, regexp, xml_suffix=''): + """If the buffer matches the given regexp, removes the match from + the buffer, appends the given suffix, parses it, and feeds it to + the delegate. + + Returns: + Whether or not the buffer matched the given regexp. + """ + results = regexp.match(self._buffer) + if not results: + return False + xml_text = self._buffer[:results.end()] + xml_suffix + self._buffer = self._buffer[results.end():] + stanza = ParseXml(xml_text) + self._delegate.FeedStanza(stanza) + # Needed because stanza may have cycles. + stanza.unlink() + return True + + +class Jid(object): + """Simple struct for an XMPP jid (essentially an e-mail address with + an optional resource string). + """ + + def __init__(self, username, domain, resource=''): + self.username = username + self.domain = domain + self.resource = resource + + def __str__(self): + jid_str = "%s@%s" % (self.username, self.domain) + if self.resource: + jid_str += '/' + self.resource + return jid_str + + def GetBareJid(self): + return Jid(self.username, self.domain) + + +class IdGenerator(object): + """Simple class to generate unique IDs for XMPP messages.""" + + def __init__(self, prefix): + self._prefix = prefix + self._id = 0 + + def GetNextId(self): + next_id = "%s.%s" % (self._prefix, self._id) + self._id += 1 + return next_id + + +class HandshakeTask(object): + """Class to handle the initial handshake with a connected XMPP + client. + """ + + # The handshake states in order. + (_INITIAL_STREAM_NEEDED, + _AUTH_NEEDED, + _AUTH_STREAM_NEEDED, + _BIND_NEEDED, + _SESSION_NEEDED, + _FINISHED) = range(6) + + # Used when in the _INITIAL_STREAM_NEEDED and _AUTH_STREAM_NEEDED + # states. Not an XML object as it's only the opening tag. + # + # The from and id attributes are filled in later. + _STREAM_DATA = ( + '<stream:stream from="%s" id="%s" ' + 'version="1.0" xmlns:stream="http://etherx.jabber.org/streams" ' + 'xmlns="jabber:client">') + + # Used when in the _INITIAL_STREAM_NEEDED state. + _AUTH_STANZA = ParseXml( + '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' + ' <mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl">' + ' <mechanism>PLAIN</mechanism>' + ' <mechanism>X-GOOGLE-TOKEN</mechanism>' + ' </mechanisms>' + '</stream:features>') + + # Used when in the _AUTH_NEEDED state. + _AUTH_SUCCESS_STANZA = ParseXml( + '<success xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + + # Used when in the _AUTH_NEEDED state. + _AUTH_FAILURE_STANZA = ParseXml( + '<failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl"/>') + + # Used when in the _AUTH_STREAM_NEEDED state. + _BIND_STANZA = ParseXml( + '<stream:features xmlns:stream="http://etherx.jabber.org/streams">' + ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"/>' + ' <session xmlns="urn:ietf:params:xml:ns:xmpp-session"/>' + '</stream:features>') + + # Used when in the _BIND_NEEDED state. + # + # The id and jid attributes are filled in later. + _BIND_RESULT_STANZA = ParseXml( + '<iq id="" type="result">' + ' <bind xmlns="urn:ietf:params:xml:ns:xmpp-bind">' + ' <jid/>' + ' </bind>' + '</iq>') + + # Used when in the _SESSION_NEEDED state. + # + # The id attribute is filled in later. + _IQ_RESPONSE_STANZA = ParseXml('<iq id="" type="result"/>') + + def __init__(self, connection, resource_prefix, authenticated): + self._connection = connection + self._id_generator = IdGenerator(resource_prefix) + self._username = '' + self._domain = '' + self._jid = None + self._authenticated = authenticated + self._resource_prefix = resource_prefix + self._state = self._INITIAL_STREAM_NEEDED + + def FeedStanza(self, stanza): + """Inspects the given stanza and changes the handshake state if needed. + + Called when a stanza is received from the client. Inspects the + stanza to make sure it has the expected attributes given the + current state, advances the state if needed, and sends a reply to + the client if needed. + """ + def ExpectStanza(stanza, name): + if stanza.tagName != name: + raise UnexpectedXml(stanza) + + def ExpectIq(stanza, type, name): + ExpectStanza(stanza, 'iq') + if (stanza.getAttribute('type') != type or + stanza.firstChild.tagName != name): + raise UnexpectedXml(stanza) + + def GetStanzaId(stanza): + return stanza.getAttribute('id') + + def HandleStream(stanza): + ExpectStanza(stanza, 'stream:stream') + domain = stanza.getAttribute('to') + if domain: + self._domain = domain + SendStreamData() + + def SendStreamData(): + next_id = self._id_generator.GetNextId() + stream_data = self._STREAM_DATA % (self._domain, next_id) + self._connection.SendData(stream_data) + + def GetUserDomain(stanza): + encoded_username_password = stanza.firstChild.data + username_password = base64.b64decode(encoded_username_password) + (_, username_domain, _) = username_password.split('\0') + # The domain may be omitted. + # + # If we were using python 2.5, we'd be able to do: + # + # username, _, domain = username_domain.partition('@') + # if not domain: + # domain = self._domain + at_pos = username_domain.find('@') + if at_pos != -1: + username = username_domain[:at_pos] + domain = username_domain[at_pos+1:] + else: + username = username_domain + domain = self._domain + return (username, domain) + + def Finish(): + self._state = self._FINISHED + self._connection.HandshakeDone(self._jid) + + if self._state == self._INITIAL_STREAM_NEEDED: + HandleStream(stanza) + self._connection.SendStanza(self._AUTH_STANZA, False) + self._state = self._AUTH_NEEDED + + elif self._state == self._AUTH_NEEDED: + ExpectStanza(stanza, 'auth') + (self._username, self._domain) = GetUserDomain(stanza) + if self._authenticated: + self._connection.SendStanza(self._AUTH_SUCCESS_STANZA, False) + self._state = self._AUTH_STREAM_NEEDED + else: + self._connection.SendStanza(self._AUTH_FAILURE_STANZA, False) + Finish() + + elif self._state == self._AUTH_STREAM_NEEDED: + HandleStream(stanza) + self._connection.SendStanza(self._BIND_STANZA, False) + self._state = self._BIND_NEEDED + + elif self._state == self._BIND_NEEDED: + ExpectIq(stanza, 'set', 'bind') + stanza_id = GetStanzaId(stanza) + resource_element = stanza.getElementsByTagName('resource')[0] + resource = resource_element.firstChild.data + full_resource = '%s.%s' % (self._resource_prefix, resource) + response = CloneXml(self._BIND_RESULT_STANZA) + response.setAttribute('id', stanza_id) + self._jid = Jid(self._username, self._domain, full_resource) + jid_text = response.parentNode.createTextNode(str(self._jid)) + response.getElementsByTagName('jid')[0].appendChild(jid_text) + self._connection.SendStanza(response) + self._state = self._SESSION_NEEDED + + elif self._state == self._SESSION_NEEDED: + ExpectIq(stanza, 'set', 'session') + stanza_id = GetStanzaId(stanza) + xml = CloneXml(self._IQ_RESPONSE_STANZA) + xml.setAttribute('id', stanza_id) + self._connection.SendStanza(xml) + Finish() + + +def AddrString(addr): + return '%s:%d' % addr + + +class XmppConnection(asynchat.async_chat): + """A single XMPP client connection. + + This class handles the connection to a single XMPP client (via a + socket). It does the XMPP handshake and also implements the (old) + Google notification protocol. + """ + + # Used for acknowledgements to the client. + # + # The from and id attributes are filled in later. + _IQ_RESPONSE_STANZA = ParseXml('<iq from="" id="" type="result"/>') + + def __init__(self, sock, socket_map, delegate, addr, authenticated): + """Starts up the xmpp connection. + + Args: + sock: The socket to the client. + socket_map: A map from sockets to their owning objects. + delegate: The delegate, which is notified when the XMPP + handshake is successful, when the connection is closed, and + when a notification has to be broadcast. + addr: The host/port of the client. + """ + # We do this because in versions of python < 2.6, + # async_chat.__init__ doesn't take a map argument nor pass it to + # dispatcher.__init__. We rely on the fact that + # async_chat.__init__ calls dispatcher.__init__ as the last thing + # it does, and that calling dispatcher.__init__ with socket=None + # and map=None is essentially a no-op. + asynchat.async_chat.__init__(self) + asyncore.dispatcher.__init__(self, sock, socket_map) + + self.set_terminator(None) + + self._delegate = delegate + self._parser = StanzaParser(self) + self._jid = None + + self._addr = addr + addr_str = AddrString(self._addr) + self._handshake_task = HandshakeTask(self, addr_str, authenticated) + print 'Starting connection to %s' % self + + def __str__(self): + if self._jid: + return str(self._jid) + else: + return AddrString(self._addr) + + # async_chat implementation. + + def collect_incoming_data(self, data): + self._parser.FeedString(data) + + # This is only here to make pychecker happy. + def found_terminator(self): + asynchat.async_chat.found_terminator(self) + + def close(self): + print "Closing connection to %s" % self + self._delegate.OnXmppConnectionClosed(self) + asynchat.async_chat.close(self) + + # Called by self._parser.FeedString(). + def FeedStanza(self, stanza): + if self._handshake_task: + self._handshake_task.FeedStanza(stanza) + elif stanza.tagName == 'iq' and stanza.getAttribute('type') == 'result': + # Ignore all client acks. + pass + elif (stanza.firstChild and + stanza.firstChild.namespaceURI == 'google:push'): + self._HandlePushCommand(stanza) + else: + raise UnexpectedXml(stanza) + + # Called by self._handshake_task. + def HandshakeDone(self, jid): + if jid: + self._jid = jid + self._handshake_task = None + self._delegate.OnXmppHandshakeDone(self) + print "Handshake done for %s" % self + else: + print "Handshake failed for %s" % self + self.close() + + def _HandlePushCommand(self, stanza): + if stanza.tagName == 'iq' and stanza.firstChild.tagName == 'subscribe': + # Subscription request. + self._SendIqResponseStanza(stanza) + elif stanza.tagName == 'message' and stanza.firstChild.tagName == 'push': + # Send notification request. + self._delegate.ForwardNotification(self, stanza) + else: + raise UnexpectedXml(command_xml) + + def _SendIqResponseStanza(self, iq): + stanza = CloneXml(self._IQ_RESPONSE_STANZA) + stanza.setAttribute('from', str(self._jid.GetBareJid())) + stanza.setAttribute('id', iq.getAttribute('id')) + self.SendStanza(stanza) + + def SendStanza(self, stanza, unlink=True): + """Sends a stanza to the client. + + Args: + stanza: The stanza to send. + unlink: Whether to unlink stanza after sending it. (Pass in + False if stanza is a constant.) + """ + self.SendData(stanza.toxml()) + if unlink: + stanza.unlink() + + def SendData(self, data): + """Sends raw data to the client. + """ + # We explicitly encode to ascii as that is what the client expects + # (some minidom library functions return unicode strings). + self.push(data.encode('ascii')) + + def ForwardNotification(self, notification_stanza): + """Forwards a notification to the client.""" + notification_stanza.setAttribute('from', str(self._jid.GetBareJid())) + notification_stanza.setAttribute('to', str(self._jid)) + self.SendStanza(notification_stanza, False) + + +class XmppServer(asyncore.dispatcher): + """The main XMPP server class. + + The XMPP server starts accepting connections on the given address + and spawns off XmppConnection objects for each one. + + Use like so: + + socket_map = {} + xmpp_server = xmppserver.XmppServer(socket_map, ('127.0.0.1', 5222)) + asyncore.loop(30.0, False, socket_map) + """ + + # Used when sending a notification. + _NOTIFICATION_STANZA = ParseXml( + '<message>' + ' <push xmlns="google:push">' + ' <data/>' + ' </push>' + '</message>') + + def __init__(self, socket_map, addr): + asyncore.dispatcher.__init__(self, None, socket_map) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.set_reuse_addr() + self.bind(addr) + self.listen(5) + self._socket_map = socket_map + self._connections = set() + self._handshake_done_connections = set() + self._notifications_enabled = True + self._authenticated = True + + def handle_accept(self): + (sock, addr) = self.accept() + xmpp_connection = XmppConnection( + sock, self._socket_map, self, addr, self._authenticated) + self._connections.add(xmpp_connection) + # Return the new XmppConnection for testing. + return xmpp_connection + + def close(self): + # A copy is necessary since calling close on each connection + # removes it from self._connections. + for connection in self._connections.copy(): + connection.close() + asyncore.dispatcher.close(self) + + def EnableNotifications(self): + self._notifications_enabled = True + + def DisableNotifications(self): + self._notifications_enabled = False + + def MakeNotification(self, channel, data): + """Makes a notification from the given channel and encoded data. + + Args: + channel: The channel on which to send the notification. + data: The notification payload. + """ + notification_stanza = CloneXml(self._NOTIFICATION_STANZA) + push_element = notification_stanza.getElementsByTagName('push')[0] + push_element.setAttribute('channel', channel) + data_element = push_element.getElementsByTagName('data')[0] + encoded_data = base64.b64encode(data) + data_text = notification_stanza.parentNode.createTextNode(encoded_data) + data_element.appendChild(data_text) + return notification_stanza + + def SendNotification(self, channel, data): + """Sends a notification to all connections. + + Args: + channel: The channel on which to send the notification. + data: The notification payload. + """ + notification_stanza = self.MakeNotification(channel, data) + self.ForwardNotification(None, notification_stanza) + notification_stanza.unlink() + + def SetAuthenticated(self, auth_valid): + self._authenticated = auth_valid + + def GetAuthenticated(self): + return self._authenticated + + # XmppConnection delegate methods. + def OnXmppHandshakeDone(self, xmpp_connection): + self._handshake_done_connections.add(xmpp_connection) + + def OnXmppConnectionClosed(self, xmpp_connection): + self._connections.discard(xmpp_connection) + self._handshake_done_connections.discard(xmpp_connection) + + def ForwardNotification(self, unused_xmpp_connection, notification_stanza): + if self._notifications_enabled: + for connection in self._handshake_done_connections: + print 'Sending notification to %s' % connection + connection.ForwardNotification(notification_stanza) + else: + print 'Notifications disabled; dropping notification' diff --git a/sync/tools/testserver/xmppserver_test.py b/sync/tools/testserver/xmppserver_test.py new file mode 100755 index 0000000..1a539d1 --- /dev/null +++ b/sync/tools/testserver/xmppserver_test.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python +# Copyright 2013 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""Tests exercising the various classes in xmppserver.py.""" + +import unittest + +import base64 +import xmppserver + +class XmlUtilsTest(unittest.TestCase): + + def testParseXml(self): + xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>""" + xml = xmppserver.ParseXml(xml_text) + self.assertEqual(xml.toxml(), xml_text) + + def testCloneXml(self): + xml = xmppserver.ParseXml('<foo/>') + xml_clone = xmppserver.CloneXml(xml) + xml_clone.setAttribute('bar', 'baz') + self.assertEqual(xml, xml) + self.assertEqual(xml_clone, xml_clone) + self.assertNotEqual(xml, xml_clone) + + def testCloneXmlUnlink(self): + xml_text = '<foo/>' + xml = xmppserver.ParseXml(xml_text) + xml_clone = xmppserver.CloneXml(xml) + xml.unlink() + self.assertEqual(xml.parentNode, None) + self.assertNotEqual(xml_clone.parentNode, None) + self.assertEqual(xml_clone.toxml(), xml_text) + +class StanzaParserTest(unittest.TestCase): + + def setUp(self): + self.stanzas = [] + + def FeedStanza(self, stanza): + # We can't append stanza directly because it is unlinked after + # this callback. + self.stanzas.append(stanza.toxml()) + + def testBasic(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<foo') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString('/><bar></bar>') + self.assertEqual(self.stanzas[0], '<foo/>') + self.assertEqual(self.stanzas[1], '<bar/>') + + def testStream(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<stream') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString(':stream foo="bar" xmlns:stream="baz">') + self.assertEqual(self.stanzas[0], + '<stream:stream foo="bar" xmlns:stream="baz"/>') + + def testNested(self): + parser = xmppserver.StanzaParser(self) + parser.FeedString('<foo') + self.assertEqual(len(self.stanzas), 0) + parser.FeedString(' bar="baz"') + parser.FeedString('><baz/><blah>meh</blah></foo>') + self.assertEqual(self.stanzas[0], + '<foo bar="baz"><baz/><blah>meh</blah></foo>') + + +class JidTest(unittest.TestCase): + + def testBasic(self): + jid = xmppserver.Jid('foo', 'bar.com') + self.assertEqual(str(jid), 'foo@bar.com') + + def testResource(self): + jid = xmppserver.Jid('foo', 'bar.com', 'resource') + self.assertEqual(str(jid), 'foo@bar.com/resource') + + def testGetBareJid(self): + jid = xmppserver.Jid('foo', 'bar.com', 'resource') + self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com') + + +class IdGeneratorTest(unittest.TestCase): + + def testBasic(self): + id_generator = xmppserver.IdGenerator('foo') + for i in xrange(0, 100): + self.assertEqual('foo.%d' % i, id_generator.GetNextId()) + + +class HandshakeTaskTest(unittest.TestCase): + + def setUp(self): + self.Reset() + + def Reset(self): + self.data_received = 0 + self.handshake_done = False + self.jid = None + + def SendData(self, _): + self.data_received += 1 + + def SendStanza(self, _, unused=True): + self.data_received += 1 + + def HandshakeDone(self, jid): + self.handshake_done = True + self.jid = jid + + def DoHandshake(self, resource_prefix, resource, username, + initial_stream_domain, auth_domain, auth_stream_domain): + self.Reset() + handshake_task = ( + xmppserver.HandshakeTask(self, resource_prefix, True)) + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', initial_stream_domain) + self.assertEqual(self.data_received, 0) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 2) + + if auth_domain: + username_domain = '%s@%s' % (username, auth_domain) + else: + username_domain = username + auth_string = base64.b64encode('\0%s\0bar' % username_domain) + auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) + handshake_task.FeedStanza(auth_xml) + self.assertEqual(self.data_received, 3) + + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', auth_stream_domain) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 5) + + bind_xml = xmppserver.ParseXml( + '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource) + handshake_task.FeedStanza(bind_xml) + self.assertEqual(self.data_received, 6) + + self.assertFalse(self.handshake_done) + + session_xml = xmppserver.ParseXml( + '<iq type="set"><session></session></iq>') + handshake_task.FeedStanza(session_xml) + self.assertEqual(self.data_received, 7) + + self.assertTrue(self.handshake_done) + + self.assertEqual(self.jid.username, username) + self.assertEqual(self.jid.domain, + auth_stream_domain or auth_domain or + initial_stream_domain) + self.assertEqual(self.jid.resource, + '%s.%s' % (resource_prefix, resource)) + + handshake_task.FeedStanza('<ignored/>') + self.assertEqual(self.data_received, 7) + + def DoHandshakeUnauthenticated(self, resource_prefix, resource, username, + initial_stream_domain): + self.Reset() + handshake_task = ( + xmppserver.HandshakeTask(self, resource_prefix, False)) + stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>') + stream_xml.setAttribute('to', initial_stream_domain) + self.assertEqual(self.data_received, 0) + handshake_task.FeedStanza(stream_xml) + self.assertEqual(self.data_received, 2) + + self.assertFalse(self.handshake_done) + + auth_string = base64.b64encode('\0%s\0bar' % username) + auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string) + handshake_task.FeedStanza(auth_xml) + self.assertEqual(self.data_received, 3) + + self.assertTrue(self.handshake_done) + + self.assertEqual(self.jid, None) + + handshake_task.FeedStanza('<ignored/>') + self.assertEqual(self.data_received, 3) + + def testBasic(self): + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', 'quux.com') + + def testDomainBehavior(self): + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', 'quux.com') + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', 'baz.com', '') + self.DoHandshake('resource_prefix', 'resource', + 'foo', 'bar.com', '', '') + self.DoHandshake('resource_prefix', 'resource', + 'foo', '', '', '') + + def testBasicUnauthenticated(self): + self.DoHandshakeUnauthenticated('resource_prefix', 'resource', + 'foo', 'bar.com') + + +class FakeSocket(object): + """A fake socket object used for testing. + """ + + def __init__(self): + self._sent_data = [] + + def GetSentData(self): + return self._sent_data + + # socket-like methods. + def fileno(self): + return 0 + + def setblocking(self, int): + pass + + def getpeername(self): + return ('', 0) + + def send(self, data): + self._sent_data.append(data) + pass + + def close(self): + pass + + +class XmppConnectionTest(unittest.TestCase): + + def setUp(self): + self.connections = set() + self.fake_socket = FakeSocket() + + # XmppConnection delegate methods. + def OnXmppHandshakeDone(self, xmpp_connection): + self.connections.add(xmpp_connection) + + def OnXmppConnectionClosed(self, xmpp_connection): + self.connections.discard(xmpp_connection) + + def ForwardNotification(self, unused_xmpp_connection, notification_stanza): + for connection in self.connections: + connection.ForwardNotification(notification_stanza) + + def testBasic(self): + socket_map = {} + xmpp_connection = xmppserver.XmppConnection( + self.fake_socket, socket_map, self, ('', 0), True) + self.assertEqual(len(socket_map), 1) + self.assertEqual(len(self.connections), 0) + xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar')) + self.assertEqual(len(socket_map), 1) + self.assertEqual(len(self.connections), 1) + + sent_data = self.fake_socket.GetSentData() + + # Test subscription request. + self.assertEqual(len(sent_data), 0) + xmpp_connection.collect_incoming_data( + '<iq><subscribe xmlns="google:push"></subscribe></iq>') + self.assertEqual(len(sent_data), 1) + + # Test acks. + xmpp_connection.collect_incoming_data('<iq type="result"/>') + self.assertEqual(len(sent_data), 1) + + # Test notification. + xmpp_connection.collect_incoming_data( + '<message><push xmlns="google:push"/></message>') + self.assertEqual(len(sent_data), 2) + + # Test unexpected stanza. + def SendUnexpectedStanza(): + xmpp_connection.collect_incoming_data('<foo/>') + self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) + + # Test unexpected notifier command. + def SendUnexpectedNotifierCommand(): + xmpp_connection.collect_incoming_data( + '<iq><foo xmlns="google:notifier"/></iq>') + self.assertRaises(xmppserver.UnexpectedXml, + SendUnexpectedNotifierCommand) + + # Test close. + xmpp_connection.close() + self.assertEqual(len(socket_map), 0) + self.assertEqual(len(self.connections), 0) + + def testBasicUnauthenticated(self): + socket_map = {} + xmpp_connection = xmppserver.XmppConnection( + self.fake_socket, socket_map, self, ('', 0), False) + self.assertEqual(len(socket_map), 1) + self.assertEqual(len(self.connections), 0) + xmpp_connection.HandshakeDone(None) + self.assertEqual(len(socket_map), 0) + self.assertEqual(len(self.connections), 0) + + # Test unexpected stanza. + def SendUnexpectedStanza(): + xmpp_connection.collect_incoming_data('<foo/>') + self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza) + + # Test redundant close. + xmpp_connection.close() + self.assertEqual(len(socket_map), 0) + self.assertEqual(len(self.connections), 0) + + +class FakeXmppServer(xmppserver.XmppServer): + """A fake XMPP server object used for testing. + """ + + def __init__(self): + self._socket_map = {} + self._fake_sockets = set() + self._next_jid_suffix = 1 + xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0)) + + def GetSocketMap(self): + return self._socket_map + + def GetFakeSockets(self): + return self._fake_sockets + + def AddHandshakeCompletedConnection(self): + """Creates a new XMPP connection and completes its handshake. + """ + xmpp_connection = self.handle_accept() + jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com') + self._next_jid_suffix += 1 + xmpp_connection.HandshakeDone(jid) + + # XmppServer overrides. + def accept(self): + fake_socket = FakeSocket() + self._fake_sockets.add(fake_socket) + return (fake_socket, ('', 0)) + + def close(self): + self._fake_sockets.clear() + xmppserver.XmppServer.close(self) + + +class XmppServerTest(unittest.TestCase): + + def setUp(self): + self.xmpp_server = FakeXmppServer() + + def AssertSentDataLength(self, expected_length): + for fake_socket in self.xmpp_server.GetFakeSockets(): + self.assertEqual(len(fake_socket.GetSentData()), expected_length) + + def testBasic(self): + socket_map = self.xmpp_server.GetSocketMap() + self.assertEqual(len(socket_map), 1) + self.xmpp_server.AddHandshakeCompletedConnection() + self.assertEqual(len(socket_map), 2) + self.xmpp_server.close() + self.assertEqual(len(socket_map), 0) + + def testMakeNotification(self): + notification = self.xmpp_server.MakeNotification('channel', 'data') + expected_xml = ( + '<message>' + ' <push channel="channel" xmlns="google:push">' + ' <data>%s</data>' + ' </push>' + '</message>' % base64.b64encode('data')) + self.assertEqual(notification.toxml(), expected_xml) + + def testSendNotification(self): + # Add a few connections. + for _ in xrange(0, 7): + self.xmpp_server.AddHandshakeCompletedConnection() + + self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7) + + self.AssertSentDataLength(0) + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(1) + + def testEnableDisableNotifications(self): + # Add a few connections. + for _ in xrange(0, 5): + self.xmpp_server.AddHandshakeCompletedConnection() + + self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5) + + self.AssertSentDataLength(0) + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(1) + + self.xmpp_server.EnableNotifications() + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(2) + + self.xmpp_server.DisableNotifications() + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(2) + + self.xmpp_server.DisableNotifications() + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(2) + + self.xmpp_server.EnableNotifications() + self.xmpp_server.SendNotification('channel', 'data') + self.AssertSentDataLength(3) + + +if __name__ == '__main__': + unittest.main() |