summaryrefslogtreecommitdiffstats
path: root/net/socket/websocket_endpoint_lock_manager_unittest.cc
blob: 6808d3733fc17c28fbd4f38234cc46b81eec34d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/socket/websocket_endpoint_lock_manager.h"

#include "base/logging.h"
#include "base/macros.h"
#include "base/message_loop/message_loop.h"
#include "base/run_loop.h"
#include "base/time/time.h"
#include "net/base/net_errors.h"
#include "net/socket/next_proto.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/stream_socket.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace net {

namespace {

// A StreamSocket implementation with no functionality at all.
// TODO(ricea): If you need to use this in another file, please move it to
// socket_test_util.h.
class FakeStreamSocket : public StreamSocket {
 public:
  FakeStreamSocket() {}

  // StreamSocket implementation
  int Connect(const CompletionCallback& callback) override {
    return ERR_FAILED;
  }

  void Disconnect() override { return; }

  bool IsConnected() const override { return false; }

  bool IsConnectedAndIdle() const override { return false; }

  int GetPeerAddress(IPEndPoint* address) const override { return ERR_FAILED; }

  int GetLocalAddress(IPEndPoint* address) const override { return ERR_FAILED; }

  const BoundNetLog& NetLog() const override { return bound_net_log_; }

  void SetSubresourceSpeculation() override { return; }
  void SetOmniboxSpeculation() override { return; }

  bool WasEverUsed() const override { return false; }

  bool UsingTCPFastOpen() const override { return false; }

  bool WasNpnNegotiated() const override { return false; }

  NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }

  bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }

  void GetConnectionAttempts(ConnectionAttempts* out) const override {
    out->clear();
  }

  void ClearConnectionAttempts() override {}

  void AddConnectionAttempts(const ConnectionAttempts& attempts) override {}

  int64_t GetTotalReceivedBytes() const override {
    NOTIMPLEMENTED();
    return 0;
  }

  // Socket implementation
  int Read(IOBuffer* buf,
           int buf_len,
           const CompletionCallback& callback) override {
    return ERR_FAILED;
  }

  int Write(IOBuffer* buf,
            int buf_len,
            const CompletionCallback& callback) override {
    return ERR_FAILED;
  }

  int SetReceiveBufferSize(int32_t size) override { return ERR_FAILED; }

  int SetSendBufferSize(int32_t size) override { return ERR_FAILED; }

 private:
  BoundNetLog bound_net_log_;

  DISALLOW_COPY_AND_ASSIGN(FakeStreamSocket);
};

class FakeWaiter : public WebSocketEndpointLockManager::Waiter {
 public:
  FakeWaiter() : called_(false) {}

  void GotEndpointLock() override {
    CHECK(!called_);
    called_ = true;
  }

  bool called() const { return called_; }

 private:
  bool called_;
};

class BlockingWaiter : public FakeWaiter {
 public:
  void WaitForLock() {
    while (!called()) {
      run_loop_.Run();
    }
  }

  void GotEndpointLock() override {
    FakeWaiter::GotEndpointLock();
    run_loop_.Quit();
  }

 private:
  base::RunLoop run_loop_;
};

class WebSocketEndpointLockManagerTest : public ::testing::Test {
 protected:
  WebSocketEndpointLockManagerTest()
      : instance_(WebSocketEndpointLockManager::GetInstance()) {}
  ~WebSocketEndpointLockManagerTest() override {
    // Permit any pending asynchronous unlock operations to complete.
    RunUntilIdle();
    // If this check fails then subsequent tests may fail.
    CHECK(instance_->IsEmpty());
  }

  WebSocketEndpointLockManager* instance() const { return instance_; }

  IPEndPoint DummyEndpoint() {
    IPAddressNumber ip_address_number;
    CHECK(ParseIPLiteralToNumber("127.0.0.1", &ip_address_number));
    return IPEndPoint(ip_address_number, 80);
  }

  void UnlockDummyEndpoint(int times) {
    for (int i = 0; i < times; ++i) {
      instance()->UnlockEndpoint(DummyEndpoint());
      RunUntilIdle();
    }
  }

  static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); }

  WebSocketEndpointLockManager* const instance_;
  ScopedWebSocketEndpointZeroUnlockDelay zero_unlock_delay_;
};

TEST_F(WebSocketEndpointLockManagerTest, GetInstanceWorks) {
  // All the work is done by the test framework.
}

TEST_F(WebSocketEndpointLockManagerTest, LockEndpointReturnsOkOnce) {
  FakeWaiter waiters[2];
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));

  UnlockDummyEndpoint(2);
}

TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledOnOk) {
  FakeWaiter waiter;
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter));
  RunUntilIdle();
  EXPECT_FALSE(waiter.called());

  UnlockDummyEndpoint(1);
}

TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockNotCalledImmediately) {
  FakeWaiter waiters[2];
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
  RunUntilIdle();
  EXPECT_FALSE(waiters[1].called());

  UnlockDummyEndpoint(2);
}

