// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "device/test/usb_test_gadget.h" #include #include #include "base/command_line.h" #include "base/compiler_specific.h" #include "base/files/file.h" #include "base/files/file_path.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/path_service.h" #include "base/process/process_handle.h" #include "base/run_loop.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" #include "base/threading/platform_thread.h" #include "base/time/time.h" #include "device/usb/usb_device.h" #include "device/usb/usb_device_handle.h" #include "device/usb/usb_service.h" #include "net/proxy/proxy_service.h" #include "net/url_request/url_fetcher.h" #include "net/url_request/url_fetcher_delegate.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_builder.h" #include "net/url_request/url_request_context_getter.h" #include "url/gurl.h" using ::base::PlatformThread; using ::base::TimeDelta; namespace device { namespace { static const char kCommandLineSwitch[] = "enable-gadget-tests"; static const int kClaimRetries = 100; // 5 seconds static const int kDisconnectRetries = 100; // 5 seconds static const int kRetryPeriod = 50; // 0.05 seconds static const int kReconnectRetries = 100; // 5 seconds static const int kUpdateRetries = 100; // 5 seconds struct UsbTestGadgetConfiguration { UsbTestGadget::Type type; const char* http_resource; uint16 product_id; }; static const struct UsbTestGadgetConfiguration kConfigurations[] = { {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0}, {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1}, {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2}, {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3}, {UsbTestGadget::ECHO, "/echo/configure", 0x58F4}, }; class UsbTestGadgetImpl : public UsbTestGadget { public: ~UsbTestGadgetImpl() override; bool Unclaim() override; bool Disconnect() override; bool Reconnect() override; bool SetType(Type type) override; UsbDevice* GetDevice() const override; std::string GetSerialNumber() const override; protected: UsbTestGadgetImpl(); private: scoped_ptr CreateURLFetcher( const GURL& url, net::URLFetcher::RequestType request_type, net::URLFetcherDelegate* delegate); int SimplePOSTRequest(const GURL& url, const std::string& form_data); bool FindUnclaimed(); bool GetVersion(std::string* version); bool Update(); bool FindClaimed(); bool ReadLocalVersion(std::string* version); bool ReadLocalPackage(std::string* package); bool ReadFile(const base::FilePath& file_path, std::string* content); class Delegate : public net::URLFetcherDelegate { public: Delegate() {} ~Delegate() override {} void WaitForCompletion() { run_loop_.Run(); } void OnURLFetchComplete(const net::URLFetcher* source) override { run_loop_.Quit(); } private: base::RunLoop run_loop_; DISALLOW_COPY_AND_ASSIGN(Delegate); }; scoped_refptr device_; std::string device_address_; scoped_ptr request_context_; std::string session_id_; UsbService* usb_service_; friend class UsbTestGadget; DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl); }; } // namespace bool UsbTestGadget::IsTestEnabled() { base::CommandLine* command_line = CommandLine::ForCurrentProcess(); return command_line->HasSwitch(kCommandLineSwitch); } scoped_ptr UsbTestGadget::Claim() { scoped_ptr gadget(new UsbTestGadgetImpl); int retries = kClaimRetries; while (!gadget->FindUnclaimed()) { if (--retries == 0) { LOG(ERROR) << "Failed to find an unclaimed device."; return scoped_ptr(); } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kClaimRetries - retries) << " retries to find an unclaimed device."; return gadget.Pass(); } UsbTestGadgetImpl::UsbTestGadgetImpl() { net::URLRequestContextBuilder context_builder; context_builder.set_proxy_service(net::ProxyService::CreateDirect()); request_context_.reset(context_builder.Build()); base::ProcessId process_id = base::GetCurrentProcId(); session_id_ = base::StringPrintf( "%s:%p", base::HexEncode(&process_id, sizeof(process_id)).c_str(), this); usb_service_ = UsbService::GetInstance(NULL); } UsbTestGadgetImpl::~UsbTestGadgetImpl() { if (!device_address_.empty()) { Unclaim(); } } UsbDevice* UsbTestGadgetImpl::GetDevice() const { return device_.get(); } std::string UsbTestGadgetImpl::GetSerialNumber() const { return device_address_; } scoped_ptr UsbTestGadgetImpl::CreateURLFetcher( const GURL& url, net::URLFetcher::RequestType request_type, net::URLFetcherDelegate* delegate) { scoped_ptr url_fetcher( net::URLFetcher::Create(url, request_type, delegate)); url_fetcher->SetRequestContext( new net::TrivialURLRequestContextGetter( request_context_.get(), base::MessageLoop::current()->message_loop_proxy())); return url_fetcher; } int UsbTestGadgetImpl::SimplePOSTRequest(const GURL& url, const std::string& form_data) { Delegate delegate; scoped_ptr url_fetcher = CreateURLFetcher(url, net::URLFetcher::POST, &delegate); url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data); url_fetcher->Start(); delegate.WaitForCompletion(); return url_fetcher->GetResponseCode(); } bool UsbTestGadgetImpl::FindUnclaimed() { std::vector > devices; usb_service_->GetDevices(&devices); for (std::vector >::const_iterator iter = devices.begin(); iter != devices.end(); ++iter) { const scoped_refptr &device = *iter; if (device->vendor_id() == 0x18D1 && device->product_id() == 0x58F0) { base::string16 serial_utf16; if (!device->GetSerialNumber(&serial_utf16)) { continue; } const std::string serial = base::UTF16ToUTF8(serial_utf16); const GURL url("http://" + serial + "/claim"); const std::string form_data = base::StringPrintf( "session_id=%s", net::EscapeUrlEncodedData(session_id_, true).c_str()); const int response_code = SimplePOSTRequest(url, form_data); if (response_code == 200) { device_address_ = serial; device_ = device; break; } // The device is probably claimed by another process. if (response_code != 403) { LOG(WARNING) << "Unexpected HTTP " << response_code << " from /claim."; } } } std::string local_version; std::string version; if (!ReadLocalVersion(&local_version) || !GetVersion(&version)) { return false; } if (version == local_version) { return true; } return Update(); } bool UsbTestGadgetImpl::GetVersion(std::string* version) { Delegate delegate; const GURL url("http://" + device_address_ + "/version"); scoped_ptr url_fetcher = CreateURLFetcher(url, net::URLFetcher::GET, &delegate); url_fetcher->Start(); delegate.WaitForCompletion(); const int response_code = url_fetcher->GetResponseCode(); if (response_code != 200) { VLOG(2) << "Unexpected HTTP " << response_code << " from /version."; return false; } STLClearObject(version); if (!url_fetcher->GetResponseAsString(version)) { VLOG(2) << "Failed to read body from /version."; return false; } return true; } bool UsbTestGadgetImpl::Update() { std::string version; if (!ReadLocalVersion(&version)) { return false; } LOG(INFO) << "Updating " << device_address_ << " to " << version << "..."; Delegate delegate; const GURL url("http://" + device_address_ + "/update"); scoped_ptr url_fetcher = CreateURLFetcher(url, net::URLFetcher::POST, &delegate); const std::string mime_header = base::StringPrintf( "--foo\r\n" "Content-Disposition: form-data; name=\"file\"; " "filename=\"usb_gadget-%s.zip\"\r\n" "Content-Type: application/octet-stream\r\n" "\r\n", version.c_str()); const std::string mime_footer("\r\n--foo--\r\n"); std::string package; if (!ReadLocalPackage(&package)) { return false; } url_fetcher->SetUploadData("multipart/form-data; boundary=foo", mime_header + package + mime_footer); url_fetcher->Start(); delegate.WaitForCompletion(); const int response_code = url_fetcher->GetResponseCode(); if (response_code != 200) { LOG(ERROR) << "Unexpected HTTP " << response_code << " from /update."; return false; } int retries = kUpdateRetries; std::string new_version; while (!GetVersion(&new_version) || new_version != version) { if (--retries == 0) { LOG(ERROR) << "Device not responding with new version."; return false; } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kUpdateRetries - retries) << " retries to see the new version."; // Release the old reference to the device and try to open a new one. device_ = NULL; retries = kReconnectRetries; while (!FindClaimed()) { if (--retries == 0) { LOG(ERROR) << "Failed to find updated device."; return false; } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kReconnectRetries - retries) << " retries to find the updated device."; return true; } bool UsbTestGadgetImpl::FindClaimed() { CHECK(!device_.get()); std::string expected_serial = GetSerialNumber(); std::vector > devices; usb_service_->GetDevices(&devices); for (std::vector >::iterator iter = devices.begin(); iter != devices.end(); ++iter) { scoped_refptr &device = *iter; if (device->vendor_id() == 0x18D1) { const uint16 product_id = device->product_id(); bool found = false; for (size_t i = 0; i < arraysize(kConfigurations); ++i) { if (product_id == kConfigurations[i].product_id) { found = true; break; } } if (!found) { continue; } base::string16 serial_utf16; if (!device->GetSerialNumber(&serial_utf16)) { continue; } std::string serial = base::UTF16ToUTF8(serial_utf16); if (serial != expected_serial) { continue; } device_ = device; return true; } } return false; } bool UsbTestGadgetImpl::ReadLocalVersion(std::string* version) { base::FilePath file_path; CHECK(PathService::Get(base::DIR_EXE, &file_path)); file_path = file_path.AppendASCII("usb_gadget.zip.md5"); return ReadFile(file_path, version); } bool UsbTestGadgetImpl::ReadLocalPackage(std::string* package) { base::FilePath file_path; CHECK(PathService::Get(base::DIR_EXE, &file_path)); file_path = file_path.AppendASCII("usb_gadget.zip"); return ReadFile(file_path, package); } bool UsbTestGadgetImpl::ReadFile(const base::FilePath& file_path, std::string* content) { base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ); if (!file.IsValid()) { LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": " << base::File::ErrorToString(file.error_details()); return false; } STLClearObject(content); int rv; do { char buf[4096]; rv = file.ReadAtCurrentPos(buf, sizeof buf); if (rv == -1) { LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": " << base::File::ErrorToString(file.error_details()); return false; } content->append(buf, rv); } while (rv > 0); return true; } bool UsbTestGadgetImpl::Unclaim() { VLOG(1) << "Releasing the device at " << device_address_ << "."; const GURL url("http://" + device_address_ + "/unclaim"); const int response_code = SimplePOSTRequest(url, ""); if (response_code != 200) { LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim."; return false; } return true; } bool UsbTestGadgetImpl::SetType(Type type) { const struct UsbTestGadgetConfiguration* config = NULL; for (size_t i = 0; i < arraysize(kConfigurations); ++i) { if (kConfigurations[i].type == type) { config = &kConfigurations[i]; } } CHECK(config); const GURL url("http://" + device_address_ + config->http_resource); const int response_code = SimplePOSTRequest(url, ""); if (response_code != 200) { LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << config->http_resource << "."; return false; } // Release the old reference to the device and try to open a new one. int retries = kReconnectRetries; while (true) { device_ = NULL; if (FindClaimed() && device_->product_id() == config->product_id) { break; } if (--retries == 0) { LOG(ERROR) << "Failed to find updated device."; return false; } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kReconnectRetries - retries) << " retries to find the updated device."; return true; } bool UsbTestGadgetImpl::Disconnect() { const GURL url("http://" + device_address_ + "/disconnect"); const int response_code = SimplePOSTRequest(url, ""); if (response_code != 200) { LOG(ERROR) << "Unexpected HTTP " << response_code << " from /disconnect."; return false; } // Release the old reference to the device and wait until it can't be found. int retries = kDisconnectRetries; while (true) { device_ = NULL; if (!FindClaimed()) { break; } if (--retries == 0) { LOG(ERROR) << "Device did not disconnect."; return false; } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kDisconnectRetries - retries) << " retries for the device to disconnect."; return true; } bool UsbTestGadgetImpl::Reconnect() { const GURL url("http://" + device_address_ + "/reconnect"); const int response_code = SimplePOSTRequest(url, ""); if (response_code != 200) { LOG(ERROR) << "Unexpected HTTP " << response_code << " from /reconnect."; return false; } int retries = kDisconnectRetries; while (true) { if (FindClaimed()) { break; } if (--retries == 0) { LOG(ERROR) << "Device did not reconnect."; return false; } PlatformThread::Sleep(TimeDelta::FromMilliseconds(kRetryPeriod)); } VLOG(1) << "It took " << (kDisconnectRetries - retries) << " retries for the device to reconnect."; return true; } } // namespace device