diff options
Diffstat (limited to 'src/ssl/test/runner/conn.go')
-rw-r--r-- | src/ssl/test/runner/conn.go | 117 |
1 files changed, 84 insertions, 33 deletions
diff --git a/src/ssl/test/runner/conn.go b/src/ssl/test/runner/conn.go index adbc1c3..39bdfda 100644 --- a/src/ssl/test/runner/conn.go +++ b/src/ssl/test/runner/conn.go @@ -12,6 +12,7 @@ import ( "crypto/ecdsa" "crypto/subtle" "crypto/x509" + "encoding/binary" "errors" "fmt" "io" @@ -39,6 +40,7 @@ type Conn struct { extendedMasterSecret bool // whether this session used an extended master secret cipherSuite *cipherSuite ocspResponse []byte // stapled OCSP response + sctList []byte // signed certificate timestamp list peerCertificates []*x509.Certificate // verifiedChains contains the certificate chains that we built, as // opposed to the ones presented by the server. @@ -48,6 +50,11 @@ type Conn struct { // firstFinished contains the first Finished hash sent during the // handshake. This is the "tls-unique" channel binding value. firstFinished [12]byte + // clientCertSignatureHash contains the TLS hash id for the hash that + // was used by the client to sign the handshake with a client + // certificate. This is only set by a server and is zero if no client + // certificates were used. + clientCertSignatureHash uint8 clientRandom, serverRandom [32]byte masterSecret [48]byte @@ -87,6 +94,8 @@ func (c *Conn) init() { c.out.isDTLS = c.isDTLS c.in.config = c.config c.out.config = c.config + + c.out.updateOutSeq() } // Access to net.Conn methods. @@ -134,6 +143,7 @@ type halfConn struct { cipher interface{} // cipher algorithm mac macFunction seq [8]byte // 64-bit sequence number + outSeq [8]byte // Mapped sequence number bfree *block // list of free blocks nextCipher interface{} // next encryption state @@ -189,10 +199,6 @@ func (hc *halfConn) incSeq(isOutgoing bool) { if hc.isDTLS { // Increment up to the epoch in DTLS. limit = 2 - - if isOutgoing && hc.config.Bugs.SequenceNumberIncrement != 0 { - increment = hc.config.Bugs.SequenceNumberIncrement - } } for i := 7; i >= limit; i-- { increment += uint64(hc.seq[i]) @@ -206,6 +212,8 @@ func (hc *halfConn) incSeq(isOutgoing bool) { if increment != 0 { panic("TLS: sequence number wraparound") } + + hc.updateOutSeq() } // incNextSeq increments the starting sequence number for the next epoch. @@ -241,6 +249,22 @@ func (hc *halfConn) incEpoch() { hc.seq[i] = 0 } } + + hc.updateOutSeq() +} + +func (hc *halfConn) updateOutSeq() { + if hc.config.Bugs.SequenceNumberMapping != nil { + seqU64 := binary.BigEndian.Uint64(hc.seq[:]) + seqU64 = hc.config.Bugs.SequenceNumberMapping(seqU64) + binary.BigEndian.PutUint64(hc.outSeq[:], seqU64) + + // The DTLS epoch cannot be changed. + copy(hc.outSeq[:2], hc.seq[:2]) + return + } + + copy(hc.outSeq[:], hc.seq[:]) } func (hc *halfConn) recordHeaderLen() int { @@ -397,6 +421,8 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) // // However, our behavior matches OpenSSL, so we leak // only as much as they do. + case nullCipher: + break default: panic("unknown cipher type") } @@ -460,7 +486,7 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { // mac if hc.mac != nil { - mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) + mac := hc.mac.MAC(hc.outDigestBuf, hc.outSeq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) n := len(b.data) b.resize(n + len(mac)) @@ -478,7 +504,7 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { case *tlsAead: payloadLen := len(b.data) - recordHeaderLen - explicitIVLen b.resize(len(b.data) + c.Overhead()) - nonce := hc.seq[:] + nonce := hc.outSeq[:] if c.explicitNonce { nonce = b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] } @@ -486,7 +512,7 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { payload = payload[:payloadLen] var additionalData [13]byte - copy(additionalData[:], hc.seq[:]) + copy(additionalData[:], hc.outSeq[:]) copy(additionalData[8:], b.data[:3]) additionalData[11] = byte(payloadLen >> 8) additionalData[12] = byte(payloadLen) @@ -502,6 +528,8 @@ func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock)) c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix) c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock) + case nullCipher: + break default: panic("unknown cipher type") } @@ -630,10 +658,10 @@ func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) { if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil { // RFC suggests that EOF without an alertCloseNotify is // an error, but popular web sites seem to do this, - // so we can't make it an error. - // if err == io.EOF { - // err = io.ErrUnexpectedEOF - // } + // so we can't make it an error, outside of tests. + if err == io.EOF && c.config.Bugs.ExpectCloseNotify { + err = io.ErrUnexpectedEOF + } if e, ok := err.(net.Error); !ok || !e.Temporary() { c.in.setErrorLocked(err) } @@ -722,6 +750,10 @@ func (c *Conn) readRecord(want recordType) error { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete")) } + case recordTypeAlert: + // Looking for a close_notify. Note: unlike a real + // implementation, this is not tolerant of additional records. + // See the documentation for ExpectCloseNotify. } Again: @@ -784,7 +816,7 @@ Again: // A client might need to process a HelloRequest from // the server, thus receiving a handshake message when // application data is expected is ok. - if !c.isClient { + if !c.isClient || want != recordTypeApplicationData { return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } } @@ -799,13 +831,8 @@ Again: // sendAlert sends a TLS alert message. // c.out.Mutex <= L. -func (c *Conn) sendAlertLocked(err alert) error { - switch err { - case alertNoRenegotiation, alertCloseNotify: - c.tmp[0] = alertLevelWarning - default: - c.tmp[0] = alertLevelError - } +func (c *Conn) sendAlertLocked(level byte, err alert) error { + c.tmp[0] = level c.tmp[1] = byte(err) if c.config.Bugs.FragmentAlert { c.writeRecord(recordTypeAlert, c.tmp[0:1]) @@ -813,8 +840,8 @@ func (c *Conn) sendAlertLocked(err alert) error { } else { c.writeRecord(recordTypeAlert, c.tmp[0:2]) } - // closeNotify is a special case in that it isn't an error: - if err != alertCloseNotify { + // Error alerts are fatal to the connection. + if level == alertLevelError { return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) } return nil @@ -823,9 +850,17 @@ func (c *Conn) sendAlertLocked(err alert) error { // sendAlert sends a TLS alert message. // L < c.out.Mutex. func (c *Conn) sendAlert(err alert) error { + level := byte(alertLevelError) + if err == alertNoRenegotiation || err == alertCloseNotify { + level = alertLevelWarning + } + return c.SendAlert(level, err) +} + +func (c *Conn) SendAlert(level byte, err alert) error { c.out.Lock() defer c.out.Unlock() - return c.sendAlertLocked(err) + return c.sendAlertLocked(level, err) } // writeV2Record writes a record for a V2ClientHello. @@ -841,13 +876,6 @@ func (c *Conn) writeV2Record(data []byte) (n int, err error) { // to the connection and updates the record layer state. // c.out.Mutex <= L. func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { - if typ != recordTypeAlert && c.config.Bugs.SendWarningAlerts != 0 { - alert := make([]byte, 2) - alert[0] = alertLevelWarning - alert[1] = byte(c.config.Bugs.SendWarningAlerts) - c.writeRecord(recordTypeAlert, alert) - } - if c.isDTLS { return c.dtlsWriteRecord(typ, data) } @@ -856,9 +884,9 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { b := c.out.newBlock() first := true isClientHello := typ == recordTypeHandshake && len(data) > 0 && data[0] == typeClientHello - for len(data) > 0 { + for len(data) > 0 || first { m := len(data) - if m > maxPlaintext { + if m > maxPlaintext && !c.config.Bugs.SendLargeRecords { m = maxPlaintext } if typ == recordTypeHandshake && c.config.Bugs.MaxHandshakeRecordLength > 0 && m > c.config.Bugs.MaxHandshakeRecordLength { @@ -1038,6 +1066,9 @@ func (c *Conn) readHandshake() (interface{}, error) { // sequence number expectations but otherwise ignores them. func (c *Conn) skipPacket(packet []byte) error { for len(packet) > 0 { + if len(packet) < 13 { + return errors.New("tls: bad packet") + } // Dropped packets are completely ignored save to update // expected sequence numbers for this and the next epoch. (We // don't assert on the contents of the packets both for @@ -1057,6 +1088,9 @@ func (c *Conn) skipPacket(packet []byte) error { } c.in.incNextSeq() } + if len(packet) < 13+int(length) { + return errors.New("tls: bad packet") + } packet = packet[13+length:] } return nil @@ -1113,7 +1147,7 @@ func (c *Conn) Write(b []byte) (int, error) { } if c.config.Bugs.SendSpuriousAlert != 0 { - c.sendAlertLocked(c.config.Bugs.SendSpuriousAlert) + c.sendAlertLocked(alertLevelError, c.config.Bugs.SendSpuriousAlert) } // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext @@ -1240,10 +1274,22 @@ func (c *Conn) Close() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if c.handshakeComplete { + if c.handshakeComplete && !c.config.Bugs.NoCloseNotify { alertErr = c.sendAlert(alertCloseNotify) } + // Consume a close_notify from the peer if one hasn't been received + // already. This avoids the peer from failing |SSL_shutdown| due to a + // write failing. + if c.handshakeComplete && alertErr == nil && c.config.Bugs.ExpectCloseNotify { + for c.in.error() == nil { + c.readRecord(recordTypeAlert) + } + if c.in.error() != io.EOF { + alertErr = c.in.error() + } + } + if err := c.conn.Close(); err != nil { return err } @@ -1273,6 +1319,9 @@ func (c *Conn) Handshake() error { }) c.conn.Write([]byte{alertLevelError, byte(alertInternalError)}) } + if data := c.config.Bugs.AppDataBeforeHandshake; data != nil { + c.writeRecord(recordTypeApplicationData, data) + } if c.isClient { c.handshakeErr = c.clientHandshake() } else { @@ -1304,6 +1353,8 @@ func (c *Conn) ConnectionState() ConnectionState { state.ChannelID = c.channelID state.SRTPProtectionProfile = c.srtpProtectionProfile state.TLSUnique = c.firstFinished[:] + state.SCTList = c.sctList + state.ClientCertSignatureHash = c.clientCertSignatureHash } return state |