diff options
Diffstat (limited to 'src/ssl/test/runner/handshake_messages.go')
-rw-r--r-- | src/ssl/test/runner/handshake_messages.go | 236 |
1 files changed, 176 insertions, 60 deletions
diff --git a/src/ssl/test/runner/handshake_messages.go b/src/ssl/test/runner/handshake_messages.go index da85e7a..ce214fd 100644 --- a/src/ssl/test/runner/handshake_messages.go +++ b/src/ssl/test/runner/handshake_messages.go @@ -32,7 +32,6 @@ type clientHelloMsg struct { srtpProtectionProfiles []uint16 srtpMasterKeyIdentifier string sctListSupported bool - customExtension string } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -66,8 +65,7 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.extendedMasterSecret == m1.extendedMasterSecret && eqUint16s(m.srtpProtectionProfiles, m1.srtpProtectionProfiles) && m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier && - m.sctListSupported == m1.sctListSupported && - m.customExtension == m1.customExtension + m.sctListSupported == m1.sctListSupported } func (m *clientHelloMsg) marshal() []byte { @@ -121,7 +119,7 @@ func (m *clientHelloMsg) marshal() []byte { if len(m.alpnProtocols) > 0 { extensionsLength += 2 for _, s := range m.alpnProtocols { - if l := len(s); l > 255 { + if l := len(s); l == 0 || l > 255 { panic("invalid ALPN protocol") } extensionsLength++ @@ -140,10 +138,6 @@ func (m *clientHelloMsg) marshal() []byte { if m.sctListSupported { numExtensions++ } - if l := len(m.customExtension); l > 0 { - extensionsLength += l - numExtensions++ - } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -382,14 +376,6 @@ func (m *clientHelloMsg) marshal() []byte { z[1] = byte(extensionSignedCertificateTimestamp & 0xff) z = z[4:] } - if l := len(m.customExtension); l > 0 { - z[0] = byte(extensionCustom >> 8) - z[1] = byte(extensionCustom & 0xff) - z[2] = byte(l >> 8) - z[3] = byte(l & 0xff) - copy(z[4:], []byte(m.customExtension)) - z = z[4+l:] - } m.raw = x @@ -457,7 +443,6 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.signatureAndHashes = nil m.alpnProtocols = nil m.extendedMasterSecret = false - m.customExtension = "" if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -619,8 +604,6 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } m.sctListSupported = true - case extensionCustom: - m.customExtension = string(data[:length]) } data = data[length:] } @@ -642,15 +625,40 @@ type serverHelloMsg struct { ticketSupported bool secureRenegotiation []byte alpnProtocol string - alpnProtocolEmpty bool duplicateExtension bool channelIDRequested bool extendedMasterSecret bool srtpProtectionProfile uint16 srtpMasterKeyIdentifier string sctList []byte - customExtension string - npnLast bool +} + +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.isDTLS == m1.isDTLS && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + (m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) && + m.alpnProtocol == m1.alpnProtocol && + m.duplicateExtension == m1.duplicateExtension && + m.channelIDRequested == m1.channelIDRequested && + m.extendedMasterSecret == m1.extendedMasterSecret && + m.srtpProtectionProfile == m1.srtpProtectionProfile && + m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier && + bytes.Equal(m.sctList, m1.sctList) } func (m *serverHelloMsg) marshal() []byte { @@ -687,7 +695,7 @@ func (m *serverHelloMsg) marshal() []byte { if m.channelIDRequested { numExtensions++ } - if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty { + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { if alpnLen >= 256 { panic("invalid ALPN protocol") } @@ -705,10 +713,6 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += len(m.sctList) numExtensions++ } - if l := len(m.customExtension); l > 0 { - extensionsLength += l - numExtensions++ - } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -743,7 +747,7 @@ func (m *serverHelloMsg) marshal() []byte { z[1] = 0xff z = z[4:] } - if m.nextProtoNeg && !m.npnLast { + if m.nextProtoNeg { z[0] = byte(extensionNextProtoNeg >> 8) z[1] = byte(extensionNextProtoNeg & 0xff) z[2] = byte(nextProtoLen >> 8) @@ -780,7 +784,7 @@ func (m *serverHelloMsg) marshal() []byte { copy(z, m.secureRenegotiation) z = z[len(m.secureRenegotiation):] } - if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty { + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { z[0] = byte(extensionALPN >> 8) z[1] = byte(extensionALPN & 0xff) l := 2 + 1 + alpnLen @@ -834,31 +838,6 @@ func (m *serverHelloMsg) marshal() []byte { copy(z[4:], m.sctList) z = z[4+l:] } - if l := len(m.customExtension); l > 0 { - z[0] = byte(extensionCustom >> 8) - z[1] = byte(extensionCustom & 0xff) - z[2] = byte(l >> 8) - z[3] = byte(l & 0xff) - copy(z[4:], []byte(m.customExtension)) - z = z[4+l:] - } - if m.nextProtoNeg && m.npnLast { - z[0] = byte(extensionNextProtoNeg >> 8) - z[1] = byte(extensionNextProtoNeg & 0xff) - z[2] = byte(nextProtoLen >> 8) - z[3] = byte(nextProtoLen) - z = z[4:] - - for _, v := range m.nextProtos { - l := len(v) - if l > 255 { - l = 255 - } - z[0] = byte(l) - copy(z[1:], []byte(v[0:l])) - z = z[1+l:] - } - } m.raw = x @@ -890,9 +869,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.ocspStapling = false m.ticketSupported = false m.alpnProtocol = "" - m.alpnProtocolEmpty = false m.extendedMasterSecret = false - m.customExtension = "" if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -963,7 +940,6 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } d = d[1:] m.alpnProtocol = string(d) - m.alpnProtocolEmpty = len(d) == 0 case extensionChannelID: if length > 0 { return false @@ -989,9 +965,14 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } m.srtpMasterKeyIdentifier = string(d[1:]) case extensionSignedCertificateTimestamp: - m.sctList = data[:length] - case extensionCustom: - m.customExtension = string(data[:length]) + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != len(data)-2 { + return false + } + m.sctList = data[2:length] } data = data[length:] } @@ -1004,6 +985,16 @@ type certificateMsg struct { certificates [][]byte } +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + func (m *certificateMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1081,6 +1072,16 @@ type serverKeyExchangeMsg struct { key []byte } +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1112,6 +1113,17 @@ type certificateStatusMsg struct { response []byte } +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + func (m *certificateStatusMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1163,6 +1175,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + func (m *serverHelloDoneMsg) marshal() []byte { x := make([]byte, 4) x[0] = typeServerHelloDone @@ -1178,6 +1195,16 @@ type clientKeyExchangeMsg struct { ciphertext []byte } +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1212,6 +1239,16 @@ type finishedMsg struct { verifyData []byte } +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + func (m *finishedMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1239,6 +1276,16 @@ type nextProtoMsg struct { proto string } +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + func (m *nextProtoMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1306,6 +1353,18 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) +} + func (m *certificateRequestMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1448,6 +1507,19 @@ type certificateVerifyMsg struct { signature []byte } +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.hasSignatureAndHash == m1.hasSignatureAndHash && + m.signatureAndHash.hash == m1.signatureAndHash.hash && + m.signatureAndHash.signature == m1.signatureAndHash.signature && + bytes.Equal(m.signature, m1.signature) +} + func (m *certificateVerifyMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1517,6 +1589,16 @@ type newSessionTicketMsg struct { ticket []byte } +func (m *newSessionTicketMsg) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ticket, m1.ticket) +} + func (m *newSessionTicketMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1569,6 +1651,19 @@ type v2ClientHelloMsg struct { challenge []byte } +func (m *v2ClientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*v2ClientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.sessionId, m1.sessionId) && + bytes.Equal(m.challenge, m1.challenge) +} + func (m *v2ClientHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1608,6 +1703,17 @@ type helloVerifyRequestMsg struct { cookie []byte } +func (m *helloVerifyRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*helloVerifyRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.cookie, m1.cookie) +} + func (m *helloVerifyRequestMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1649,6 +1755,16 @@ type encryptedExtensionsMsg struct { channelID []byte } +func (m *encryptedExtensionsMsg) equal(i interface{}) bool { + m1, ok := i.(*encryptedExtensionsMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.channelID, m1.channelID) +} + func (m *encryptedExtensionsMsg) marshal() []byte { if m.raw != nil { return m.raw |