diff options
Diffstat (limited to 'net/base')
-rw-r--r-- | net/base/force_tls_state.cc | 110 | ||||
-rw-r--r-- | net/base/force_tls_state.h | 35 |
2 files changed, 135 insertions, 10 deletions
diff --git a/net/base/force_tls_state.cc b/net/base/force_tls_state.cc index ea2e2f8..eca45a7 100644 --- a/net/base/force_tls_state.cc +++ b/net/base/force_tls_state.cc @@ -4,33 +4,64 @@ #include "net/base/force_tls_state.h" +#include "base/json_reader.h" +#include "base/json_writer.h" #include "base/logging.h" +#include "base/scoped_ptr.h" #include "base/string_tokenizer.h" #include "base/string_util.h" +#include "base/values.h" #include "googleurl/src/gurl.h" #include "net/base/registry_controlled_domain.h" namespace net { -ForceTLSState::ForceTLSState() { +ForceTLSState::ForceTLSState() + : delegate_(NULL) { } void ForceTLSState::DidReceiveHeader(const GURL& url, const std::string& value) { - // TODO(abarth): Actually parse |value| once the spec settles down. - EnableHost(url.host()); + int max_age; + bool include_subdomains; + + if (!ParseHeader(value, &max_age, &include_subdomains)) + return; + + base::Time current_time(base::Time::Now()); + base::TimeDelta max_age_delta = base::TimeDelta::FromSeconds(max_age); + base::Time expiry = current_time + max_age_delta; + + EnableHost(url.host(), expiry, include_subdomains); } -void ForceTLSState::EnableHost(const std::string& host) { +void ForceTLSState::EnableHost(const std::string& host, base::Time expiry, + bool include_subdomains) { // TODO(abarth): Canonicalize host. AutoLock lock(lock_); - enabled_hosts_.insert(host); + + State state = {expiry, include_subdomains}; + enabled_hosts_[host] = state; + DirtyNotify(); } bool ForceTLSState::IsEnabledForHost(const std::string& host) { // TODO(abarth): Canonicalize host. + // TODO: check for subdomains too. + AutoLock lock(lock_); - return enabled_hosts_.find(host) != enabled_hosts_.end(); + std::map<std::string, State>::iterator i = enabled_hosts_.find(host); + if (i == enabled_hosts_.end()) + return false; + + base::Time current_time(base::Time::Now()); + if (current_time > i->second.expiry) { + enabled_hosts_.erase(i); + DirtyNotify(); + return false; + } + + return true; } // "X-Force-TLS" ":" "max-age" "=" delta-seconds *1INCLUDESUBDOMAINS @@ -130,4 +161,71 @@ bool ForceTLSState::ParseHeader(const std::string& value, } } +void ForceTLSState::SetDelegate(ForceTLSState::Delegate* delegate) { + AutoLock lock(lock_); + + delegate_ = delegate; +} + +bool ForceTLSState::Serialise(std::string* output) { + AutoLock lock(lock_); + + DictionaryValue toplevel; + for (std::map<std::string, State>::const_iterator + i = enabled_hosts_.begin(); i != enabled_hosts_.end(); ++i) { + DictionaryValue* state = new DictionaryValue; + state->SetBoolean(L"include_subdomains", i->second.include_subdomains); + state->SetReal(L"expiry", i->second.expiry.ToDoubleT()); + + toplevel.Set(ASCIIToWide(i->first), state); + } + + JSONWriter::Write(&toplevel, true /* pretty print */, output); + return true; +} + +bool ForceTLSState::Deserialise(const std::string& input) { + AutoLock lock(lock_); + + enabled_hosts_.clear(); + + scoped_ptr<Value> value( + JSONReader::Read(input, false /* do not allow trailing commas */)); + if (!value.get() || !value->IsType(Value::TYPE_DICTIONARY)) + return false; + + DictionaryValue* dict_value = reinterpret_cast<DictionaryValue*>(value.get()); + const base::Time current_time(base::Time::Now()); + + for (DictionaryValue::key_iterator + i = dict_value->begin_keys(); i != dict_value->end_keys(); ++i) { + DictionaryValue* state; + if (!dict_value->GetDictionary(*i, &state)) + continue; + + const std::string host = WideToASCII(*i); + bool include_subdomains; + double expiry; + + if (!state->GetBoolean(L"include_subdomains", &include_subdomains) || + !state->GetReal(L"expiry", &expiry)) { + continue; + } + + base::Time expiry_time = base::Time::FromDoubleT(expiry); + if (expiry_time <= current_time) + continue; + + State new_state = { expiry_time, include_subdomains }; + enabled_hosts_[host] = new_state; + } + + return enabled_hosts_.size() > 0; +} + +void ForceTLSState::DirtyNotify() { + if (delegate_) + delegate_->StateIsDirty(this); +} + } // namespace diff --git a/net/base/force_tls_state.h b/net/base/force_tls_state.h index e52adb9..068d73c 100644 --- a/net/base/force_tls_state.h +++ b/net/base/force_tls_state.h @@ -5,11 +5,13 @@ #ifndef NET_BASE_FORCE_TLS_STATE_H_ #define NET_BASE_FORCE_TLS_STATE_H_ -#include <set> +#include <map> #include <string> #include "base/basictypes.h" #include "base/lock.h" +#include "base/ref_counted.h" +#include "base/time.h" class GURL; @@ -21,7 +23,7 @@ namespace net { // then we refuse to talk to the host over HTTP, treat all certificate errors as // fatal, and refuse to load any mixed content. // -class ForceTLSState { +class ForceTLSState : public base::RefCountedThreadSafe<ForceTLSState> { public: ForceTLSState(); @@ -30,7 +32,8 @@ class ForceTLSState { void DidReceiveHeader(const GURL& url, const std::string& value); // Enable ForceTLS for |host|. - void EnableHost(const std::string& host); + void EnableHost(const std::string& host, base::Time expiry, + bool include_subdomains); // Returns whether |host| has had ForceTLS enabled. bool IsEnabledForHost(const std::string& host); @@ -43,13 +46,37 @@ class ForceTLSState { int* max_age, bool* include_subdomains); + struct State { + base::Time expiry; // the absolute time (UTC) when this record expires + bool include_subdomains; // subdomains included? + }; + + class Delegate { + public: + // This function may not block and may be called with internal locks held. + // Thus it must not reenter the ForceTLSState object. + virtual void StateIsDirty(ForceTLSState* state) = 0; + }; + + void SetDelegate(Delegate*); + + bool Serialise(std::string* output); + bool Deserialise(const std::string& state); + private: + // If we have a callback configured, call it to let our serialiser know that + // our state is dirty. + void DirtyNotify(); + // The set of hosts that have enabled ForceTLS. - std::set<std::string> enabled_hosts_; + std::map<std::string, State> enabled_hosts_; // Protect access to our data members with this lock. Lock lock_; + // Our delegate who gets notified when we are dirtied, or NULL. + Delegate* delegate_; + DISALLOW_COPY_AND_ASSIGN(ForceTLSState); }; |