TEST_F(WebSocketEndpointLockManagerTest, GotEndpointLockCalledWhenUnlocked) {
  FakeWaiter waiters[2];
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
  instance()->UnlockEndpoint(DummyEndpoint());
  RunUntilIdle();
  EXPECT_TRUE(waiters[1].called());

  UnlockDummyEndpoint(1);
}

TEST_F(WebSocketEndpointLockManagerTest,
       EndpointUnlockedIfWaiterAlreadyDeleted) {
  FakeWaiter first_lock_holder;
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &first_lock_holder));

  {
    FakeWaiter short_lived_waiter;
    EXPECT_EQ(ERR_IO_PENDING,
              instance()->LockEndpoint(DummyEndpoint(), &short_lived_waiter));
  }

  instance()->UnlockEndpoint(DummyEndpoint());
  RunUntilIdle();

  FakeWaiter second_lock_holder;
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &second_lock_holder));

  UnlockDummyEndpoint(1);
}

TEST_F(WebSocketEndpointLockManagerTest, RememberSocketWorks) {
  FakeWaiter waiters[2];
  FakeStreamSocket dummy_socket;
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));

  instance()->RememberSocket(&dummy_socket, DummyEndpoint());
  instance()->UnlockSocket(&dummy_socket);
  RunUntilIdle();
  EXPECT_TRUE(waiters[1].called());

  UnlockDummyEndpoint(1);
}

// UnlockEndpoint() should cause any sockets remembered for this endpoint
// to be forgotten.
TEST_F(WebSocketEndpointLockManagerTest, SocketAssociationForgottenOnUnlock) {
  FakeWaiter waiter;
  FakeStreamSocket dummy_socket;

  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiter));
  instance()->RememberSocket(&dummy_socket, DummyEndpoint());
  instance()->UnlockEndpoint(DummyEndpoint());
  RunUntilIdle();
  EXPECT_TRUE(instance()->IsEmpty());
}

// When ownership of the endpoint is passed to a new waiter, the new waiter can
// call RememberSocket() again.
TEST_F(WebSocketEndpointLockManagerTest, NextWaiterCanCallRememberSocketAgain) {
  FakeWaiter waiters[2];
  FakeStreamSocket dummy_sockets[2];
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));

  instance()->RememberSocket(&dummy_sockets[0], DummyEndpoint());
  instance()->UnlockEndpoint(DummyEndpoint());
  RunUntilIdle();
  EXPECT_TRUE(waiters[1].called());
  instance()->RememberSocket(&dummy_sockets[1], DummyEndpoint());

  UnlockDummyEndpoint(1);
}

// Calling UnlockSocket() after UnlockEndpoint() does nothing.
TEST_F(WebSocketEndpointLockManagerTest,
       UnlockSocketAfterUnlockEndpointDoesNothing) {
  FakeWaiter waiters[3];
  FakeStreamSocket dummy_socket;

  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[2]));
  instance()->RememberSocket(&dummy_socket, DummyEndpoint());
  instance()->UnlockEndpoint(DummyEndpoint());
  instance()->UnlockSocket(&dummy_socket);
  RunUntilIdle();
  EXPECT_TRUE(waiters[1].called());
  EXPECT_FALSE(waiters[2].called());

  UnlockDummyEndpoint(2);
}

// UnlockEndpoint() should always be asynchronous.
TEST_F(WebSocketEndpointLockManagerTest, UnlockEndpointIsAsynchronous) {
  FakeWaiter waiters[2];
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &waiters[0]));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &waiters[1]));

  instance()->UnlockEndpoint(DummyEndpoint());
  EXPECT_FALSE(waiters[1].called());
  RunUntilIdle();
  EXPECT_TRUE(waiters[1].called());

  UnlockDummyEndpoint(1);
}

// UnlockEndpoint() should normally have a delay.
TEST_F(WebSocketEndpointLockManagerTest, UnlockEndpointIsDelayed) {
  using base::TimeTicks;

  // This 1ms delay is too short for very slow environments (usually those
  // running memory checkers). In those environments, the code takes >1ms to run
  // and no delay is needed. Rather than increase the delay and slow down the
  // test everywhere, the test doesn't explicitly verify that a delay has been
  // applied. Instead it just verifies that the whole thing took >=1ms. 1ms is
  // easily enough for normal compiles even on Android, so the fact that there
  // is a delay is still checked on every platform.
  const base::TimeDelta unlock_delay = base::TimeDelta::FromMilliseconds(1);
  instance()->SetUnlockDelayForTesting(unlock_delay);
  FakeWaiter fake_waiter;
  BlockingWaiter blocking_waiter;
  EXPECT_EQ(OK, instance()->LockEndpoint(DummyEndpoint(), &fake_waiter));
  EXPECT_EQ(ERR_IO_PENDING,
            instance()->LockEndpoint(DummyEndpoint(), &blocking_waiter));

  TimeTicks before_unlock = TimeTicks::Now();
  instance()->UnlockEndpoint(DummyEndpoint());
  blocking_waiter.WaitForLock();
  TimeTicks after_unlock = TimeTicks::Now();
  EXPECT_GE(after_unlock - before_unlock, unlock_delay);
  instance()->SetUnlockDelayForTesting(base::TimeDelta());
  UnlockDummyEndpoint(1);
}

}  // namespace

}  // namespace net