diff options
Diffstat (limited to 'src/ssl/d1_both.c')
-rw-r--r-- | src/ssl/d1_both.c | 112 |
1 files changed, 50 insertions, 62 deletions
diff --git a/src/ssl/d1_both.c b/src/ssl/d1_both.c index ac35a66..1acb3ce 100644 --- a/src/ssl/d1_both.c +++ b/src/ssl/d1_both.c @@ -111,6 +111,8 @@ * copied and put under another distribution licence * [including the GNU Public Licence.] */ +#include <openssl/ssl.h> + #include <assert.h> #include <limits.h> #include <stdio.h> @@ -147,52 +149,44 @@ static void dtls1_fix_message_header(SSL *s, unsigned long frag_off, unsigned long frag_len); static unsigned char *dtls1_write_message_header(SSL *s, unsigned char *p); -static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len, - int reassembly) { - hm_fragment *frag = NULL; - uint8_t *buf = NULL; - uint8_t *bitmask = NULL; - - frag = (hm_fragment *)OPENSSL_malloc(sizeof(hm_fragment)); +static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) { + hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment)); if (frag == NULL) { - OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE); + OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); return NULL; } + memset(frag, 0, sizeof(hm_fragment)); - if (frag_len) { - buf = (uint8_t *)OPENSSL_malloc(frag_len); - if (buf == NULL) { - OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE); - OPENSSL_free(frag); - return NULL; + /* If the handshake message is empty, |frag->fragment| and |frag->reassembly| + * are NULL. */ + if (frag_len > 0) { + frag->fragment = OPENSSL_malloc(frag_len); + if (frag->fragment == NULL) { + OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); + goto err; } - } - - /* zero length fragment gets zero frag->fragment */ - frag->fragment = buf; - /* Initialize reassembly bitmask if necessary */ - if (reassembly && frag_len > 0) { - if (frag_len + 7 < frag_len) { - OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_OVERFLOW); - return NULL; - } - size_t bitmask_len = (frag_len + 7) / 8; - bitmask = (uint8_t *)OPENSSL_malloc(bitmask_len); - if (bitmask == NULL) { - OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE); - if (buf != NULL) { - OPENSSL_free(buf); + if (reassembly) { + /* Initialize reassembly bitmask. */ + if (frag_len + 7 < frag_len) { + OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); + goto err; } - OPENSSL_free(frag); - return NULL; + size_t bitmask_len = (frag_len + 7) / 8; + frag->reassembly = OPENSSL_malloc(bitmask_len); + if (frag->reassembly == NULL) { + OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); + goto err; + } + memset(frag->reassembly, 0, bitmask_len); } - memset(bitmask, 0, bitmask_len); } - frag->reassembly = bitmask; - return frag; + +err: + dtls1_hm_fragment_free(frag); + return NULL; } void dtls1_hm_fragment_free(hm_fragment *frag) { @@ -326,7 +320,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) { if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) { /* To make forward progress, the MTU must, at minimum, fit the handshake * header and one byte of handshake body. */ - OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL); + OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL); return -1; } @@ -344,7 +338,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) { assert(type == SSL3_RT_CHANGE_CIPHER_SPEC); /* ChangeCipherSpec cannot be fragmented. */ if (s->init_num > curr_mtu) { - OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL); + OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL); return -1; } len = s->init_num; @@ -450,8 +444,7 @@ static hm_fragment *dtls1_get_buffered_message( frag->msg_header.msg_len != msg_hdr->msg_len) { /* The new fragment must be compatible with the previous fragments from * this message. */ - OPENSSL_PUT_ERROR(SSL, dtls1_get_buffered_message, - SSL_R_FRAGMENT_MISMATCH); + OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH); ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER); return NULL; } @@ -473,11 +466,7 @@ static size_t dtls1_max_handshake_message_len(const SSL *s) { /* dtls1_process_fragment reads a handshake fragment and processes it. It * returns one if a fragment was successfully processed and 0 or -1 on error. */ static int dtls1_process_fragment(SSL *s) { - /* Read handshake message header. - * - * TODO(davidben): ssl_read_bytes allows splitting the fragment header and - * body across two records. Change this interface to consume the fragment in - * one pass. */ + /* Read handshake message header. */ uint8_t header[DTLS1_HM_HEADER_LENGTH]; int ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, header, DTLS1_HM_HEADER_LENGTH, 0); @@ -485,7 +474,7 @@ static int dtls1_process_fragment(SSL *s) { return ret; } if (ret != DTLS1_HM_HEADER_LENGTH) { - OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); return -1; } @@ -494,14 +483,16 @@ static int dtls1_process_fragment(SSL *s) { struct hm_header_st msg_hdr; dtls1_get_message_header(header, &msg_hdr); + /* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There + * should be no need to reach into |s->s3->rrec.length|. */ const size_t frag_off = msg_hdr.frag_off; const size_t frag_len = msg_hdr.frag_len; const size_t msg_len = msg_hdr.msg_len; if (frag_off > msg_len || frag_off + frag_len < frag_off || frag_off + frag_len > msg_len || - msg_len > dtls1_max_handshake_message_len(s)) { - OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, - SSL_R_EXCESSIVE_MESSAGE_SIZE); + msg_len > dtls1_max_handshake_message_len(s) || + frag_len > s->s3->rrec.length) { + OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER); return -1; } @@ -535,8 +526,8 @@ static int dtls1_process_fragment(SSL *s) { ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, frag->fragment + frag_off, frag_len, 0); if (ret != frag_len) { - OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE); - ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); + ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR); return -1; } dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len); @@ -563,7 +554,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max, s->s3->tmp.reuse_message = 0; if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) { al = SSL_AD_UNEXPECTED_MESSAGE; - OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); goto f_err; } *ok = 1; @@ -589,22 +580,19 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max, assert(frag->reassembly == NULL); if (frag->msg_header.msg_len > (size_t)max) { - OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_EXCESSIVE_MESSAGE_SIZE); + OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); goto err; } + /* Reconstruct the assembled message. */ + size_t len; CBB cbb; + CBB_zero(&cbb); if (!BUF_MEM_grow(s->init_buf, (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH) || - !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max)) { - OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_MALLOC_FAILURE); - goto err; - } - - /* Reconstruct the assembled message. */ - size_t len; - if (!CBB_add_u8(&cbb, frag->msg_header.type) || + !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max) || + !CBB_add_u8(&cbb, frag->msg_header.type) || !CBB_add_u24(&cbb, frag->msg_header.msg_len) || !CBB_add_u16(&cbb, frag->msg_header.seq) || !CBB_add_u24(&cbb, 0 /* frag_off */) || @@ -612,7 +600,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max, !CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) || !CBB_finish(&cbb, NULL, &len)) { CBB_cleanup(&cbb); - OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_INTERNAL_ERROR); + OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); goto err; } assert(len == (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH); @@ -628,7 +616,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max, if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) { al = SSL_AD_UNEXPECTED_MESSAGE; - OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); goto f_err; } if (hash_message == ssl_hash_message && !ssl3_hash_current_message(s)) { |