diff options
Diffstat (limited to 'bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java')
-rw-r--r-- | bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java | 78 |
1 files changed, 51 insertions, 27 deletions
diff --git a/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java index 3819251..84ccfcb 100644 --- a/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java +++ b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java @@ -10,12 +10,11 @@ import org.bouncycastle.util.Integers; class DTLSReliableHandshake { - private final static int MAX_RECEIVE_AHEAD = 10; private final DTLSRecordLayer recordLayer; - private TlsHandshakeHash hash = new DeferredHash(); + private TlsHandshakeHash handshakeHash; private Hashtable currentInboundFlight = new Hashtable(); private Hashtable previousInboundFlight = null; @@ -27,25 +26,31 @@ class DTLSReliableHandshake DTLSReliableHandshake(TlsContext context, DTLSRecordLayer transport) { this.recordLayer = transport; - this.hash.init(context); + this.handshakeHash = new DeferredHash(); + this.handshakeHash.init(context); } void notifyHelloComplete() { - this.hash = this.hash.commit(); + this.handshakeHash = handshakeHash.notifyPRFDetermined(); + } + + TlsHandshakeHash getHandshakeHash() + { + return handshakeHash; } - byte[] getCurrentHash() + TlsHandshakeHash prepareToFinish() { - TlsHandshakeHash copyOfHash = hash.fork(); - byte[] result = new byte[copyOfHash.getDigestSize()]; - copyOfHash.doFinal(result, 0); + TlsHandshakeHash result = handshakeHash; + this.handshakeHash = handshakeHash.stopTracking(); return result; } void sendMessage(short msg_type, byte[] body) throws IOException { + TlsUtils.checkUint24(body.length); if (!sending) { @@ -62,10 +67,21 @@ class DTLSReliableHandshake updateHandshakeMessagesDigest(message); } - Message receiveMessage() + byte[] receiveMessageBody(short msg_type) throws IOException { + Message message = receiveMessage(); + if (message.getType() != msg_type) + { + throw new TlsFatalAlert(AlertDescription.unexpected_message); + } + return message.getBody(); + } + + Message receiveMessage() + throws IOException + { if (sending) { sending = false; @@ -93,7 +109,6 @@ class DTLSReliableHandshake for (; ; ) { - int receiveLimit = recordLayer.getReceiveLimit(); if (buf == null || buf.length < receiveLimit) { @@ -280,7 +295,7 @@ class DTLSReliableHandshake void resetHandshakeMessagesDigest() { - hash.reset(); + handshakeHash.reset(); } /** @@ -328,8 +343,8 @@ class DTLSReliableHandshake TlsUtils.writeUint16(message.getSeq(), buf, 4); TlsUtils.writeUint24(0, buf, 6); TlsUtils.writeUint24(body.length, buf, 9); - hash.update(buf, 0, buf.length); - hash.update(body, 0, body.length); + handshakeHash.update(buf, 0, buf.length); + handshakeHash.update(body, 0, body.length); } return message; } @@ -337,7 +352,6 @@ class DTLSReliableHandshake private void writeMessage(Message message) throws IOException { - int sendLimit = recordLayer.getSendLimit(); int fragmentLimit = sendLimit - 12; @@ -364,18 +378,15 @@ class DTLSReliableHandshake private void writeHandshakeFragment(Message message, int fragment_offset, int fragment_length) throws IOException { - - ByteArrayOutputStream buf = new ByteArrayOutputStream(); - TlsUtils.writeUint8(message.getType(), buf); - TlsUtils.writeUint24(message.getBody().length, buf); - TlsUtils.writeUint16(message.getSeq(), buf); - TlsUtils.writeUint24(fragment_offset, buf); - TlsUtils.writeUint24(fragment_length, buf); - buf.write(message.getBody(), fragment_offset, fragment_length); - - byte[] fragment = buf.toByteArray(); - - recordLayer.send(fragment, 0, fragment.length); + RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length); + TlsUtils.writeUint8(message.getType(), fragment); + TlsUtils.writeUint24(message.getBody().length, fragment); + TlsUtils.writeUint16(message.getSeq(), fragment); + TlsUtils.writeUint24(fragment_offset, fragment); + TlsUtils.writeUint24(fragment_length, fragment); + fragment.write(message.getBody(), fragment_offset, fragment_length); + + fragment.sendToRecordLayer(recordLayer); } private static boolean checkAll(Hashtable inboundFlight) @@ -402,7 +413,6 @@ class DTLSReliableHandshake static class Message { - private final int message_seq; private final short msg_type; private final byte[] body; @@ -429,4 +439,18 @@ class DTLSReliableHandshake return body; } } + + static class RecordLayerBuffer extends ByteArrayOutputStream + { + RecordLayerBuffer(int size) + { + super(size); + } + + void sendToRecordLayer(DTLSRecordLayer recordLayer) throws IOException + { + recordLayer.send(buf, 0, count); + buf = null; + } + } } |