summaryrefslogtreecommitdiffstats
path: root/net/socket/tcp_pinger.h
blob: 96fa4fd5e29102466141071b8eb6de8bedfeef65 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright (c) 2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef NET_SOCKET_TCP_PINGER_H_
#define NET_SOCKET_TCP_PINGER_H_

#include "base/compiler_specific.h"
#include "base/ref_counted.h"
#include "base/scoped_ptr.h"
#include "base/task.h"
#include "base/third_party/dynamic_annotations/dynamic_annotations.h"
#include "base/thread.h"
#include "base/waitable_event.h"
#include "net/base/address_list.h"
#include "net/base/completion_callback.h"
#include "net/base/net_errors.h"
#include "net/socket/tcp_client_socket.h"

namespace base {
class TimeDelta;
}

namespace net {

// Simple class to wait until a TCP server is accepting connections.
class TCPPinger {
 public:
  explicit TCPPinger(const net::AddressList& addr)
    : io_thread_("TCPPinger"),
      worker_(new Worker(addr)) {
    worker_->AddRef();
    // Start up a throwaway IO thread just for this.
    // TODO(dkegel): use some existing thread pool instead?
    base::Thread::Options options;
    options.message_loop_type = MessageLoop::TYPE_IO;
    io_thread_.StartWithOptions(options);
  }

  ~TCPPinger() {
    io_thread_.message_loop()->ReleaseSoon(FROM_HERE, worker_);
  }

  int Ping() {
    // Default is 10 tries, each with a timeout of 1000ms,
    // for a total max timeout of 10 seconds.
    return Ping(base::TimeDelta::FromMilliseconds(1000), 10);
  }

  int Ping(base::TimeDelta tryTimeout, int nTries) {
    int err = ERR_IO_PENDING;
    // Post a request to do the connect on that thread.
    for (int i = 0; i < nTries; i++) {
      io_thread_.message_loop()->PostTask(FROM_HERE,
        NewRunnableMethod(worker_,
        &net::TCPPinger::Worker::DoConnect));
      // Timeout here in case remote host offline
      err = worker_->TimedWaitForResult(tryTimeout);
      if (err == net::OK)
        break;
      PlatformThread::Sleep(static_cast<int>(tryTimeout.InMilliseconds()));

      // Cancel leftover activity, if any
      io_thread_.message_loop()->PostTask(FROM_HERE,
        NewRunnableMethod(worker_,
        &net::TCPPinger::Worker::DoDisconnect));
      worker_->WaitForResult();
    }
    return err;
  }

 private:

  // Inner class to handle all actual socket calls.
  // This makes the outer interface simpler,
  // and helps us obey the "all socket calls
  // must be on same thread" restriction.
  class Worker : public base::RefCountedThreadSafe<Worker> {
   public:
    explicit Worker(const net::AddressList& addr)
      : event_(false, false),
        net_error_(ERR_IO_PENDING),
        addr_(addr),
        ALLOW_THIS_IN_INITIALIZER_LIST(connect_callback_(this,
            &net::TCPPinger::Worker::ConnectDone)) {
    }

    void DoConnect() {
      sock_.reset(new TCPClientSocket(addr_, NULL));
      int rv = sock_->Connect(&connect_callback_);
      // Regardless of success or failure, if we're done now,
      // signal the customer.
      if (rv != ERR_IO_PENDING)
        ConnectDone(rv);
    }

    void DoDisconnect() {
      sock_.reset();
      event_.Signal();
    }

    void ConnectDone(int rv) {
      sock_.reset();
      net_error_ = rv;
      event_.Signal();
    }

    int TimedWaitForResult(base::TimeDelta tryTimeout) {
      event_.TimedWait(tryTimeout);
      // In case of timeout, the value of net_error_ should be ERR_IO_PENDING.
      // However, a harmless data race can happen if TimedWait times out right
      // before event_.Signal() is called in ConnectDone().
      return ANNOTATE_UNPROTECTED_READ(net_error_);
    }

    int WaitForResult() {
      event_.Wait();
      return net_error_;
    }

   private:
    friend class base::RefCountedThreadSafe<Worker>;

    ~Worker() {}

    base::WaitableEvent event_;
    int net_error_;
    net::AddressList addr_;
    scoped_ptr<TCPClientSocket> sock_;
    net::CompletionCallbackImpl<Worker> connect_callback_;
  };

  base::Thread io_thread_;
  Worker* worker_;
  DISALLOW_COPY_AND_ASSIGN(TCPPinger);
};

}  // namespace net

#endif  // NET_SOCKET_TCP_PINGER_H_