diff options
Diffstat (limited to 'src/ssl/d1_pkt.c')
-rw-r--r-- | src/ssl/d1_pkt.c | 475 |
1 files changed, 394 insertions, 81 deletions
diff --git a/src/ssl/d1_pkt.c b/src/ssl/d1_pkt.c index e2d505c..553499f 100644 --- a/src/ssl/d1_pkt.c +++ b/src/ssl/d1_pkt.c @@ -109,8 +109,6 @@ * copied and put under another distribution licence * [including the GNU Public Licence.] */ -#include <openssl/ssl.h> - #include <assert.h> #include <stdio.h> #include <string.h> @@ -124,66 +122,270 @@ #include "internal.h" +/* mod 128 saturating subtract of two 64-bit values in big-endian order */ +static int satsub64be(const uint8_t *v1, const uint8_t *v2) { + int ret, sat, brw, i; + + if (sizeof(long) == 8) { + do { + const union { + long one; + char little; + } is_endian = {1}; + long l; + + if (is_endian.little) { + break; + } + /* not reached on little-endians */ + /* following test is redundant, because input is + * always aligned, but I take no chances... */ + if (((size_t)v1 | (size_t)v2) & 0x7) { + break; + } + + l = *((long *)v1); + l -= *((long *)v2); + if (l > 128) { + return 128; + } else if (l < -128) { + return -128; + } else { + return (int)l; + } + } while (0); + } + + ret = (int)v1[7] - (int)v2[7]; + sat = 0; + brw = ret >> 8; /* brw is either 0 or -1 */ + if (ret & 0x80) { + for (i = 6; i >= 0; i--) { + brw += (int)v1[i] - (int)v2[i]; + sat |= ~brw; + brw >>= 8; + } + } else { + for (i = 6; i >= 0; i--) { + brw += (int)v1[i] - (int)v2[i]; + sat |= brw; + brw >>= 8; + } + } + brw <<= 8; /* brw is either 0 or -256 */ + + if (sat & 0xff) { + return brw | 0x80; + } else { + return brw + (ret & 0xFF); + } +} + +static int dtls1_record_replay_check(SSL *s, DTLS1_BITMAP *bitmap); +static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap); +static int dtls1_process_record(SSL *s); static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, unsigned int len, enum dtls1_use_epoch_t use_epoch); -/* dtls1_get_record reads a new input record. On success, it places it in - * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if - * more data is needed. */ -static int dtls1_get_record(SSL *ssl) { -again: - /* Read a new packet if there is no unconsumed one. */ - if (ssl_read_buffer_len(ssl) == 0) { - int ret = ssl_read_buffer_extend_to(ssl, 0 /* unused */); - if (ret <= 0) { - return ret; - } +static int dtls1_process_record(SSL *s) { + int al; + SSL3_RECORD *rr = &s->s3->rrec; + + /* check is not needed I believe */ + if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) { + al = SSL_AD_RECORD_OVERFLOW; + OPENSSL_PUT_ERROR(SSL, dtls1_process_record, + SSL_R_ENCRYPTED_LENGTH_TOO_LONG); + goto f_err; } - assert(ssl_read_buffer_len(ssl) > 0); - /* Ensure the packet is large enough to decrypt in-place. */ - if (ssl_read_buffer_len(ssl) < ssl_record_prefix_len(ssl)) { - ssl_read_buffer_clear(ssl); - goto again; + /* |rr->data| points to |rr->length| bytes of ciphertext in |s->packet|. */ + rr->data = &s->packet[DTLS1_RT_HEADER_LENGTH]; + + uint8_t seq[8]; + seq[0] = rr->epoch >> 8; + seq[1] = rr->epoch & 0xff; + memcpy(&seq[2], &rr->seq_num[2], 6); + + /* Decrypt the packet in-place. Note it is important that |SSL_AEAD_CTX_open| + * not write beyond |rr->length|. There may be another record in the packet. + * + * TODO(davidben): This assumes |s->version| is the same as the record-layer + * version which isn't always true, but it only differs with the NULL cipher + * which ignores the parameter. */ + size_t plaintext_len; + if (!SSL_AEAD_CTX_open(s->aead_read_ctx, rr->data, &plaintext_len, rr->length, + rr->type, s->version, seq, rr->data, rr->length)) { + /* Bad packets are silently dropped in DTLS. Clear the error queue of any + * errors decryption may have added. */ + ERR_clear_error(); + rr->length = 0; + s->packet_length = 0; + goto err; + } + + if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) { + al = SSL_AD_RECORD_OVERFLOW; + OPENSSL_PUT_ERROR(SSL, dtls1_process_record, SSL_R_DATA_LENGTH_TOO_LONG); + goto f_err; } + assert(plaintext_len < (1u << 16)); + rr->length = plaintext_len; - uint8_t *out = ssl_read_buffer(ssl) + ssl_record_prefix_len(ssl); - size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl); - uint8_t type, alert; - size_t len, consumed; - switch (dtls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out, - ssl_read_buffer(ssl), ssl_read_buffer_len(ssl))) { - case ssl_open_record_success: - ssl_read_buffer_consume(ssl, consumed); + rr->off = 0; + /* So at this point the following is true + * ssl->s3->rrec.type is the type of record + * ssl->s3->rrec.length == number of bytes in record + * ssl->s3->rrec.off == offset to first valid byte + * ssl->s3->rrec.data == the first byte of the record body. */ + + /* we have pulled in a full packet so zero things */ + s->packet_length = 0; + return 1; + +f_err: + ssl3_send_alert(s, SSL3_AL_FATAL, al); - if (len > 0xffff) { - OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); - return -1; +err: + return 0; +} + +/* Call this to get a new input record. + * It will return <= 0 if more data is needed, normally due to an error + * or non-blocking IO. + * When it finishes, one packet has been decoded and can be found in + * ssl->s3->rrec.type - is the type of record + * ssl->s3->rrec.data, - data + * ssl->s3->rrec.length, - number of bytes + * + * used only by dtls1_read_bytes */ +int dtls1_get_record(SSL *s) { + uint8_t ssl_major, ssl_minor; + int n; + SSL3_RECORD *rr; + uint8_t *p = NULL; + uint16_t version; + + rr = &(s->s3->rrec); + + /* get something from the wire */ +again: + /* check if we have the header */ + if ((s->rstate != SSL_ST_READ_BODY) || + (s->packet_length < DTLS1_RT_HEADER_LENGTH)) { + n = ssl3_read_n(s, DTLS1_RT_HEADER_LENGTH, 0); + /* read timeout is handled by dtls1_read_bytes */ + if (n <= 0) { + return n; /* error or non-blocking */ + } + + /* this packet contained a partial record, dump it */ + if (s->packet_length != DTLS1_RT_HEADER_LENGTH) { + s->packet_length = 0; + goto again; + } + + s->rstate = SSL_ST_READ_BODY; + + p = s->packet; + + if (s->msg_callback) { + s->msg_callback(0, 0, SSL3_RT_HEADER, p, DTLS1_RT_HEADER_LENGTH, s, + s->msg_callback_arg); + } + + /* Pull apart the header into the DTLS1_RECORD */ + rr->type = *(p++); + ssl_major = *(p++); + ssl_minor = *(p++); + version = (((uint16_t)ssl_major) << 8) | ssl_minor; + + /* sequence number is 64 bits, with top 2 bytes = epoch */ + n2s(p, rr->epoch); + + memcpy(&(s->s3->read_sequence[2]), p, 6); + p += 6; + + n2s(p, rr->length); + + /* Lets check version */ + if (s->s3->have_version) { + if (version != s->version) { + /* The record's version doesn't match, so silently drop it. + * + * TODO(davidben): This doesn't work. The DTLS record layer is not + * packet-based, so the remainder of the packet isn't dropped and we + * get a framing error. It's also unclear what it means to silently + * drop a record in a packet containing two records. */ + rr->length = 0; + s->packet_length = 0; + goto again; } + } - SSL3_RECORD *rr = &ssl->s3->rrec; - rr->type = type; - rr->length = (uint16_t)len; - rr->off = 0; - rr->data = out; - return 1; + if ((version & 0xff00) != (s->version & 0xff00)) { + /* wrong version, silently discard record */ + rr->length = 0; + s->packet_length = 0; + goto again; + } - case ssl_open_record_discard: - ssl_read_buffer_consume(ssl, consumed); + if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) { + /* record too long, silently discard it */ + rr->length = 0; + s->packet_length = 0; goto again; + } - case ssl_open_record_error: - ssl3_send_alert(ssl, SSL3_AL_FATAL, alert); - return -1; + /* now s->rstate == SSL_ST_READ_BODY */ + } + + /* s->rstate == SSL_ST_READ_BODY, get and decode the data */ + + if (rr->length > s->packet_length - DTLS1_RT_HEADER_LENGTH) { + /* now s->packet_length == DTLS1_RT_HEADER_LENGTH */ + n = ssl3_read_n(s, rr->length, 1); + /* This packet contained a partial record, dump it. */ + if (n != rr->length) { + rr->length = 0; + s->packet_length = 0; + goto again; + } - case ssl_open_record_partial: - /* Impossible in DTLS. */ - break; + /* now n == rr->length, + * and s->packet_length == DTLS1_RT_HEADER_LENGTH + rr->length */ } + s->rstate = SSL_ST_READ_HEADER; /* set state for later operations */ - assert(0); - OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); - return -1; + if (rr->epoch != s->d1->r_epoch) { + /* This record is from the wrong epoch. If it is the next epoch, it could be + * buffered. For simplicity, drop it and expect retransmit to handle it + * later; DTLS is supposed to handle packet loss. */ + rr->length = 0; + s->packet_length = 0; + goto again; + } + + /* Check whether this is a repeat, or aged record. */ + if (!dtls1_record_replay_check(s, &s->d1->bitmap)) { + rr->length = 0; + s->packet_length = 0; /* dump this record */ + goto again; /* get another record */ + } + + /* just read a 0 length packet */ + if (rr->length == 0) { + goto again; + } + + if (!dtls1_process_record(s)) { + rr->length = 0; + s->packet_length = 0; /* dump this record */ + goto again; /* get another record */ + } + dtls1_record_bitmap_update(s, &s->d1->bitmap); /* Mark receipt of record. */ + + return 1; } int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek) { @@ -191,11 +393,7 @@ int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek) { } void dtls1_read_close_notify(SSL *ssl) { - /* Bidirectional shutdown doesn't make sense for an unordered transport. DTLS - * alerts also aren't delivered reliably, so we may even time out because the - * peer never received our close_notify. Report to the caller that the channel - * has fully shut down. */ - ssl->shutdown |= SSL_RECEIVED_SHUTDOWN; + dtls1_read_bytes(ssl, 0, NULL, 0, 0); } /* Return up to 'len' payload bytes received in 'type' records. @@ -203,6 +401,7 @@ void dtls1_read_close_notify(SSL *ssl) { * * - SSL3_RT_HANDSHAKE (when ssl3_get_message calls us) * - SSL3_RT_APPLICATION_DATA (when ssl3_read calls us) + * - 0 (during a shutdown, no data has to be returned) * * If we don't have stored data to work from, read a SSL/TLS record first * (possibly multiple records if we still don't have anything to return). @@ -230,9 +429,11 @@ int dtls1_read_bytes(SSL *s, int type, unsigned char *buf, int len, int peek) { SSL3_RECORD *rr; void (*cb)(const SSL *ssl, int type2, int val) = NULL; - if ((type != SSL3_RT_APPLICATION_DATA && type != SSL3_RT_HANDSHAKE) || - (peek && type != SSL3_RT_APPLICATION_DATA)) { - OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); + /* XXX: check what the second '&& type' is about */ + if ((type && (type != SSL3_RT_APPLICATION_DATA) && + (type != SSL3_RT_HANDSHAKE) && type) || + (peek && (type != SSL3_RT_APPLICATION_DATA))) { + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, ERR_R_INTERNAL_ERROR); return -1; } @@ -243,7 +444,7 @@ int dtls1_read_bytes(SSL *s, int type, unsigned char *buf, int len, int peek) { return i; } if (i == 0) { - OPENSSL_PUT_ERROR(SSL, SSL_R_SSL_HANDSHAKE_FAILURE); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_SSL_HANDSHAKE_FAILURE); return -1; } } @@ -263,7 +464,7 @@ start: } /* get new packet if necessary */ - if (rr->length == 0) { + if (rr->length == 0 || s->rstate == SSL_ST_READ_BODY) { ret = dtls1_get_record(s); if (ret <= 0) { ret = dtls1_read_failed(s, ret); @@ -306,15 +507,10 @@ start: /* TODO(davidben): Is this check redundant with the handshake_func * check? */ al = SSL_AD_UNEXPECTED_MESSAGE; - OPENSSL_PUT_ERROR(SSL, SSL_R_APP_DATA_IN_HANDSHAKE); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_APP_DATA_IN_HANDSHAKE); goto f_err; } - /* Discard empty records. */ - if (rr->length == 0) { - goto start; - } - if (len <= 0) { return len; } @@ -330,9 +526,8 @@ start: rr->length -= n; rr->off += n; if (rr->length == 0) { + s->rstate = SSL_ST_READ_HEADER; rr->off = 0; - /* The record has been consumed, so we may now clear the buffer. */ - ssl_read_buffer_discard(s); } } @@ -347,7 +542,7 @@ start: /* Alerts may not be fragmented. */ if (rr->length < 2) { al = SSL_AD_DECODE_ERROR; - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_ALERT); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_BAD_ALERT); goto f_err; } @@ -381,7 +576,8 @@ start: s->rwstate = SSL_NOTHING; s->s3->fatal_alert = alert_descr; - OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, + SSL_AD_REASON_OFFSET + alert_descr); BIO_snprintf(tmp, sizeof tmp, "%d", alert_descr); ERR_add_error_data(2, "SSL alert number ", tmp); s->shutdown |= SSL_RECEIVED_SHUTDOWN; @@ -389,19 +585,26 @@ start: return 0; } else { al = SSL_AD_ILLEGAL_PARAMETER; - OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_ALERT_TYPE); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_UNKNOWN_ALERT_TYPE); goto f_err; } goto start; } + if (s->shutdown & SSL_SENT_SHUTDOWN) { + /* but we have not received a shutdown */ + s->rwstate = SSL_NOTHING; + rr->length = 0; + return 0; + } + if (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC) { /* 'Change Cipher Spec' is just a single byte, so we know exactly what the * record payload has to look like */ if (rr->length != 1 || rr->off != 0 || rr->data[0] != SSL3_MT_CCS) { al = SSL_AD_ILLEGAL_PARAMETER; - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_CHANGE_CIPHER_SPEC); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_BAD_CHANGE_CIPHER_SPEC); goto f_err; } @@ -438,7 +641,7 @@ start: if (rr->type == SSL3_RT_HANDSHAKE && !s->in_handshake) { if (rr->length < DTLS1_HM_HEADER_LENGTH) { al = SSL_AD_DECODE_ERROR; - OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_BAD_HANDSHAKE_RECORD); goto f_err; } struct hm_header_st msg_hdr; @@ -466,7 +669,7 @@ start: assert(rr->type != SSL3_RT_CHANGE_CIPHER_SPEC && rr->type != SSL3_RT_ALERT); al = SSL_AD_UNEXPECTED_MESSAGE; - OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD); + OPENSSL_PUT_ERROR(SSL, dtls1_read_bytes, SSL_R_UNEXPECTED_RECORD); f_err: ssl3_send_alert(s, SSL3_AL_FATAL, al); @@ -483,13 +686,13 @@ int dtls1_write_app_data(SSL *s, const void *buf_, int len) { return i; } if (i == 0) { - OPENSSL_PUT_ERROR(SSL, SSL_R_SSL_HANDSHAKE_FAILURE); + OPENSSL_PUT_ERROR(SSL, dtls1_write_app_data, SSL_R_SSL_HANDSHAKE_FAILURE); return -1; } } if (len > SSL3_RT_MAX_PLAIN_LENGTH) { - OPENSSL_PUT_ERROR(SSL, SSL_R_DTLS_MESSAGE_TOO_BIG); + OPENSSL_PUT_ERROR(SSL, dtls1_write_app_data, SSL_R_DTLS_MESSAGE_TOO_BIG); return -1; } @@ -510,11 +713,73 @@ int dtls1_write_bytes(SSL *s, int type, const void *buf, int len, return i; } +/* dtls1_seal_record seals a new record of type |type| and plaintext |in| and + * writes it to |out|. At most |max_out| bytes will be written. It returns one + * on success and zero on error. On success, it updates the write sequence + * number. */ +static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len, + size_t max_out, uint8_t type, const uint8_t *in, + size_t in_len, enum dtls1_use_epoch_t use_epoch) { + if (max_out < DTLS1_RT_HEADER_LENGTH) { + OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL); + return 0; + } + + /* Determine the parameters for the current epoch. */ + uint16_t epoch = s->d1->w_epoch; + SSL_AEAD_CTX *aead = s->aead_write_ctx; + uint8_t *seq = s->s3->write_sequence; + if (use_epoch == dtls1_use_previous_epoch) { + /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1 + * (negotiated cipher) exist. */ + assert(s->d1->w_epoch == 1); + epoch = s->d1->w_epoch - 1; + aead = NULL; + seq = s->d1->last_write_sequence; + } + + out[0] = type; + + uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION; + out[1] = wire_version >> 8; + out[2] = wire_version & 0xff; + + out[3] = epoch >> 8; + out[4] = epoch & 0xff; + memcpy(&out[5], &seq[2], 6); + + size_t ciphertext_len; + if (!SSL_AEAD_CTX_seal(aead, out + DTLS1_RT_HEADER_LENGTH, &ciphertext_len, + max_out - DTLS1_RT_HEADER_LENGTH, type, wire_version, + &out[3] /* seq */, in, in_len) || + !ssl3_record_sequence_update(&seq[2], 6)) { + return 0; + } + + if (ciphertext_len >= 1 << 16) { + OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, ERR_R_OVERFLOW); + return 0; + } + out[11] = ciphertext_len >> 8; + out[12] = ciphertext_len & 0xff; + + *out_len = DTLS1_RT_HEADER_LENGTH + ciphertext_len; + + if (s->msg_callback) { + s->msg_callback(1 /* write */, 0, SSL3_RT_HEADER, out, + DTLS1_RT_HEADER_LENGTH, s, s->msg_callback_arg); + } + + return 1; +} + static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, unsigned int len, enum dtls1_use_epoch_t use_epoch) { + SSL3_BUFFER *wb = &s->s3->wbuf; + /* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there * is never pending data. */ - assert(!ssl_write_buffer_is_pending(s)); + assert(s->s3->wbuf.left == 0); /* If we have an alert to send, lets send it */ if (s->s3->alert_dispatch) { @@ -525,8 +790,7 @@ static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, /* if it went, fall through and send more stuff */ } - if (len > SSL3_RT_MAX_PLAIN_LENGTH) { - OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); + if (wb->buf == NULL && !ssl3_setup_write_buffer(s)) { return -1; } @@ -534,15 +798,21 @@ static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, return 0; } - size_t max_out = len + ssl_max_seal_overhead(s); - uint8_t *out; + /* Align the output so the ciphertext is aligned to |SSL3_ALIGN_PAYLOAD|. */ + uintptr_t align = (uintptr_t)wb->buf + DTLS1_RT_HEADER_LENGTH; + align = (0 - align) & (SSL3_ALIGN_PAYLOAD - 1); + uint8_t *out = wb->buf + align; + wb->offset = align; + size_t max_out = wb->len - wb->offset; + size_t ciphertext_len; - if (!ssl_write_buffer_init(s, &out, max_out) || - !dtls_seal_record(s, out, &ciphertext_len, max_out, type, buf, len, - use_epoch)) { + if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len, + use_epoch)) { return -1; } - ssl_write_buffer_set_len(s, ciphertext_len); + + /* now let's set up wb */ + wb->left = ciphertext_len; /* memorize arguments so that ssl3_write_pending can detect bad write retries * later */ @@ -555,6 +825,49 @@ static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, return ssl3_write_pending(s, type, buf, len); } +static int dtls1_record_replay_check(SSL *s, DTLS1_BITMAP *bitmap) { + int cmp; + unsigned int shift; + const uint8_t *seq = s->s3->read_sequence; + + cmp = satsub64be(seq, bitmap->max_seq_num); + if (cmp > 0) { + memcpy(s->s3->rrec.seq_num, seq, 8); + return 1; /* this record in new */ + } + shift = -cmp; + if (shift >= sizeof(bitmap->map) * 8) { + return 0; /* stale, outside the window */ + } else if (bitmap->map & (((uint64_t)1) << shift)) { + return 0; /* record previously received */ + } + + memcpy(s->s3->rrec.seq_num, seq, 8); + return 1; +} + +static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap) { + int cmp; + unsigned int shift; + const uint8_t *seq = s->s3->read_sequence; + + cmp = satsub64be(seq, bitmap->max_seq_num); + if (cmp > 0) { + shift = cmp; + if (shift < sizeof(bitmap->map) * 8) { + bitmap->map <<= shift, bitmap->map |= 1UL; + } else { + bitmap->map = 1UL; + } + memcpy(bitmap->max_seq_num, seq, 8); + } else { + shift = -cmp; + if (shift < sizeof(bitmap->map) * 8) { + bitmap->map |= ((uint64_t)1) << shift; + } + } +} + int dtls1_dispatch_alert(SSL *s) { int i, j; void (*cb)(const SSL *ssl, int type, int val) = NULL; |