diff options
author | wtc@chromium.org <wtc@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 01:58:06 +0000 |
---|---|---|
committer | wtc@chromium.org <wtc@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2014-08-22 01:59:40 +0000 |
commit | 6d51582e8d510bdfb0606733a971064d59294d48 (patch) | |
tree | dd1e428c3506eec273e706704bd1f04e4131d794 | |
parent | 4975b85b32094c18b948d5ddc0c2e4fda3cc091c (diff) | |
download | chromium_src-6d51582e8d510bdfb0606733a971064d59294d48.zip chromium_src-6d51582e8d510bdfb0606733a971064d59294d48.tar.gz chromium_src-6d51582e8d510bdfb0606733a971064d59294d48.tar.bz2 |
Refactoring: Create per-connection packet writers in QuicDispatcher.
To make porting the QUIC EndToEndTest to Chromium possible with fewer
Chromium-specific parts in shared code, I've made QuicDispatcher expose
and accept a QuicDispatcher::PacketWriterFactory which it uses to create
a new packet writer wrapper for every QuicConnection. I also changed
QuicConnection to accept a QuicConnection::PacketWriterFactory (a second
new type of factory) rather than the writer itself in its constructor,
since the per-connection packet writers need to be created with the
connection already existing.
Merge internal CL: 73064412
Written by Daniel Ziegler <dmziegler@chromium.org>
Original review URL: https://codereview.chromium.org/467963002/
R=rch@chromium.org,wtc@chromium.org
BUG=
Review URL: https://codereview.chromium.org/475113005
Cr-Commit-Position: refs/heads/master@{#291314}
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@291314 0039d316-1c4b-4281-b951-d872f2087c98
41 files changed, 702 insertions, 185 deletions
diff --git a/net/BUILD.gn b/net/BUILD.gn index aa2fb79..0fdacb5 100644 --- a/net/BUILD.gn +++ b/net/BUILD.gn @@ -993,6 +993,8 @@ if (is_linux) { "tools/quic/quic_in_memory_cache.h", "tools/quic/quic_packet_writer_wrapper.cc", "tools/quic/quic_packet_writer_wrapper.h", + "tools/quic/quic_per_connection_packet_writer.cc", + "tools/quic/quic_per_connection_packet_writer.h", "tools/quic/quic_server.cc", "tools/quic/quic_server.h", "tools/quic/quic_server_session.cc", diff --git a/net/net.gyp b/net/net.gyp index c6007ae..6fa0f62 100644 --- a/net/net.gyp +++ b/net/net.gyp @@ -1484,6 +1484,8 @@ 'tools/quic/quic_in_memory_cache.h', 'tools/quic/quic_packet_writer_wrapper.cc', 'tools/quic/quic_packet_writer_wrapper.h', + 'tools/quic/quic_per_connection_packet_writer.cc', + 'tools/quic/quic_per_connection_packet_writer.h', 'tools/quic/quic_server.cc', 'tools/quic/quic_server.h', 'tools/quic/quic_server_session.cc', diff --git a/net/quic/quic_blocked_writer_interface.h b/net/quic/quic_blocked_writer_interface.h index a6f6faf..1c3a6fe 100644 --- a/net/quic/quic_blocked_writer_interface.h +++ b/net/quic/quic_blocked_writer_interface.h @@ -24,4 +24,18 @@ class NET_EXPORT_PRIVATE QuicBlockedWriterInterface { } // namespace net +#if defined(COMPILER_GCC) +namespace BASE_HASH_NAMESPACE { +// Hash pointers as if they were int's, but bring more entropy to the lower +// bits. +template <> +struct hash<net::QuicBlockedWriterInterface*> { + std::size_t operator()(const net::QuicBlockedWriterInterface* ptr) const { + size_t k = reinterpret_cast<size_t>(ptr); + return k + (k >> 6); + } +}; +} +#endif + #endif // NET_QUIC_QUIC_BLOCKED_WRITER_INTERFACE_H_ diff --git a/net/quic/quic_client_session.cc b/net/quic/quic_client_session.cc index df874dd..d563973 100644 --- a/net/quic/quic_client_session.cc +++ b/net/quic/quic_client_session.cc @@ -18,7 +18,6 @@ #include "net/quic/crypto/quic_server_info.h" #include "net/quic/quic_connection_helper.h" #include "net/quic/quic_crypto_client_stream_factory.h" -#include "net/quic/quic_default_packet_writer.h" #include "net/quic/quic_server_id.h" #include "net/quic/quic_stream_factory.h" #include "net/spdy/spdy_session.h" @@ -137,7 +136,6 @@ void QuicClientSession::StreamRequest::OnRequestCompleteFailure(int rv) { QuicClientSession::QuicClientSession( QuicConnection* connection, scoped_ptr<DatagramClientSocket> socket, - scoped_ptr<QuicDefaultPacketWriter> writer, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, TransportSecurityState* transport_security_state, @@ -152,7 +150,6 @@ QuicClientSession::QuicClientSession( require_confirmation_(false), stream_factory_(stream_factory), socket_(socket.Pass()), - writer_(writer.Pass()), read_buffer_(new IOBufferWithSize(kMaxPacketSize)), transport_security_state_(transport_security_state), server_info_(server_info.Pass()), diff --git a/net/quic/quic_client_session.h b/net/quic/quic_client_session.h index 9808cac..41dff64 100644 --- a/net/quic/quic_client_session.h +++ b/net/quic/quic_client_session.h @@ -30,7 +30,6 @@ class CertVerifyResult; class DatagramClientSocket; class QuicConnectionHelper; class QuicCryptoClientStreamFactory; -class QuicDefaultPacketWriter; class QuicServerId; class QuicServerInfo; class QuicStreamFactory; @@ -93,7 +92,6 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicClientSessionBase { // TODO(rch): decouple the factory from the session via a Delegate interface. QuicClientSession(QuicConnection* connection, scoped_ptr<DatagramClientSocket> socket, - scoped_ptr<QuicDefaultPacketWriter> writer, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, TransportSecurityState* transport_security_state, @@ -226,7 +224,6 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicClientSessionBase { scoped_ptr<QuicCryptoClientStream> crypto_stream_; QuicStreamFactory* stream_factory_; scoped_ptr<DatagramClientSocket> socket_; - scoped_ptr<QuicDefaultPacketWriter> writer_; scoped_refptr<IOBufferWithSize> read_buffer_; TransportSecurityState* transport_security_state_; scoped_ptr<QuicServerInfo> server_info_; diff --git a/net/quic/quic_client_session_test.cc b/net/quic/quic_client_session_test.cc index 1e13e89..39b2ae5 100644 --- a/net/quic/quic_client_session_test.cc +++ b/net/quic/quic_client_session_test.cc @@ -20,7 +20,6 @@ #include "net/quic/crypto/quic_decrypter.h" #include "net/quic/crypto/quic_encrypter.h" #include "net/quic/crypto/quic_server_info.h" -#include "net/quic/quic_default_packet_writer.h" #include "net/quic/test_tools/crypto_test_utils.h" #include "net/quic/test_tools/quic_client_session_peer.h" #include "net/quic/test_tools/quic_test_utils.h" @@ -39,43 +38,12 @@ namespace { const char kServerHostname[] = "www.example.org"; const uint16 kServerPort = 80; -class TestPacketWriter : public QuicDefaultPacketWriter { - public: - TestPacketWriter(QuicVersion version) : version_(version) {} - - // QuicPacketWriter - virtual WriteResult WritePacket( - const char* buffer, size_t buf_len, - const IPAddressNumber& self_address, - const IPEndPoint& peer_address) OVERRIDE { - SimpleQuicFramer framer(SupportedVersions(version_)); - QuicEncryptedPacket packet(buffer, buf_len); - EXPECT_TRUE(framer.ProcessPacket(packet)); - header_ = framer.header(); - return WriteResult(WRITE_STATUS_OK, packet.length()); - } - - virtual bool IsWriteBlockedDataBuffered() const OVERRIDE { - // Chrome sockets' Write() methods buffer the data until the Write is - // permitted. - return true; - } - - // Returns the header from the last packet written. - const QuicPacketHeader& header() { return header_; } - - private: - QuicVersion version_; - QuicPacketHeader header_; -}; - class QuicClientSessionTest : public ::testing::TestWithParam<QuicVersion> { protected: QuicClientSessionTest() - : writer_(new TestPacketWriter(GetParam())), - connection_( + : connection_( new PacketSavingConnection(false, SupportedVersions(GetParam()))), - session_(connection_, GetSocket().Pass(), writer_.Pass(), NULL, NULL, + session_(connection_, GetSocket().Pass(), NULL, NULL, &transport_security_state_, make_scoped_ptr((QuicServerInfo*)NULL), QuicServerId(kServerHostname, kServerPort, false, @@ -107,7 +75,6 @@ class QuicClientSessionTest : public ::testing::TestWithParam<QuicVersion> { ASSERT_EQ(OK, callback_.WaitForResult()); } - scoped_ptr<QuicDefaultPacketWriter> writer_; PacketSavingConnection* connection_; CapturingNetLog net_log_; MockClientSocketFactory socket_factory_; diff --git a/net/quic/quic_connection.cc b/net/quic/quic_connection.cc index 35617c8..c116ba5 100644 --- a/net/quic/quic_connection.cc +++ b/net/quic/quic_connection.cc @@ -190,14 +190,14 @@ QuicConnection::QueuedPacket::QueuedPacket(SerializedPacket packet, QuicConnection::QuicConnection(QuicConnectionId connection_id, IPEndPoint address, QuicConnectionHelperInterface* helper, - QuicPacketWriter* writer, + const PacketWriterFactory& writer_factory, bool owns_writer, bool is_server, const QuicVersionVector& supported_versions) : framer_(supported_versions, helper->GetClock()->ApproximateNow(), is_server), helper_(helper), - writer_(writer), + writer_(writer_factory.Create(this)), owns_writer_(owns_writer), encryption_level_(ENCRYPTION_NONE), clock_(helper->GetClock()), diff --git a/net/quic/quic_connection.h b/net/quic/quic_connection.h index 4d57208..13952c5 100644 --- a/net/quic/quic_connection.h +++ b/net/quic/quic_connection.h @@ -236,13 +236,21 @@ class NET_EXPORT_PRIVATE QuicConnection BUNDLE_PENDING_ACK = 2, }; - // Constructs a new QuicConnection for |connection_id| and |address|. - // |helper| must outlive this connection, and if |owns_writer| is false, so - // must |writer|. + class PacketWriterFactory { + public: + virtual ~PacketWriterFactory() {} + + virtual QuicPacketWriter* Create(QuicConnection* connection) const = 0; + }; + + // Constructs a new QuicConnection for |connection_id| and |address|. Invokes + // writer_factory->Create() to get a writer; |owns_writer| specifies whether + // the connection takes ownership of the returned writer. |helper| must + // outlive this connection. QuicConnection(QuicConnectionId connection_id, IPEndPoint address, QuicConnectionHelperInterface* helper, - QuicPacketWriter* writer, + const PacketWriterFactory& writer_factory, bool owns_writer, bool is_server, const QuicVersionVector& supported_versions); diff --git a/net/quic/quic_connection_test.cc b/net/quic/quic_connection_test.cc index c0025da..2821064 100644 --- a/net/quic/quic_connection_test.cc +++ b/net/quic/quic_connection_test.cc @@ -38,6 +38,7 @@ using testing::Contains; using testing::DoAll; using testing::InSequence; using testing::InvokeWithoutArgs; +using testing::NiceMock; using testing::Ref; using testing::Return; using testing::SaveArg; @@ -411,21 +412,20 @@ class TestConnection : public QuicConnection { TestConnection(QuicConnectionId connection_id, IPEndPoint address, TestConnectionHelper* helper, - TestPacketWriter* writer, + const PacketWriterFactory& factory, bool is_server, QuicVersion version) : QuicConnection(connection_id, address, helper, - writer, - false /* owns_writer */, + factory, + /* owns_writer= */ false, is_server, - SupportedVersions(version)), - writer_(writer) { + SupportedVersions(version)) { // Disable tail loss probes for most tests. QuicSentPacketManagerPeer::SetMaxTailLossProbes( QuicConnectionPeer::GetSentPacketManager(this), 0); - writer_->set_is_server(is_server); + writer()->set_is_server(is_server); } void SendAck() { @@ -537,11 +537,11 @@ class TestConnection : public QuicConnection { void SetSupportedVersions(const QuicVersionVector& versions) { QuicConnectionPeer::GetFramer(this)->SetSupportedVersions(versions); - writer_->SetSupportedVersions(versions); + writer()->SetSupportedVersions(versions); } void set_is_server(bool is_server) { - writer_->set_is_server(is_server); + writer()->set_is_server(is_server); QuicConnectionPeer::SetIsServer(this, is_server); } @@ -578,7 +578,9 @@ class TestConnection : public QuicConnection { using QuicConnection::SelectMutualVersion; private: - TestPacketWriter* writer_; + TestPacketWriter* writer() { + return static_cast<TestPacketWriter*>(QuicConnection::writer()); + } DISALLOW_COPY_AND_ASSIGN(TestConnection); }; @@ -601,6 +603,16 @@ class FecQuicConnectionDebugVisitor QuicPacketHeader revived_header_; }; +class MockPacketWriterFactory : public QuicConnection::PacketWriterFactory { + public: + MockPacketWriterFactory(QuicPacketWriter* writer) { + ON_CALL(*this, Create(_)).WillByDefault(Return(writer)); + } + virtual ~MockPacketWriterFactory() {} + + MOCK_CONST_METHOD1(Create, QuicPacketWriter*(QuicConnection* connection)); +}; + class QuicConnectionTest : public ::testing::TestWithParam<QuicVersion> { protected: QuicConnectionTest() @@ -611,8 +623,9 @@ class QuicConnectionTest : public ::testing::TestWithParam<QuicVersion> { loss_algorithm_(new MockLossAlgorithm()), helper_(new TestConnectionHelper(&clock_, &random_generator_)), writer_(new TestPacketWriter(version())), + factory_(writer_.get()), connection_(connection_id_, IPEndPoint(), helper_.get(), - writer_.get(), false, version()), + factory_, false, version()), frame1_(1, false, 0, MakeIOVector(data1)), frame2_(1, false, 3, MakeIOVector(data2)), sequence_number_length_(PACKET_6BYTE_SEQUENCE_NUMBER), @@ -960,6 +973,7 @@ class QuicConnectionTest : public ::testing::TestWithParam<QuicVersion> { MockRandom random_generator_; scoped_ptr<TestConnectionHelper> helper_; scoped_ptr<TestPacketWriter> writer_; + NiceMock<MockPacketWriterFactory> factory_; TestConnection connection_; StrictMock<MockConnectionVisitor> visitor_; @@ -3923,9 +3937,9 @@ TEST_P(QuicConnectionTest, OnPacketHeaderDebugVisitor) { TEST_P(QuicConnectionTest, Pacing) { TestConnection server(connection_id_, IPEndPoint(), helper_.get(), - writer_.get(), true, version()); + factory_, /* is_server= */ true, version()); TestConnection client(connection_id_, IPEndPoint(), helper_.get(), - writer_.get(), false, version()); + factory_, /* is_server= */ false, version()); EXPECT_FALSE(client.sent_packet_manager().using_pacing()); EXPECT_FALSE(server.sent_packet_manager().using_pacing()); } diff --git a/net/quic/quic_dispatcher.cc b/net/quic/quic_dispatcher.cc index 802f00b..3c867b0 100644 --- a/net/quic/quic_dispatcher.cc +++ b/net/quic/quic_dispatcher.cc @@ -12,6 +12,7 @@ #include "net/quic/quic_blocked_writer_interface.h" #include "net/quic/quic_connection_helper.h" #include "net/quic/quic_flags.h" +#include "net/quic/quic_per_connection_packet_writer.h" #include "net/quic/quic_time_wait_list_manager.h" #include "net/quic/quic_utils.h" @@ -154,15 +155,37 @@ class QuicDispatcher::QuicFramerVisitor : public QuicFramerVisitorInterface { QuicConnectionId connection_id_; }; +QuicPacketWriter* QuicDispatcher::DefaultPacketWriterFactory::Create( + QuicServerPacketWriter* writer, + QuicConnection* connection) { + return new QuicPerConnectionPacketWriter(writer, connection); +} + +QuicDispatcher::PacketWriterFactoryAdapter::PacketWriterFactoryAdapter( + QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + +QuicDispatcher::PacketWriterFactoryAdapter::~PacketWriterFactoryAdapter() {} + +QuicPacketWriter* QuicDispatcher::PacketWriterFactoryAdapter::Create( + QuicConnection* connection) const { + return dispatcher_->packet_writer_factory_->Create( + dispatcher_->writer_.get(), + connection); +} + QuicDispatcher::QuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, const QuicVersionVector& supported_versions, + PacketWriterFactory* packet_writer_factory, QuicConnectionHelperInterface* helper) : config_(config), crypto_config_(crypto_config), helper_(helper), delete_sessions_alarm_( helper_->CreateAlarm(new DeleteSessionsAlarm(this))), + packet_writer_factory_(packet_writer_factory), + connection_writer_factory_(this), supported_versions_(supported_versions), current_packet_(NULL), framer_(supported_versions, /*unused*/ QuicTime::Zero(), true), @@ -339,17 +362,9 @@ QuicSession* QuicDispatcher::CreateQuicSession( QuicConnectionId connection_id, const IPEndPoint& server_address, const IPEndPoint& client_address) { - QuicPerConnectionPacketWriter* per_connection_packet_writer = - new QuicPerConnectionPacketWriter(writer_.get()); - QuicConnection* connection = - CreateQuicConnection(connection_id, - server_address, - client_address, - per_connection_packet_writer); QuicServerSession* session = new QuicServerSession( config_, - connection, - per_connection_packet_writer, + CreateQuicConnection(connection_id, server_address, client_address), this); session->InitializeSession(crypto_config_); return session; @@ -358,19 +373,14 @@ QuicSession* QuicDispatcher::CreateQuicSession( QuicConnection* QuicDispatcher::CreateQuicConnection( QuicConnectionId connection_id, const IPEndPoint& server_address, - const IPEndPoint& client_address, - QuicPerConnectionPacketWriter* writer) { - QuicConnection* connection; - connection = new QuicConnection( - connection_id, - client_address, - helper_, - writer, - false /* owns_writer */, - true /* is_server */, - supported_versions_); - writer->set_connection(connection); - return connection; + const IPEndPoint& client_address) { + return new QuicConnection(connection_id, + client_address, + helper_, + connection_writer_factory_, + /* owns_writer= */ true, + /* is_server= */ true, + supported_versions_); } QuicTimeWaitListManager* QuicDispatcher::CreateQuicTimeWaitListManager() { diff --git a/net/quic/quic_dispatcher.h b/net/quic/quic_dispatcher.h index 2687e8b..a4dba81 100644 --- a/net/quic/quic_dispatcher.h +++ b/net/quic/quic_dispatcher.h @@ -22,17 +22,6 @@ #include "net/quic/quic_server_session.h" #include "net/quic/quic_time_wait_list_manager.h" -#if defined(COMPILER_GCC) -namespace BASE_HASH_NAMESPACE { -template <> -struct hash<net::QuicBlockedWriterInterface*> { - std::size_t operator()(const net::QuicBlockedWriterInterface* ptr) const { - return hash<size_t>()(reinterpret_cast<size_t>(ptr)); - } -}; -} -#endif - namespace net { class QuicConfig; @@ -57,15 +46,40 @@ class QuicDispatcher : public QuicBlockedWriterInterface, public QuicServerSessionVisitor, public ProcessPacketInterface { public: + // Creates per-connection packet writers out of the QuicDispatcher's shared + // QuicPacketWriter. The per-connection writers' IsWriteBlocked() state must + // always be the same as the shared writer's IsWriteBlocked(), or else the + // QuicDispatcher::OnCanWrite logic will not work. (This will hopefully be + // cleaned up for bug 16950226.) + class PacketWriterFactory { + public: + virtual ~PacketWriterFactory() {} + + virtual QuicPacketWriter* Create(QuicServerPacketWriter* writer, + QuicConnection* connection) = 0; + }; + + // Creates ordinary QuicPerConnectionPacketWriter instances. + class DefaultPacketWriterFactory : public PacketWriterFactory { + public: + virtual ~DefaultPacketWriterFactory() {} + + virtual QuicPacketWriter* Create( + QuicServerPacketWriter* writer, + QuicConnection* connection) OVERRIDE; + }; + // Ideally we'd have a linked_hash_set: the boolean is unused. typedef linked_hash_map<QuicBlockedWriterInterface*, bool> WriteBlockedList; - // Due to the way delete_sessions_closure_ is registered, the Dispatcher - // must live until epoll_server Shutdown. |supported_versions| specifies the - // list of supported QUIC versions. + // Due to the way delete_sessions_closure_ is registered, the Dispatcher must + // live until epoll_server Shutdown. |supported_versions| specifies the list + // of supported QUIC versions. Takes ownership of |packet_writer_factory|, + // which is used to create per-connection writers. QuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, const QuicVersionVector& supported_versions, + PacketWriterFactory* packet_writer_factory, QuicConnectionHelperInterface* helper); virtual ~QuicDispatcher(); @@ -113,8 +127,7 @@ class QuicDispatcher : public QuicBlockedWriterInterface, virtual QuicConnection* CreateQuicConnection( QuicConnectionId connection_id, const IPEndPoint& server_address, - const IPEndPoint& client_address, - QuicPerConnectionPacketWriter* writer); + const IPEndPoint& client_address); // Called by |framer_visitor_| when the public header has been parsed. virtual bool OnUnauthenticatedPublicHeader( @@ -157,10 +170,29 @@ class QuicDispatcher : public QuicBlockedWriterInterface, QuicServerPacketWriter* writer() { return writer_.get(); } + const QuicConnection::PacketWriterFactory& connection_writer_factory() { + return connection_writer_factory_; + } + private: class QuicFramerVisitor; friend class net::test::QuicDispatcherPeer; + // An adapter that creates packet writers using the dispatcher's + // PacketWriterFactory and shared writer. Essentially, it just curries the + // writer argument away from QuicDispatcher::PacketWriterFactory. + class PacketWriterFactoryAdapter : + public QuicConnection::PacketWriterFactory { + public: + PacketWriterFactoryAdapter(QuicDispatcher* dispatcher); + virtual ~PacketWriterFactoryAdapter (); + + virtual QuicPacketWriter* Create(QuicConnection* connection) const OVERRIDE; + + private: + QuicDispatcher* dispatcher_; + }; + // Called by |framer_visitor_| when the private header has been parsed // of a data packet that is destined for the time wait manager. void OnUnauthenticatedHeader(const QuicPacketHeader& header); @@ -195,6 +227,12 @@ class QuicDispatcher : public QuicBlockedWriterInterface, // The writer to write to the socket with. scoped_ptr<QuicServerPacketWriter> writer_; + // Used to create per-connection packet writers, not |writer_| itself. + scoped_ptr<PacketWriterFactory> packet_writer_factory_; + + // Passed in to QuicConnection for it to create the per-connection writers + PacketWriterFactoryAdapter connection_writer_factory_; + // This vector contains QUIC versions which we currently support. // This should be ordered such that the highest supported version is the first // element, with subsequent elements in descending order (versions can be diff --git a/net/quic/quic_http_stream_test.cc b/net/quic/quic_http_stream_test.cc index 7cb227c6..7ebf8e5 100644 --- a/net/quic/quic_http_stream_test.cc +++ b/net/quic/quic_http_stream_test.cc @@ -59,12 +59,12 @@ class TestQuicConnection : public QuicConnection { QuicConnectionId connection_id, IPEndPoint address, QuicConnectionHelper* helper, - QuicPacketWriter* writer) + const QuicConnection::PacketWriterFactory& writer_factory) : QuicConnection(connection_id, address, helper, - writer, - false /* owns_writer */, + writer_factory, + true /* owns_writer */, false /* is_server */, versions) { } @@ -103,6 +103,20 @@ class AutoClosingStream : public QuicHttpStream { } }; +class TestPacketWriterFactory : public QuicConnection::PacketWriterFactory { + public: + explicit TestPacketWriterFactory(DatagramClientSocket* socket) + : socket_(socket) {} + virtual ~TestPacketWriterFactory() {} + + virtual QuicPacketWriter* Create(QuicConnection* connection) const OVERRIDE { + return new QuicDefaultPacketWriter(socket_); + } + + private: + DatagramClientSocket* socket_; +}; + } // namespace class QuicHttpStreamPeer { @@ -202,10 +216,10 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<QuicVersion> { EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)).Times(AnyNumber()); helper_.reset(new QuicConnectionHelper(runner_.get(), &clock_, &random_generator_)); - writer_.reset(new QuicDefaultPacketWriter(socket)); + TestPacketWriterFactory writer_factory(socket); connection_ = new TestQuicConnection(SupportedVersions(GetParam()), connection_id_, peer_addr_, - helper_.get(), writer_.get()); + helper_.get(), writer_factory); connection_->set_visitor(&visitor_); connection_->SetSendAlgorithm(send_algorithm_); connection_->SetReceiveAlgorithm(receive_algorithm_); @@ -213,7 +227,7 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<QuicVersion> { session_.reset( new QuicClientSession(connection_, scoped_ptr<DatagramClientSocket>(socket), - writer_.Pass(), NULL, + NULL, &crypto_client_stream_factory_, &transport_security_state_, make_scoped_ptr((QuicServerInfo*)NULL), @@ -300,7 +314,6 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<QuicVersion> { scoped_ptr<QuicConnectionHelper> helper_; testing::StrictMock<MockConnectionVisitor> visitor_; scoped_ptr<QuicHttpStream> stream_; - scoped_ptr<QuicDefaultPacketWriter> writer_; TransportSecurityState transport_security_state_; scoped_ptr<QuicClientSession> session_; QuicCryptoClientConfig crypto_config_; diff --git a/net/quic/quic_per_connection_packet_writer.cc b/net/quic/quic_per_connection_packet_writer.cc index 6347bc2..efb4fec 100644 --- a/net/quic/quic_per_connection_packet_writer.cc +++ b/net/quic/quic_per_connection_packet_writer.cc @@ -5,24 +5,30 @@ #include "net/quic/quic_per_connection_packet_writer.h" #include "net/quic/quic_server_packet_writer.h" -#include "net/quic/quic_types.h" namespace net { QuicPerConnectionPacketWriter::QuicPerConnectionPacketWriter( - QuicServerPacketWriter* writer) - : weak_factory_(this), writer_(writer) { + QuicServerPacketWriter* shared_writer, + QuicConnection* connection) + : weak_factory_(this), + shared_writer_(shared_writer), + connection_(connection) { } QuicPerConnectionPacketWriter::~QuicPerConnectionPacketWriter() { } +QuicPacketWriter* QuicPerConnectionPacketWriter::shared_writer() const { + return shared_writer_; +} + WriteResult QuicPerConnectionPacketWriter::WritePacket( const char* buffer, size_t buf_len, const IPAddressNumber& self_address, const IPEndPoint& peer_address) { - return writer_->WritePacketWithCallback( + return shared_writer_->WritePacketWithCallback( buffer, buf_len, self_address, @@ -32,15 +38,15 @@ WriteResult QuicPerConnectionPacketWriter::WritePacket( } bool QuicPerConnectionPacketWriter::IsWriteBlockedDataBuffered() const { - return writer_->IsWriteBlockedDataBuffered(); + return shared_writer_->IsWriteBlockedDataBuffered(); } bool QuicPerConnectionPacketWriter::IsWriteBlocked() const { - return writer_->IsWriteBlocked(); + return shared_writer_->IsWriteBlocked(); } void QuicPerConnectionPacketWriter::SetWritable() { - writer_->SetWritable(); + shared_writer_->SetWritable(); } void QuicPerConnectionPacketWriter::OnWriteComplete(WriteResult result) { diff --git a/net/quic/quic_per_connection_packet_writer.h b/net/quic/quic_per_connection_packet_writer.h index 96a37b6..e88a37d 100644 --- a/net/quic/quic_per_connection_packet_writer.h +++ b/net/quic/quic_per_connection_packet_writer.h @@ -5,13 +5,9 @@ #ifndef NET_QUIC_QUIC_PER_CONNECTION_PACKET_WRITER_H_ #define NET_QUIC_QUIC_PER_CONNECTION_PACKET_WRITER_H_ -#include "base/basictypes.h" #include "base/memory/weak_ptr.h" -#include "net/base/ip_endpoint.h" #include "net/quic/quic_connection.h" #include "net/quic/quic_packet_writer.h" -#include "net/quic/quic_protocol.h" -#include "net/quic/quic_types.h" namespace net { @@ -21,16 +17,18 @@ class QuicServerPacketWriter; // writes to the shared QuicServerPacketWriter complete. // This class is necessary because multiple connections can share the same // QuicServerPacketWriter, so it has no way to know which connection to notify. -// TODO(dmz) Try to merge with Chrome's default packet writer class QuicPerConnectionPacketWriter : public QuicPacketWriter { public: - QuicPerConnectionPacketWriter(QuicServerPacketWriter* writer); + // Does not take ownership of |shared_writer| or |connection|. + QuicPerConnectionPacketWriter(QuicServerPacketWriter* shared_writer, + QuicConnection* connection); virtual ~QuicPerConnectionPacketWriter(); - // Set the connection to notify after writes complete. - void set_connection(QuicConnection* connection) { connection_ = connection; } + QuicPacketWriter* shared_writer() const; + QuicConnection* connection() const { return connection_; } - // QuicPacketWriter + // Default implementation of the QuicPacketWriter interface: Passes everything + // to |shared_writer_|. virtual WriteResult WritePacket(const char* buffer, size_t buf_len, const IPAddressNumber& self_address, @@ -43,8 +41,8 @@ class QuicPerConnectionPacketWriter : public QuicPacketWriter { void OnWriteComplete(WriteResult result); base::WeakPtrFactory<QuicPerConnectionPacketWriter> weak_factory_; - QuicServerPacketWriter* writer_; // Not owned. - QuicConnection* connection_; + QuicServerPacketWriter* shared_writer_; // Not owned. + QuicConnection* connection_; // Not owned. DISALLOW_COPY_AND_ASSIGN(QuicPerConnectionPacketWriter); }; diff --git a/net/quic/quic_server.cc b/net/quic/quic_server.cc index e2a9da2..b8013c4 100644 --- a/net/quic/quic_server.cc +++ b/net/quic/quic_server.cc @@ -104,6 +104,7 @@ int QuicServer::Listen(const IPEndPoint& address) { new QuicDispatcher(config_, crypto_config_, supported_versions_, + new QuicDispatcher::DefaultPacketWriterFactory(), &helper_)); QuicServerPacketWriter* writer = new QuicServerPacketWriter( socket_.get(), diff --git a/net/quic/quic_server_session.cc b/net/quic/quic_server_session.cc index 399b772..c7141e6 100644 --- a/net/quic/quic_server_session.cc +++ b/net/quic/quic_server_session.cc @@ -15,10 +15,8 @@ namespace net { QuicServerSession::QuicServerSession( const QuicConfig& config, QuicConnection* connection, - QuicPerConnectionPacketWriter* connection_packet_writer, QuicServerSessionVisitor* visitor) : QuicSession(connection, config), - connection_packet_writer_(connection_packet_writer), visitor_(visitor) {} QuicServerSession::~QuicServerSession() {} diff --git a/net/quic/quic_server_session.h b/net/quic/quic_server_session.h index 49c0337..344feb2 100644 --- a/net/quic/quic_server_session.h +++ b/net/quic/quic_server_session.h @@ -44,10 +44,8 @@ class QuicServerSessionVisitor { class QuicServerSession : public QuicSession { public: - // Takes ownership of connection_packet_writer QuicServerSession(const QuicConfig& config, QuicConnection* connection, - QuicPerConnectionPacketWriter* connection_packet_writer, QuicServerSessionVisitor* visitor); // Override the base class to notify the owner of the connection close. @@ -83,7 +81,6 @@ class QuicServerSession : public QuicSession { friend class test::QuicServerSessionPeer; scoped_ptr<QuicCryptoServerStream> crypto_stream_; - scoped_ptr<QuicPerConnectionPacketWriter> connection_packet_writer_; QuicServerSessionVisitor* visitor_; DISALLOW_COPY_AND_ASSIGN(QuicServerSession); diff --git a/net/quic/quic_server_test.cc b/net/quic/quic_server_test.cc index 74db34d..5aa1524 100644 --- a/net/quic/quic_server_test.cc +++ b/net/quic/quic_server_test.cc @@ -22,7 +22,10 @@ class QuicChromeServerDispatchPacketTest : public ::testing::Test { public: QuicChromeServerDispatchPacketTest() : crypto_config_("blah", QuicRandom::GetInstance()), - dispatcher_(config_, crypto_config_, &helper_) { + dispatcher_(config_, + crypto_config_, + new QuicDispatcher::DefaultPacketWriterFactory(), + &helper_) { dispatcher_.Initialize(NULL); } diff --git a/net/quic/quic_stream_factory.cc b/net/quic/quic_stream_factory.cc index aa1c0f8..b012913 100644 --- a/net/quic/quic_stream_factory.cc +++ b/net/quic/quic_stream_factory.cc @@ -98,6 +98,26 @@ QuicConfig InitializeQuicConfig(bool enable_time_based_loss_detection, return config; } +class DefaultPacketWriterFactory : public QuicConnection::PacketWriterFactory { + public: + explicit DefaultPacketWriterFactory(DatagramClientSocket* socket) + : socket_(socket) {} + virtual ~DefaultPacketWriterFactory() {} + + virtual QuicPacketWriter* Create(QuicConnection* connection) const OVERRIDE; + + private: + DatagramClientSocket* socket_; +}; + +QuicPacketWriter* DefaultPacketWriterFactory::Create( + QuicConnection* connection) const { + scoped_ptr<QuicDefaultPacketWriter> writer( + new QuicDefaultPacketWriter(socket_)); + writer->SetConnection(connection); + return writer.release(); +} + } // namespace QuicStreamFactory::IpAliasKey::IpAliasKey() {} @@ -820,8 +840,7 @@ int QuicStreamFactory::CreateSession( return rv; } - scoped_ptr<QuicDefaultPacketWriter> writer( - new QuicDefaultPacketWriter(socket.get())); + DefaultPacketWriterFactory packet_writer_factory(socket.get()); if (!helper_.get()) { helper_.reset(new QuicConnectionHelper( @@ -832,11 +851,10 @@ int QuicStreamFactory::CreateSession( QuicConnection* connection = new QuicConnection(connection_id, addr, helper_.get(), - writer.get(), - false /* owns_writer */, + packet_writer_factory, + true /* owns_writer */, false /* is_server */, supported_versions_); - writer->SetConnection(connection); connection->set_max_packet_length(max_packet_length_); InitializeCachedStateInCryptoConfig(server_id, server_info); @@ -858,7 +876,7 @@ int QuicStreamFactory::CreateSession( } *session = new QuicClientSession( - connection, socket.Pass(), writer.Pass(), this, + connection, socket.Pass(), this, quic_crypto_client_stream_factory_, transport_security_state_, server_info.Pass(), server_id, config, &crypto_config_, base::MessageLoop::current()->message_loop_proxy().get(), diff --git a/net/quic/test_tools/mock_quic_dispatcher.cc b/net/quic/test_tools/mock_quic_dispatcher.cc index f3f41a5..9c9da43 100644 --- a/net/quic/test_tools/mock_quic_dispatcher.cc +++ b/net/quic/test_tools/mock_quic_dispatcher.cc @@ -12,8 +12,13 @@ namespace test { MockQuicDispatcher::MockQuicDispatcher( const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, + QuicDispatcher::PacketWriterFactory* packet_writer_factory, QuicConnectionHelperInterface* helper) - : QuicDispatcher(config, crypto_config, QuicSupportedVersions(), helper) { + : QuicDispatcher(config, + crypto_config, + QuicSupportedVersions(), + packet_writer_factory, + helper) { } MockQuicDispatcher::~MockQuicDispatcher() { diff --git a/net/quic/test_tools/mock_quic_dispatcher.h b/net/quic/test_tools/mock_quic_dispatcher.h index 93ecda4..f923790 100644 --- a/net/quic/test_tools/mock_quic_dispatcher.h +++ b/net/quic/test_tools/mock_quic_dispatcher.h @@ -19,6 +19,7 @@ class MockQuicDispatcher : public QuicDispatcher { public: MockQuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, + PacketWriterFactory* packet_writer_factory, QuicConnectionHelperInterface* helper); virtual ~MockQuicDispatcher(); diff --git a/net/quic/test_tools/quic_connection_peer.cc b/net/quic/test_tools/quic_connection_peer.cc index d1aa8a5..b752656 100644 --- a/net/quic/test_tools/quic_connection_peer.cc +++ b/net/quic/test_tools/quic_connection_peer.cc @@ -167,6 +167,7 @@ QuicFramer* QuicConnectionPeer::GetFramer(QuicConnection* connection) { return &connection->framer_; } +// static QuicFecGroup* QuicConnectionPeer::GetFecGroup(QuicConnection* connection, int fec_group) { connection->last_header_.fec_group = fec_group; diff --git a/net/quic/test_tools/quic_test_utils.cc b/net/quic/test_tools/quic_test_utils.cc index c781b23..19fd682 100644 --- a/net/quic/test_tools/quic_test_utils.cc +++ b/net/quic/test_tools/quic_test_utils.cc @@ -221,12 +221,29 @@ void MockHelper::AdvanceTime(QuicTime::Delta delta) { clock_.AdvanceTime(delta); } +namespace { +class NiceMockPacketWriterFactory + : public QuicConnection::PacketWriterFactory { + public: + NiceMockPacketWriterFactory() {} + virtual ~NiceMockPacketWriterFactory() {} + + virtual QuicPacketWriter* Create( + QuicConnection* /*connection*/) const override { + return new testing::NiceMock<MockPacketWriter>(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(NiceMockPacketWriterFactory); +}; +} // namespace + MockConnection::MockConnection(bool is_server) : QuicConnection(kTestConnectionId, IPEndPoint(TestPeerIPAddress(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -235,8 +252,8 @@ MockConnection::MockConnection(IPEndPoint address, bool is_server) : QuicConnection(kTestConnectionId, address, new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -246,8 +263,8 @@ MockConnection::MockConnection(QuicConnectionId connection_id, : QuicConnection(connection_id, IPEndPoint(TestPeerIPAddress(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -257,8 +274,8 @@ MockConnection::MockConnection(bool is_server, : QuicConnection(kTestConnectionId, IPEndPoint(TestPeerIPAddress(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, supported_versions), helper_(helper()) { } @@ -606,5 +623,53 @@ QuicVersionVector SupportedVersions(QuicVersion version) { return versions; } +TestWriterFactory::TestWriterFactory() : current_writer_(NULL) {} +TestWriterFactory::~TestWriterFactory() {} + +QuicPacketWriter* TestWriterFactory::Create(QuicServerPacketWriter* writer, + QuicConnection* connection) { + return new PerConnectionPacketWriter(this, writer, connection); +} + +void TestWriterFactory::OnPacketSent(WriteResult result) { + if (current_writer_ != NULL) { + current_writer_->connection()->OnPacketSent(result); + current_writer_ = NULL; + } +} + +void TestWriterFactory::Unregister(PerConnectionPacketWriter* writer) { + if (current_writer_ == writer) { + current_writer_ = NULL; + } +} + +TestWriterFactory::PerConnectionPacketWriter::PerConnectionPacketWriter( + TestWriterFactory* factory, + QuicServerPacketWriter* writer, + QuicConnection* connection) + : QuicPerConnectionPacketWriter(writer, connection), + factory_(factory) { +} + +TestWriterFactory::PerConnectionPacketWriter::~PerConnectionPacketWriter() { + factory_->Unregister(this); +} + +WriteResult TestWriterFactory::PerConnectionPacketWriter::WritePacket( + const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) { + // A DCHECK(factory_current_writer_ == NULL) would be wrong here -- this class + // may be used in a setting where connection()->OnPacketSent() is called in a + // different way, so TestWriterFactory::OnPacketSent might never be called. + factory_->current_writer_ = this; + return QuicPerConnectionPacketWriter::WritePacket(buffer, + buf_len, + self_address, + peer_address); +} + } // namespace test } // namespace net diff --git a/net/quic/test_tools/quic_test_utils.h b/net/quic/test_tools/quic_test_utils.h index ca2df7e..2350668 100644 --- a/net/quic/test_tools/quic_test_utils.h +++ b/net/quic/test_tools/quic_test_utils.h @@ -16,7 +16,9 @@ #include "net/quic/quic_ack_notifier.h" #include "net/quic/quic_client_session_base.h" #include "net/quic/quic_connection.h" +#include "net/quic/quic_dispatcher.h" #include "net/quic/quic_framer.h" +#include "net/quic/quic_per_connection_packet_writer.h" #include "net/quic/quic_sent_packet_manager.h" #include "net/quic/quic_session.h" #include "net/quic/test_tools/mock_clock.h" @@ -543,6 +545,46 @@ class MockNetworkChangeVisitor : DISALLOW_COPY_AND_ASSIGN(MockNetworkChangeVisitor); }; +// Creates per-connection packet writers that register themselves with the +// TestWriterFactory on each write so that TestWriterFactory::OnPacketSent can +// be routed to the appropriate QuicConnection. +class TestWriterFactory : public QuicDispatcher::PacketWriterFactory { + public: + TestWriterFactory(); + virtual ~TestWriterFactory(); + + virtual QuicPacketWriter* Create(QuicServerPacketWriter* writer, + QuicConnection* connection) OVERRIDE; + + // Calls OnPacketSent on the last QuicConnection to write through one of the + // packet writers created by this factory. + void OnPacketSent(WriteResult result); + + private: + class PerConnectionPacketWriter : public QuicPerConnectionPacketWriter { + public: + PerConnectionPacketWriter(TestWriterFactory* factory, + QuicServerPacketWriter* writer, + QuicConnection* connection); + virtual ~PerConnectionPacketWriter(); + + virtual WriteResult WritePacket( + const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) OVERRIDE; + + private: + TestWriterFactory* factory_; + }; + + // If an asynchronous write is happening and |writer| gets deleted, this + // clears the pointer to it to prevent use-after-free. + void Unregister(PerConnectionPacketWriter* writer); + + PerConnectionPacketWriter* current_writer_; +}; + } // namespace test } // namespace net diff --git a/net/tools/quic/end_to_end_test.cc b/net/tools/quic/end_to_end_test.cc index 24ad3905..1f4103f 100644 --- a/net/tools/quic/end_to_end_test.cc +++ b/net/tools/quic/end_to_end_test.cc @@ -161,11 +161,17 @@ vector<TestParams> GetTestParams() { class ServerDelegate : public PacketDroppingTestWriter::Delegate { public: - explicit ServerDelegate(QuicDispatcher* dispatcher) - : dispatcher_(dispatcher) {} + ServerDelegate(TestWriterFactory* writer_factory, + QuicDispatcher* dispatcher) + : writer_factory_(writer_factory), + dispatcher_(dispatcher) {} virtual ~ServerDelegate() {} + virtual void OnPacketSent(WriteResult result) override { + writer_factory_->OnPacketSent(result); + } virtual void OnCanWrite() OVERRIDE { dispatcher_->OnCanWrite(); } private: + TestWriterFactory* writer_factory_; QuicDispatcher* dispatcher_; }; @@ -173,6 +179,7 @@ class ClientDelegate : public PacketDroppingTestWriter::Delegate { public: explicit ClientDelegate(QuicClient* client) : client_(client) {} virtual ~ClientDelegate() {} + virtual void OnPacketSent(WriteResult result) OVERRIDE {} virtual void OnCanWrite() OVERRIDE { EpollEvent event(EPOLLOUT, false); client_->OnEvent(client_->fd(), &event); @@ -326,7 +333,7 @@ class EndToEndTest : public ::testing::TestWithParam<TestParams> { virtual void SetUp() OVERRIDE { // The ownership of these gets transferred to the QuicPacketWriterWrapper - // and QuicDispatcher when Initialize() is executed. + // and TestWriterFactory when Initialize() is executed. client_writer_ = new PacketDroppingTestWriter(); server_writer_ = new PacketDroppingTestWriter(); } @@ -346,10 +353,13 @@ class EndToEndTest : public ::testing::TestWithParam<TestParams> { server_thread_->GetPort()); QuicDispatcher* dispatcher = QuicServerPeer::GetDispatcher(server_thread_->server()); + TestWriterFactory* packet_writer_factory = new TestWriterFactory(); + QuicDispatcherPeer::SetPacketWriterFactory(dispatcher, + packet_writer_factory); QuicDispatcherPeer::UseWriter(dispatcher, server_writer_); server_writer_->Initialize( QuicDispatcherPeer::GetHelper(dispatcher), - new ServerDelegate(dispatcher)); + new ServerDelegate(packet_writer_factory, dispatcher)); server_thread_->Start(); server_started_ = true; } @@ -1150,7 +1160,7 @@ TEST_P(EndToEndTest, ConnectionMigrationClientIPChanged) { writer->set_writer(new QuicDefaultPacketWriter(client_->client()->fd())); QuicConnectionPeer::SetWriter(client_->client()->session()->connection(), writer, - true /* owns_writer */); + /* owns_writer= */ true); client_->SendSynchronousRequest("/bar"); diff --git a/net/tools/quic/quic_client.cc b/net/tools/quic/quic_client.cc index ca87c2f..48f789f 100644 --- a/net/tools/quic/quic_client.cc +++ b/net/tools/quic/quic_client.cc @@ -97,6 +97,18 @@ bool QuicClient::Initialize() { return true; } +QuicClient::DummyPacketWriterFactory::DummyPacketWriterFactory( + QuicPacketWriter* writer) + : writer_(writer) {} + +QuicClient::DummyPacketWriterFactory::~DummyPacketWriterFactory() {} + +QuicPacketWriter* QuicClient::DummyPacketWriterFactory::Create( + QuicConnection* /*connection*/) const { + return writer_; +} + + bool QuicClient::CreateUDPSocket() { int address_family = server_address_.GetSockAddrFamily(); fd_ = socket(address_family, SOCK_DGRAM | SOCK_NONBLOCK, IPPROTO_UDP); @@ -179,15 +191,18 @@ bool QuicClient::StartConnect() { QuicPacketWriter* writer = CreateQuicPacketWriter(); + DummyPacketWriterFactory factory(writer); + session_.reset(new QuicClientSession( config_, new QuicConnection(GenerateConnectionId(), server_address_, helper_.get(), - writer, - false /* owns_writer */, - false /* is_server */, + factory, + /* owns_writer= */ false, + /* is_server= */ false, supported_versions_))); + // Reset |writer_| after |session_| so that the old writer outlives the old // session. if (writer_.get() != writer) { diff --git a/net/tools/quic/quic_client.h b/net/tools/quic/quic_client.h index a4fb6c3..8b0f1fe 100644 --- a/net/tools/quic/quic_client.h +++ b/net/tools/quic/quic_client.h @@ -188,6 +188,18 @@ class QuicClient : public EpollCallbackInterface, private: friend class net::tools::test::QuicClientPeer; + // A packet writer factory that always returns the same writer + class DummyPacketWriterFactory : public QuicConnection::PacketWriterFactory { + public: + DummyPacketWriterFactory(QuicPacketWriter* writer); + virtual ~DummyPacketWriterFactory(); + + virtual QuicPacketWriter* Create(QuicConnection* connection) const OVERRIDE; + + private: + QuicPacketWriter* writer_; + }; + // Used during initialization: creates the UDP socket FD, sets socket options, // and binds the socket to our address. bool CreateUDPSocket(); diff --git a/net/tools/quic/quic_dispatcher.cc b/net/tools/quic/quic_dispatcher.cc index cf3a46b..05cb39d 100644 --- a/net/tools/quic/quic_dispatcher.cc +++ b/net/tools/quic/quic_dispatcher.cc @@ -15,6 +15,7 @@ #include "net/tools/epoll_server/epoll_server.h" #include "net/tools/quic/quic_default_packet_writer.h" #include "net/tools/quic/quic_epoll_connection_helper.h" +#include "net/tools/quic/quic_per_connection_packet_writer.h" #include "net/tools/quic/quic_socket_utils.h" #include "net/tools/quic/quic_time_wait_list_manager.h" @@ -159,15 +160,37 @@ class QuicDispatcher::QuicFramerVisitor : public QuicFramerVisitorInterface { QuicConnectionId connection_id_; }; +QuicPacketWriter* QuicDispatcher::DefaultPacketWriterFactory::Create( + QuicPacketWriter* writer, + QuicConnection* connection) { + return new QuicPerConnectionPacketWriter(writer, connection); +} + +QuicDispatcher::PacketWriterFactoryAdapter::PacketWriterFactoryAdapter( + QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + +QuicDispatcher::PacketWriterFactoryAdapter::~PacketWriterFactoryAdapter() {} + +QuicPacketWriter* QuicDispatcher::PacketWriterFactoryAdapter::Create( + QuicConnection* connection) const { + return dispatcher_->packet_writer_factory_->Create( + dispatcher_->writer_.get(), + connection); +} + QuicDispatcher::QuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, const QuicVersionVector& supported_versions, + PacketWriterFactory* packet_writer_factory, EpollServer* epoll_server) : config_(config), crypto_config_(crypto_config), delete_sessions_alarm_(new DeleteSessionsAlarm(this)), epoll_server_(epoll_server), helper_(new QuicEpollConnectionHelper(epoll_server_)), + packet_writer_factory_(packet_writer_factory), + connection_writer_factory_(this), supported_versions_(supported_versions), current_packet_(NULL), framer_(supported_versions, /*unused*/ QuicTime::Zero(), true), @@ -366,9 +389,9 @@ QuicConnection* QuicDispatcher::CreateQuicConnection( return new QuicConnection(connection_id, client_address, helper_.get(), - writer_.get(), - false /* owns_writer */, - true /* is_server */, + connection_writer_factory_, + /* owns_writer= */ true, + /* is_server= */ true, supported_versions_); } diff --git a/net/tools/quic/quic_dispatcher.h b/net/tools/quic/quic_dispatcher.h index 8d7b03b..ba6a7ad 100644 --- a/net/tools/quic/quic_dispatcher.h +++ b/net/tools/quic/quic_dispatcher.h @@ -20,18 +20,6 @@ #include "net/tools/quic/quic_server_session.h" #include "net/tools/quic/quic_time_wait_list_manager.h" -#if defined(COMPILER_GCC) -namespace BASE_HASH_NAMESPACE { -template<> -struct hash<net::QuicBlockedWriterInterface*> { - std::size_t operator()( - const net::QuicBlockedWriterInterface* ptr) const { - return hash<size_t>()(reinterpret_cast<size_t>(ptr)); - } -}; -} -#endif - namespace net { class EpollServer; @@ -61,15 +49,40 @@ class ProcessPacketInterface { class QuicDispatcher : public QuicServerSessionVisitor, public ProcessPacketInterface { public: + // Creates per-connection packet writers out of the QuicDispatcher's shared + // QuicPacketWriter. The per-connection writers' IsWriteBlocked() state must + // always be the same as the shared writer's IsWriteBlocked(), or else the + // QuicDispatcher::OnCanWrite logic will not work. (This will hopefully be + // cleaned up for bug 16950226.) + class PacketWriterFactory { + public: + virtual ~PacketWriterFactory() {} + + virtual QuicPacketWriter* Create(QuicPacketWriter* writer, + QuicConnection* connection) = 0; + }; + + // Creates ordinary QuicPerConnectionPacketWriter instances. + class DefaultPacketWriterFactory : public PacketWriterFactory { + public: + virtual ~DefaultPacketWriterFactory() {} + + virtual QuicPacketWriter* Create( + QuicPacketWriter* writer, + QuicConnection* connection) OVERRIDE; + }; + // Ideally we'd have a linked_hash_set: the boolean is unused. typedef linked_hash_map<QuicBlockedWriterInterface*, bool> WriteBlockedList; - // Due to the way delete_sessions_closure_ is registered, the Dispatcher - // must live until epoll_server Shutdown. |supported_versions| specifies the - // list of supported QUIC versions. + // Due to the way delete_sessions_closure_ is registered, the Dispatcher must + // live until epoll_server Shutdown. |supported_versions| specifies the list + // of supported QUIC versions. Takes ownership of |packet_writer_factory|, + // which is used to create per-connection writers. QuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, const QuicVersionVector& supported_versions, + PacketWriterFactory* packet_writer_factory, EpollServer* epoll_server); virtual ~QuicDispatcher(); @@ -164,10 +177,29 @@ class QuicDispatcher : public QuicServerSessionVisitor, QuicPacketWriter* writer() { return writer_.get(); } + const QuicConnection::PacketWriterFactory& connection_writer_factory() { + return connection_writer_factory_; + } + private: class QuicFramerVisitor; friend class net::tools::test::QuicDispatcherPeer; + // An adapter that creates packet writers using the dispatcher's + // PacketWriterFactory and shared writer. Essentially, it just curries the + // writer argument away from QuicDispatcher::PacketWriterFactory. + class PacketWriterFactoryAdapter : + public QuicConnection::PacketWriterFactory { + public: + PacketWriterFactoryAdapter(QuicDispatcher* dispatcher); + virtual ~PacketWriterFactoryAdapter (); + + virtual QuicPacketWriter* Create(QuicConnection* connection) const OVERRIDE; + + private: + QuicDispatcher* dispatcher_; + }; + // Called by |framer_visitor_| when the private header has been parsed // of a data packet that is destined for the time wait manager. void OnUnauthenticatedHeader(const QuicPacketHeader& header); @@ -204,6 +236,12 @@ class QuicDispatcher : public QuicServerSessionVisitor, // The writer to write to the socket with. scoped_ptr<QuicPacketWriter> writer_; + // Used to create per-connection packet writers, not |writer_| itself. + scoped_ptr<PacketWriterFactory> packet_writer_factory_; + + // Passed in to QuicConnection for it to create the per-connection writers + PacketWriterFactoryAdapter connection_writer_factory_; + // This vector contains QUIC versions which we currently support. // This should be ordered such that the highest supported version is the first // element, with subsequent elements in descending order (versions can be diff --git a/net/tools/quic/quic_dispatcher_test.cc b/net/tools/quic/quic_dispatcher_test.cc index 4c778cb..33685c9 100644 --- a/net/tools/quic/quic_dispatcher_test.cc +++ b/net/tools/quic/quic_dispatcher_test.cc @@ -48,6 +48,7 @@ class TestDispatcher : public QuicDispatcher { : QuicDispatcher(config, crypto_config, QuicSupportedVersions(), + new QuicDispatcher::DefaultPacketWriterFactory(), eps) { } @@ -271,12 +272,11 @@ class BlockingWriter : public QuicPacketWriterWrapper { size_t buf_len, const IPAddressNumber& self_client_address, const IPEndPoint& peer_client_address) OVERRIDE { - if (write_blocked_) { - return WriteResult(WRITE_STATUS_BLOCKED, EAGAIN); - } else { - return QuicPacketWriterWrapper::WritePacket( - buffer, buf_len, self_client_address, peer_client_address); - } + // It would be quite possible to actually implement this method here with + // the fake blocked status, but it would be significantly more work in + // Chromium, and since it's not called anyway, don't bother. + LOG(DFATAL) << "Not supported"; + return WriteResult(); } bool write_blocked_; @@ -286,6 +286,8 @@ class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest { public: virtual void SetUp() { writer_ = new BlockingWriter; + QuicDispatcherPeer::SetPacketWriterFactory(&dispatcher_, + new TestWriterFactory()); QuicDispatcherPeer::UseWriter(&dispatcher_, writer_); IPEndPoint client_address(net::test::Loopback4(), 1); diff --git a/net/tools/quic/quic_per_connection_packet_writer.cc b/net/tools/quic/quic_per_connection_packet_writer.cc new file mode 100644 index 0000000..508946b --- /dev/null +++ b/net/tools/quic/quic_per_connection_packet_writer.cc @@ -0,0 +1,46 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/quic/quic_per_connection_packet_writer.h" + +namespace net { + +namespace tools { + +QuicPerConnectionPacketWriter::QuicPerConnectionPacketWriter( + QuicPacketWriter* shared_writer, + QuicConnection* connection) + : shared_writer_(shared_writer), + connection_(connection) { +} + +QuicPerConnectionPacketWriter::~QuicPerConnectionPacketWriter() { +} + +WriteResult QuicPerConnectionPacketWriter::WritePacket( + const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) { + return shared_writer_->WritePacket(buffer, + buf_len, + self_address, + peer_address); +} + +bool QuicPerConnectionPacketWriter::IsWriteBlockedDataBuffered() const { + return shared_writer_->IsWriteBlockedDataBuffered(); +} + +bool QuicPerConnectionPacketWriter::IsWriteBlocked() const { + return shared_writer_->IsWriteBlocked(); +} + +void QuicPerConnectionPacketWriter::SetWritable() { + shared_writer_->SetWritable(); +} + +} // namespace tools + +} // namespace net diff --git a/net/tools/quic/quic_per_connection_packet_writer.h b/net/tools/quic/quic_per_connection_packet_writer.h new file mode 100644 index 0000000..a442a9a --- /dev/null +++ b/net/tools/quic/quic_per_connection_packet_writer.h @@ -0,0 +1,48 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_QUIC_QUIC_PER_CONNECTION_PACKET_WRITER_H_ +#define NET_TOOLS_QUIC_QUIC_PER_CONNECTION_PACKET_WRITER_H_ + +#include "net/quic/quic_connection.h" +#include "net/quic/quic_packet_writer.h" + +namespace net { + +namespace tools { + +// A connection-specific packet writer that wraps a shared writer and keeps a +// reference to the connection. +class QuicPerConnectionPacketWriter : public QuicPacketWriter { + public: + // Does not take ownership of |shared_writer| or |connection|. + QuicPerConnectionPacketWriter(QuicPacketWriter* shared_writer, + QuicConnection* connection); + virtual ~QuicPerConnectionPacketWriter(); + + QuicPacketWriter* shared_writer() const { return shared_writer_; } + QuicConnection* connection() const { return connection_; } + + // Default implementation of the QuicPacketWriter interface: Passes everything + // to |shared_writer_|. + virtual WriteResult WritePacket(const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) OVERRIDE; + virtual bool IsWriteBlockedDataBuffered() const OVERRIDE; + virtual bool IsWriteBlocked() const OVERRIDE; + virtual void SetWritable() OVERRIDE; + + private: + QuicPacketWriter* shared_writer_; // Not owned. + QuicConnection* connection_; // Not owned. + + DISALLOW_COPY_AND_ASSIGN(QuicPerConnectionPacketWriter); +}; + +} // namespace tools + +} // namespace net + +#endif // NET_TOOLS_QUIC_QUIC_PER_CONNECTION_PACKET_WRITER_H_ diff --git a/net/tools/quic/quic_server.cc b/net/tools/quic/quic_server.cc index 651f900..fa93678 100644 --- a/net/tools/quic/quic_server.cc +++ b/net/tools/quic/quic_server.cc @@ -166,6 +166,7 @@ QuicDispatcher* QuicServer::CreateQuicDispatcher() { config_, crypto_config_, supported_versions_, + new QuicDispatcher::DefaultPacketWriterFactory(), &epoll_server_); } diff --git a/net/tools/quic/quic_server_test.cc b/net/tools/quic/quic_server_test.cc index e2d7c4f..1107ded 100644 --- a/net/tools/quic/quic_server_test.cc +++ b/net/tools/quic/quic_server_test.cc @@ -21,7 +21,10 @@ class QuicServerDispatchPacketTest : public ::testing::Test { public: QuicServerDispatchPacketTest() : crypto_config_("blah", QuicRandom::GetInstance()), - dispatcher_(config_, crypto_config_, &eps_) { + dispatcher_(config_, + crypto_config_, + new QuicDispatcher::DefaultPacketWriterFactory(), + &eps_) { dispatcher_.Initialize(1234); } diff --git a/net/tools/quic/test_tools/mock_quic_dispatcher.cc b/net/tools/quic/test_tools/mock_quic_dispatcher.cc index 13271ca..120c279 100644 --- a/net/tools/quic/test_tools/mock_quic_dispatcher.cc +++ b/net/tools/quic/test_tools/mock_quic_dispatcher.cc @@ -13,10 +13,12 @@ namespace test { MockQuicDispatcher::MockQuicDispatcher( const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, + QuicDispatcher::PacketWriterFactory* packet_writer_factory, EpollServer* eps) : QuicDispatcher(config, crypto_config, QuicSupportedVersions(), + packet_writer_factory, eps) {} MockQuicDispatcher::~MockQuicDispatcher() {} diff --git a/net/tools/quic/test_tools/mock_quic_dispatcher.h b/net/tools/quic/test_tools/mock_quic_dispatcher.h index d155911..df32ce4 100644 --- a/net/tools/quic/test_tools/mock_quic_dispatcher.h +++ b/net/tools/quic/test_tools/mock_quic_dispatcher.h @@ -21,6 +21,7 @@ class MockQuicDispatcher : public QuicDispatcher { public: MockQuicDispatcher(const QuicConfig& config, const QuicCryptoServerConfig& crypto_config, + PacketWriterFactory* packet_writer_factory, EpollServer* eps); virtual ~MockQuicDispatcher(); diff --git a/net/tools/quic/test_tools/packet_dropping_test_writer.h b/net/tools/quic/test_tools/packet_dropping_test_writer.h index 3509722..b7babad 100644 --- a/net/tools/quic/test_tools/packet_dropping_test_writer.h +++ b/net/tools/quic/test_tools/packet_dropping_test_writer.h @@ -30,6 +30,7 @@ class PacketDroppingTestWriter : public QuicPacketWriterWrapper { class Delegate { public: virtual ~Delegate() {} + virtual void OnPacketSent(WriteResult result) = 0; virtual void OnCanWrite() = 0; }; diff --git a/net/tools/quic/test_tools/quic_dispatcher_peer.cc b/net/tools/quic/test_tools/quic_dispatcher_peer.cc index 26ff490..1900420 100644 --- a/net/tools/quic/test_tools/quic_dispatcher_peer.cc +++ b/net/tools/quic/test_tools/quic_dispatcher_peer.cc @@ -31,6 +31,13 @@ QuicPacketWriter* QuicDispatcherPeer::GetWriter(QuicDispatcher* dispatcher) { } // static +void QuicDispatcherPeer::SetPacketWriterFactory( + QuicDispatcher* dispatcher, + QuicDispatcher::PacketWriterFactory* packet_writer_factory) { + dispatcher->packet_writer_factory_.reset(packet_writer_factory); +} + +// static QuicEpollConnectionHelper* QuicDispatcherPeer::GetHelper( QuicDispatcher* dispatcher) { return dispatcher->helper_.get(); diff --git a/net/tools/quic/test_tools/quic_dispatcher_peer.h b/net/tools/quic/test_tools/quic_dispatcher_peer.h index 1d614d3..6271615 100644 --- a/net/tools/quic/test_tools/quic_dispatcher_peer.h +++ b/net/tools/quic/test_tools/quic_dispatcher_peer.h @@ -22,12 +22,16 @@ class QuicDispatcherPeer { QuicDispatcher* dispatcher, QuicTimeWaitListManager* time_wait_list_manager); - // Injects |writer| into |dispatcher| as the top level writer. + // Injects |writer| into |dispatcher| as the shared writer. static void UseWriter(QuicDispatcher* dispatcher, QuicPacketWriterWrapper* writer); static QuicPacketWriter* GetWriter(QuicDispatcher* dispatcher); + static void SetPacketWriterFactory( + QuicDispatcher* dispatcher, + QuicDispatcher::PacketWriterFactory* packet_writer_factory); + static QuicEpollConnectionHelper* GetHelper(QuicDispatcher* dispatcher); static QuicConnection* CreateQuicConnection( diff --git a/net/tools/quic/test_tools/quic_test_utils.cc b/net/tools/quic/test_tools/quic_test_utils.cc index 255fbc0..2e8c96d 100644 --- a/net/tools/quic/test_tools/quic_test_utils.cc +++ b/net/tools/quic/test_tools/quic_test_utils.cc @@ -18,12 +18,29 @@ namespace net { namespace tools { namespace test { +namespace { +class NiceMockPacketWriterFactory + : public QuicConnection::PacketWriterFactory { + public: + NiceMockPacketWriterFactory() {} + virtual ~NiceMockPacketWriterFactory() {} + + virtual QuicPacketWriter* Create( + QuicConnection* /*connection*/) const override { + return new testing::NiceMock<MockPacketWriter>(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(NiceMockPacketWriterFactory); +}; +} // namespace + MockConnection::MockConnection(bool is_server) : QuicConnection(kTestConnectionId, IPEndPoint(net::test::Loopback4(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -32,8 +49,8 @@ MockConnection::MockConnection(IPEndPoint address, bool is_server) : QuicConnection(kTestConnectionId, address, new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -43,8 +60,8 @@ MockConnection::MockConnection(QuicConnectionId connection_id, : QuicConnection(connection_id, IPEndPoint(net::test::Loopback4(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -54,8 +71,8 @@ MockConnection::MockConnection(bool is_server, : QuicConnection(kTestConnectionId, IPEndPoint(net::test::Loopback4(), kTestPort), new testing::NiceMock<MockHelper>(), - new testing::NiceMock<MockPacketWriter>(), - true /* owns_writer */, + NiceMockPacketWriterFactory(), + /* owns_writer= */ true, is_server, QuicSupportedVersions()), helper_(helper()) { } @@ -112,6 +129,54 @@ MockAckNotifierDelegate::MockAckNotifierDelegate() { MockAckNotifierDelegate::~MockAckNotifierDelegate() { } +TestWriterFactory::TestWriterFactory() : current_writer_(NULL) {} +TestWriterFactory::~TestWriterFactory() {} + +QuicPacketWriter* TestWriterFactory::Create(QuicPacketWriter* writer, + QuicConnection* connection) { + return new PerConnectionPacketWriter(this, writer, connection); +} + +void TestWriterFactory::OnPacketSent(WriteResult result) { + if (current_writer_ != NULL) { + current_writer_->connection()->OnPacketSent(result); + current_writer_ = NULL; + } +} + +void TestWriterFactory::Unregister(PerConnectionPacketWriter* writer) { + if (current_writer_ == writer) { + current_writer_ = NULL; + } +} + +TestWriterFactory::PerConnectionPacketWriter::PerConnectionPacketWriter( + TestWriterFactory* factory, + QuicPacketWriter* writer, + QuicConnection* connection) + : QuicPerConnectionPacketWriter(writer, connection), + factory_(factory) { +} + +TestWriterFactory::PerConnectionPacketWriter::~PerConnectionPacketWriter() { + factory_->Unregister(this); +} + +WriteResult TestWriterFactory::PerConnectionPacketWriter::WritePacket( + const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) { + // A DCHECK(factory_current_writer_ == NULL) would be wrong here -- this class + // may be used in a setting where connection()->OnPacketSent() is called in a + // different way, so TestWriterFactory::OnPacketSent might never be called. + factory_->current_writer_ = this; + return QuicPerConnectionPacketWriter::WritePacket(buffer, + buf_len, + self_address, + peer_address); +} + } // namespace test } // namespace tools } // namespace net diff --git a/net/tools/quic/test_tools/quic_test_utils.h b/net/tools/quic/test_tools/quic_test_utils.h index b6449d3..1eca03b 100644 --- a/net/tools/quic/test_tools/quic_test_utils.h +++ b/net/tools/quic/test_tools/quic_test_utils.h @@ -14,6 +14,8 @@ #include "net/quic/quic_packet_writer.h" #include "net/quic/quic_session.h" #include "net/spdy/spdy_framer.h" +#include "net/tools/quic/quic_dispatcher.h" +#include "net/tools/quic/quic_per_connection_packet_writer.h" #include "net/tools/quic/quic_server_session.h" #include "testing/gmock/include/gmock/gmock.h" @@ -158,6 +160,46 @@ class MockAckNotifierDelegate : public QuicAckNotifier::DelegateInterface { DISALLOW_COPY_AND_ASSIGN(MockAckNotifierDelegate); }; +// Creates per-connection packet writers that register themselves with the +// TestWriterFactory on each write so that TestWriterFactory::OnPacketSent can +// be routed to the appropriate QuicConnection. +class TestWriterFactory : public QuicDispatcher::PacketWriterFactory { + public: + TestWriterFactory(); + virtual ~TestWriterFactory(); + + virtual QuicPacketWriter* Create(QuicPacketWriter* writer, + QuicConnection* connection) override; + + // Calls OnPacketSent on the last QuicConnection to write through one of the + // packet writers created by this factory. + void OnPacketSent(WriteResult result); + + private: + class PerConnectionPacketWriter : public QuicPerConnectionPacketWriter { + public: + PerConnectionPacketWriter(TestWriterFactory* factory, + QuicPacketWriter* writer, + QuicConnection* connection); + virtual ~PerConnectionPacketWriter(); + + virtual WriteResult WritePacket( + const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address) OVERRIDE; + + private: + TestWriterFactory* factory_; + }; + + // If an asynchronous write is happening and |writer| gets deleted, this + // clears the pointer to it to prevent use-after-free. + void Unregister(PerConnectionPacketWriter* writer); + + PerConnectionPacketWriter* current_writer_; +}; + } // namespace test } // namespace tools } // namespace net |