summaryrefslogtreecommitdiffstats
path: root/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java
diff options
context:
space:
mode:
Diffstat (limited to 'bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java')
-rw-r--r--bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java78
1 files changed, 51 insertions, 27 deletions
diff --git a/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java
index 3819251..84ccfcb 100644
--- a/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java
+++ b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DTLSReliableHandshake.java
@@ -10,12 +10,11 @@ import org.bouncycastle.util.Integers;
class DTLSReliableHandshake
{
-
private final static int MAX_RECEIVE_AHEAD = 10;
private final DTLSRecordLayer recordLayer;
- private TlsHandshakeHash hash = new DeferredHash();
+ private TlsHandshakeHash handshakeHash;
private Hashtable currentInboundFlight = new Hashtable();
private Hashtable previousInboundFlight = null;
@@ -27,25 +26,31 @@ class DTLSReliableHandshake
DTLSReliableHandshake(TlsContext context, DTLSRecordLayer transport)
{
this.recordLayer = transport;
- this.hash.init(context);
+ this.handshakeHash = new DeferredHash();
+ this.handshakeHash.init(context);
}
void notifyHelloComplete()
{
- this.hash = this.hash.commit();
+ this.handshakeHash = handshakeHash.notifyPRFDetermined();
+ }
+
+ TlsHandshakeHash getHandshakeHash()
+ {
+ return handshakeHash;
}
- byte[] getCurrentHash()
+ TlsHandshakeHash prepareToFinish()
{
- TlsHandshakeHash copyOfHash = hash.fork();
- byte[] result = new byte[copyOfHash.getDigestSize()];
- copyOfHash.doFinal(result, 0);
+ TlsHandshakeHash result = handshakeHash;
+ this.handshakeHash = handshakeHash.stopTracking();
return result;
}
void sendMessage(short msg_type, byte[] body)
throws IOException
{
+ TlsUtils.checkUint24(body.length);
if (!sending)
{
@@ -62,10 +67,21 @@ class DTLSReliableHandshake
updateHandshakeMessagesDigest(message);
}
- Message receiveMessage()
+ byte[] receiveMessageBody(short msg_type)
throws IOException
{
+ Message message = receiveMessage();
+ if (message.getType() != msg_type)
+ {
+ throw new TlsFatalAlert(AlertDescription.unexpected_message);
+ }
+ return message.getBody();
+ }
+
+ Message receiveMessage()
+ throws IOException
+ {
if (sending)
{
sending = false;
@@ -93,7 +109,6 @@ class DTLSReliableHandshake
for (; ; )
{
-
int receiveLimit = recordLayer.getReceiveLimit();
if (buf == null || buf.length < receiveLimit)
{
@@ -280,7 +295,7 @@ class DTLSReliableHandshake
void resetHandshakeMessagesDigest()
{
- hash.reset();
+ handshakeHash.reset();
}
/**
@@ -328,8 +343,8 @@ class DTLSReliableHandshake
TlsUtils.writeUint16(message.getSeq(), buf, 4);
TlsUtils.writeUint24(0, buf, 6);
TlsUtils.writeUint24(body.length, buf, 9);
- hash.update(buf, 0, buf.length);
- hash.update(body, 0, body.length);
+ handshakeHash.update(buf, 0, buf.length);
+ handshakeHash.update(body, 0, body.length);
}
return message;
}
@@ -337,7 +352,6 @@ class DTLSReliableHandshake
private void writeMessage(Message message)
throws IOException
{
-
int sendLimit = recordLayer.getSendLimit();
int fragmentLimit = sendLimit - 12;
@@ -364,18 +378,15 @@ class DTLSReliableHandshake
private void writeHandshakeFragment(Message message, int fragment_offset, int fragment_length)
throws IOException
{
-
- ByteArrayOutputStream buf = new ByteArrayOutputStream();
- TlsUtils.writeUint8(message.getType(), buf);
- TlsUtils.writeUint24(message.getBody().length, buf);
- TlsUtils.writeUint16(message.getSeq(), buf);
- TlsUtils.writeUint24(fragment_offset, buf);
- TlsUtils.writeUint24(fragment_length, buf);
- buf.write(message.getBody(), fragment_offset, fragment_length);
-
- byte[] fragment = buf.toByteArray();
-
- recordLayer.send(fragment, 0, fragment.length);
+ RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length);
+ TlsUtils.writeUint8(message.getType(), fragment);
+ TlsUtils.writeUint24(message.getBody().length, fragment);
+ TlsUtils.writeUint16(message.getSeq(), fragment);
+ TlsUtils.writeUint24(fragment_offset, fragment);
+ TlsUtils.writeUint24(fragment_length, fragment);
+ fragment.write(message.getBody(), fragment_offset, fragment_length);
+
+ fragment.sendToRecordLayer(recordLayer);
}
private static boolean checkAll(Hashtable inboundFlight)
@@ -402,7 +413,6 @@ class DTLSReliableHandshake
static class Message
{
-
private final int message_seq;
private final short msg_type;
private final byte[] body;
@@ -429,4 +439,18 @@ class DTLSReliableHandshake
return body;
}
}
+
+ static class RecordLayerBuffer extends ByteArrayOutputStream
+ {
+ RecordLayerBuffer(int size)
+ {
+ super(size);
+ }
+
+ void sendToRecordLayer(DTLSRecordLayer recordLayer) throws IOException
+ {
+ recordLayer.send(buf, 0, count);
+ buf = null;
+ }
+ }
}