diff options
Diffstat (limited to 'src/ssl/d1_both.c')
-rw-r--r-- | src/ssl/d1_both.c | 74 |
1 files changed, 17 insertions, 57 deletions
diff --git a/src/ssl/d1_both.c b/src/ssl/d1_both.c index 662f518..ac35a66 100644 --- a/src/ssl/d1_both.c +++ b/src/ssl/d1_both.c @@ -259,11 +259,10 @@ static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start, /* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or * SSL3_RT_CHANGE_CIPHER_SPEC) */ -int dtls1_do_write(SSL *s, int type) { +int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) { int ret; int curr_mtu; unsigned int len, frag_off; - size_t max_overhead = 0; /* AHA! Figure out the MTU, and stick to the right size */ if (s->d1->mtu < dtls1_min_mtu() && @@ -286,12 +285,7 @@ int dtls1_do_write(SSL *s, int type) { } /* Determine the maximum overhead of the current cipher. */ - if (s->aead_write_ctx != NULL) { - max_overhead = EVP_AEAD_max_overhead(s->aead_write_ctx->ctx.aead); - if (s->aead_write_ctx->variable_nonce_included_in_record) { - max_overhead += s->aead_write_ctx->variable_nonce_len; - } - } + size_t max_overhead = SSL_AEAD_CTX_max_overhead(s->aead_write_ctx); frag_off = 0; while (s->init_num) { @@ -356,7 +350,8 @@ int dtls1_do_write(SSL *s, int type) { len = s->init_num; } - ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len); + ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len, + use_epoch); if (ret < 0) { return -1; } @@ -409,8 +404,7 @@ static int dtls1_discard_fragment_body(SSL *s, size_t frag_len) { uint8_t discard[256]; while (frag_len > 0) { size_t chunk = frag_len < sizeof(discard) ? frag_len : sizeof(discard); - int ret = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, discard, chunk, - 0); + int ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, discard, chunk, 0); if (ret != chunk) { return 0; } @@ -485,8 +479,8 @@ static int dtls1_process_fragment(SSL *s) { * body across two records. Change this interface to consume the fragment in * one pass. */ uint8_t header[DTLS1_HM_HEADER_LENGTH]; - int ret = s->method->ssl_read_bytes(s, SSL3_RT_HANDSHAKE, header, - DTLS1_HM_HEADER_LENGTH, 0); + int ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, header, + DTLS1_HM_HEADER_LENGTH, 0); if (ret <= 0) { return ret; } @@ -538,8 +532,8 @@ static int dtls1_process_fragment(SSL *s) { assert(msg_len > 0); /* Read the body of the fragment. */ - ret = s->method->ssl_read_bytes( - s, SSL3_RT_HANDSHAKE, frag->fragment + frag_off, frag_len, 0); + ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, frag->fragment + frag_off, + frag_len, 0); if (ret != frag_len) { OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE); ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); @@ -690,7 +684,7 @@ int dtls1_send_change_cipher_spec(SSL *s, int a, int b) { } /* SSL3_ST_CW_CHANGE_B */ - return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC); + return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC, dtls1_use_current_epoch); } int dtls1_read_failed(SSL *s, int code) { @@ -730,7 +724,6 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) { int ret; /* XDTLS: for now assuming that read/writes are blocking */ unsigned long header_length; - uint8_t save_write_sequence[8]; /* assert(s->init_num == 0); assert(s->init_off == 0); */ @@ -749,45 +742,18 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) { frag->msg_header.msg_len, frag->msg_header.seq, 0, frag->msg_header.frag_len); - /* Save current state. */ - SSL_AEAD_CTX *aead_write_ctx = s->aead_write_ctx; - uint16_t epoch = s->d1->w_epoch; - /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1 * (negotiated cipher) exist. */ - assert(epoch == 0 || epoch == 1); - assert(frag->msg_header.epoch <= epoch); - const int fragment_from_previous_epoch = (epoch == 1 && - frag->msg_header.epoch == 0); - if (fragment_from_previous_epoch) { - /* Rewind to the previous epoch. - * - * TODO(davidben): Instead of swapping out connection-global state, this - * logic should pass a "use previous epoch" parameter down to lower-level - * functions. */ - s->d1->w_epoch = frag->msg_header.epoch; - s->aead_write_ctx = NULL; - memcpy(save_write_sequence, s->s3->write_sequence, - sizeof(s->s3->write_sequence)); - memcpy(s->s3->write_sequence, s->d1->last_write_sequence, - sizeof(s->s3->write_sequence)); - } else { - /* Otherwise the messages must be from the same epoch. */ - assert(frag->msg_header.epoch == epoch); + assert(s->d1->w_epoch == 0 || s->d1->w_epoch == 1); + assert(frag->msg_header.epoch <= s->d1->w_epoch); + enum dtls1_use_epoch_t use_epoch = dtls1_use_current_epoch; + if (s->d1->w_epoch == 1 && frag->msg_header.epoch == 0) { + use_epoch = dtls1_use_previous_epoch; } ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC - : SSL3_RT_HANDSHAKE); - - if (fragment_from_previous_epoch) { - /* Restore the current epoch. */ - s->aead_write_ctx = aead_write_ctx; - s->d1->w_epoch = epoch; - memcpy(s->d1->last_write_sequence, s->s3->write_sequence, - sizeof(s->s3->write_sequence)); - memcpy(s->s3->write_sequence, save_write_sequence, - sizeof(s->s3->write_sequence)); - } + : SSL3_RT_HANDSHAKE, + use_epoch); (void)BIO_flush(SSL_get_wbio(s)); return ret; @@ -917,9 +883,3 @@ void dtls1_get_message_header(uint8_t *data, n2l3(data, msg_hdr->frag_off); n2l3(data, msg_hdr->frag_len); } - -int dtls1_shutdown(SSL *s) { - int ret; - ret = ssl3_shutdown(s); - return ret; -} |