summaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/base/force_tls_state.cc110
-rw-r--r--net/base/force_tls_state.h35
2 files changed, 135 insertions, 10 deletions
diff --git a/net/base/force_tls_state.cc b/net/base/force_tls_state.cc
index ea2e2f8..eca45a7 100644
--- a/net/base/force_tls_state.cc
+++ b/net/base/force_tls_state.cc
@@ -4,33 +4,64 @@
#include "net/base/force_tls_state.h"
+#include "base/json_reader.h"
+#include "base/json_writer.h"
#include "base/logging.h"
+#include "base/scoped_ptr.h"
#include "base/string_tokenizer.h"
#include "base/string_util.h"
+#include "base/values.h"
#include "googleurl/src/gurl.h"
#include "net/base/registry_controlled_domain.h"
namespace net {
-ForceTLSState::ForceTLSState() {
+ForceTLSState::ForceTLSState()
+ : delegate_(NULL) {
}
void ForceTLSState::DidReceiveHeader(const GURL& url,
const std::string& value) {
- // TODO(abarth): Actually parse |value| once the spec settles down.
- EnableHost(url.host());
+ int max_age;
+ bool include_subdomains;
+
+ if (!ParseHeader(value, &max_age, &include_subdomains))
+ return;
+
+ base::Time current_time(base::Time::Now());
+ base::TimeDelta max_age_delta = base::TimeDelta::FromSeconds(max_age);
+ base::Time expiry = current_time + max_age_delta;
+
+ EnableHost(url.host(), expiry, include_subdomains);
}
-void ForceTLSState::EnableHost(const std::string& host) {
+void ForceTLSState::EnableHost(const std::string& host, base::Time expiry,
+ bool include_subdomains) {
// TODO(abarth): Canonicalize host.
AutoLock lock(lock_);
- enabled_hosts_.insert(host);
+
+ State state = {expiry, include_subdomains};
+ enabled_hosts_[host] = state;
+ DirtyNotify();
}
bool ForceTLSState::IsEnabledForHost(const std::string& host) {
// TODO(abarth): Canonicalize host.
+ // TODO: check for subdomains too.
+
AutoLock lock(lock_);
- return enabled_hosts_.find(host) != enabled_hosts_.end();
+ std::map<std::string, State>::iterator i = enabled_hosts_.find(host);
+ if (i == enabled_hosts_.end())
+ return false;
+
+ base::Time current_time(base::Time::Now());
+ if (current_time > i->second.expiry) {
+ enabled_hosts_.erase(i);
+ DirtyNotify();
+ return false;
+ }
+
+ return true;
}
// "X-Force-TLS" ":" "max-age" "=" delta-seconds *1INCLUDESUBDOMAINS
@@ -130,4 +161,71 @@ bool ForceTLSState::ParseHeader(const std::string& value,
}
}
+void ForceTLSState::SetDelegate(ForceTLSState::Delegate* delegate) {
+ AutoLock lock(lock_);
+
+ delegate_ = delegate;
+}
+
+bool ForceTLSState::Serialise(std::string* output) {
+ AutoLock lock(lock_);
+
+ DictionaryValue toplevel;
+ for (std::map<std::string, State>::const_iterator
+ i = enabled_hosts_.begin(); i != enabled_hosts_.end(); ++i) {
+ DictionaryValue* state = new DictionaryValue;
+ state->SetBoolean(L"include_subdomains", i->second.include_subdomains);
+ state->SetReal(L"expiry", i->second.expiry.ToDoubleT());
+
+ toplevel.Set(ASCIIToWide(i->first), state);
+ }
+
+ JSONWriter::Write(&toplevel, true /* pretty print */, output);
+ return true;
+}
+
+bool ForceTLSState::Deserialise(const std::string& input) {
+ AutoLock lock(lock_);
+
+ enabled_hosts_.clear();
+
+ scoped_ptr<Value> value(
+ JSONReader::Read(input, false /* do not allow trailing commas */));
+ if (!value.get() || !value->IsType(Value::TYPE_DICTIONARY))
+ return false;
+
+ DictionaryValue* dict_value = reinterpret_cast<DictionaryValue*>(value.get());
+ const base::Time current_time(base::Time::Now());
+
+ for (DictionaryValue::key_iterator
+ i = dict_value->begin_keys(); i != dict_value->end_keys(); ++i) {
+ DictionaryValue* state;
+ if (!dict_value->GetDictionary(*i, &state))
+ continue;
+
+ const std::string host = WideToASCII(*i);
+ bool include_subdomains;
+ double expiry;
+
+ if (!state->GetBoolean(L"include_subdomains", &include_subdomains) ||
+ !state->GetReal(L"expiry", &expiry)) {
+ continue;
+ }
+
+ base::Time expiry_time = base::Time::FromDoubleT(expiry);
+ if (expiry_time <= current_time)
+ continue;
+
+ State new_state = { expiry_time, include_subdomains };
+ enabled_hosts_[host] = new_state;
+ }
+
+ return enabled_hosts_.size() > 0;
+}
+
+void ForceTLSState::DirtyNotify() {
+ if (delegate_)
+ delegate_->StateIsDirty(this);
+}
+
} // namespace
diff --git a/net/base/force_tls_state.h b/net/base/force_tls_state.h
index e52adb9..068d73c 100644
--- a/net/base/force_tls_state.h
+++ b/net/base/force_tls_state.h
@@ -5,11 +5,13 @@
#ifndef NET_BASE_FORCE_TLS_STATE_H_
#define NET_BASE_FORCE_TLS_STATE_H_
-#include <set>
+#include <map>
#include <string>
#include "base/basictypes.h"
#include "base/lock.h"
+#include "base/ref_counted.h"
+#include "base/time.h"
class GURL;
@@ -21,7 +23,7 @@ namespace net {
// then we refuse to talk to the host over HTTP, treat all certificate errors as
// fatal, and refuse to load any mixed content.
//
-class ForceTLSState {
+class ForceTLSState : public base::RefCountedThreadSafe<ForceTLSState> {
public:
ForceTLSState();
@@ -30,7 +32,8 @@ class ForceTLSState {
void DidReceiveHeader(const GURL& url, const std::string& value);
// Enable ForceTLS for |host|.
- void EnableHost(const std::string& host);
+ void EnableHost(const std::string& host, base::Time expiry,
+ bool include_subdomains);
// Returns whether |host| has had ForceTLS enabled.
bool IsEnabledForHost(const std::string& host);
@@ -43,13 +46,37 @@ class ForceTLSState {
int* max_age,
bool* include_subdomains);
+ struct State {
+ base::Time expiry; // the absolute time (UTC) when this record expires
+ bool include_subdomains; // subdomains included?
+ };
+
+ class Delegate {
+ public:
+ // This function may not block and may be called with internal locks held.
+ // Thus it must not reenter the ForceTLSState object.
+ virtual void StateIsDirty(ForceTLSState* state) = 0;
+ };
+
+ void SetDelegate(Delegate*);
+
+ bool Serialise(std::string* output);
+ bool Deserialise(const std::string& state);
+
private:
+ // If we have a callback configured, call it to let our serialiser know that
+ // our state is dirty.
+ void DirtyNotify();
+
// The set of hosts that have enabled ForceTLS.
- std::set<std::string> enabled_hosts_;
+ std::map<std::string, State> enabled_hosts_;
// Protect access to our data members with this lock.
Lock lock_;
+ // Our delegate who gets notified when we are dirtied, or NULL.
+ Delegate* delegate_;
+
DISALLOW_COPY_AND_ASSIGN(ForceTLSState);
};