diff options
Diffstat (limited to 'net/socket_stream')
-rw-r--r-- | net/socket_stream/socket_stream_job.cc | 21 | ||||
-rw-r--r-- | net/socket_stream/socket_stream_job.h | 5 |
2 files changed, 23 insertions, 3 deletions
diff --git a/net/socket_stream/socket_stream_job.cc b/net/socket_stream/socket_stream_job.cc index b52945a..0349de8 100644 --- a/net/socket_stream/socket_stream_job.cc +++ b/net/socket_stream/socket_stream_job.cc @@ -5,7 +5,9 @@ #include "net/socket_stream/socket_stream_job.h" #include "base/memory/singleton.h" +#include "net/base/transport_security_state.h" #include "net/socket_stream/socket_stream_job_manager.h" +#include "net/url_request/url_request_context.h" namespace net { @@ -18,8 +20,23 @@ SocketStreamJob::ProtocolFactory* SocketStreamJob::RegisterProtocolFactory( // static SocketStreamJob* SocketStreamJob::CreateSocketStreamJob( - const GURL& url, SocketStream::Delegate* delegate) { - return SocketStreamJobManager::GetInstance()->CreateJob(url, delegate); + const GURL& url, + SocketStream::Delegate* delegate, + const URLRequestContext& context) { + GURL socket_url(url); + TransportSecurityState::DomainState domain_state; + if (url.scheme() == "ws" && + context.transport_security_state() && + context.transport_security_state()->IsEnabledForHost( + &domain_state, url.host(), context.IsSNIAvailable()) && + domain_state.mode == TransportSecurityState::DomainState::MODE_STRICT) { + url_canon::Replacements<char> replacements; + static const char kNewScheme[] = "wss"; + replacements.SetScheme(kNewScheme, + url_parse::Component(0, strlen(kNewScheme))); + socket_url = url.ReplaceComponents(replacements); + } + return SocketStreamJobManager::GetInstance()->CreateJob(socket_url, delegate); } SocketStreamJob::SocketStreamJob() {} diff --git a/net/socket_stream/socket_stream_job.h b/net/socket_stream/socket_stream_job.h index 9a4577e..24eaa19 100644 --- a/net/socket_stream/socket_stream_job.h +++ b/net/socket_stream/socket_stream_job.h @@ -32,7 +32,9 @@ class SocketStreamJob : public base::RefCountedThreadSafe<SocketStreamJob> { ProtocolFactory* factory); static SocketStreamJob* CreateSocketStreamJob( - const GURL& url, SocketStream::Delegate* delegate); + const GURL& url, + SocketStream::Delegate* delegate, + const URLRequestContext& context); SocketStreamJob(); void InitSocketStream(SocketStream* socket) { @@ -61,6 +63,7 @@ class SocketStreamJob : public base::RefCountedThreadSafe<SocketStreamJob> { virtual void DetachDelegate(); protected: + friend class WebSocketJobTest; friend class base::RefCountedThreadSafe<SocketStreamJob>; virtual ~SocketStreamJob(); |