diff options
Diffstat (limited to 'bcprov/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java')
-rw-r--r-- | bcprov/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/bcprov/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java new file mode 100644 index 0000000..274e69a --- /dev/null +++ b/bcprov/src/main/java/org/bouncycastle/crypto/tls/DeferredHash.java @@ -0,0 +1,207 @@ +package org.bouncycastle.crypto.tls; + +import java.util.Enumeration; +import java.util.Hashtable; + +import org.bouncycastle.crypto.Digest; +import org.bouncycastle.util.Shorts; + +/** + * Buffers input until the hash algorithm is determined. + */ +class DeferredHash + implements TlsHandshakeHash +{ + protected static final int BUFFERING_HASH_LIMIT = 4; + + protected TlsContext context; + + private DigestInputBuffer buf; + private Hashtable hashes; + private Short prfHashAlgorithm; + + DeferredHash() + { + this.buf = new DigestInputBuffer(); + this.hashes = new Hashtable(); + this.prfHashAlgorithm = null; + } + + private DeferredHash(Short prfHashAlgorithm, Digest prfHash) + { + this.buf = null; + this.hashes = new Hashtable(); + this.prfHashAlgorithm = prfHashAlgorithm; + hashes.put(prfHashAlgorithm, prfHash); + } + + public void init(TlsContext context) + { + this.context = context; + } + + public TlsHandshakeHash notifyPRFDetermined() + { + int prfAlgorithm = context.getSecurityParameters().getPrfAlgorithm(); + if (prfAlgorithm == PRFAlgorithm.tls_prf_legacy) + { + CombinedHash legacyHash = new CombinedHash(); + legacyHash.init(context); + buf.updateDigest(legacyHash); + return legacyHash.notifyPRFDetermined(); + } + + this.prfHashAlgorithm = Shorts.valueOf(TlsUtils.getHashAlgorithmForPRFAlgorithm(prfAlgorithm)); + + checkTrackingHash(prfHashAlgorithm); + + return this; + } + + public void trackHashAlgorithm(short hashAlgorithm) + { + if (buf == null) + { + throw new IllegalStateException("Too late to track more hash algorithms"); + } + + checkTrackingHash(Shorts.valueOf(hashAlgorithm)); + } + + public void sealHashAlgorithms() + { + checkStopBuffering(); + } + + public TlsHandshakeHash stopTracking() + { + Digest prfHash = TlsUtils.cloneHash(prfHashAlgorithm.shortValue(), (Digest)hashes.get(prfHashAlgorithm)); + if (buf != null) + { + buf.updateDigest(prfHash); + } + DeferredHash result = new DeferredHash(prfHashAlgorithm, prfHash); + result.init(context); + return result; + } + + public Digest forkPRFHash() + { + checkStopBuffering(); + + if (buf != null) + { + Digest prfHash = TlsUtils.createHash(prfHashAlgorithm.shortValue()); + buf.updateDigest(prfHash); + return prfHash; + } + + return TlsUtils.cloneHash(prfHashAlgorithm.shortValue(), (Digest)hashes.get(prfHashAlgorithm)); + } + + public byte[] getFinalHash(short hashAlgorithm) + { + Digest d = (Digest)hashes.get(Shorts.valueOf(hashAlgorithm)); + if (d == null) + { + throw new IllegalStateException("HashAlgorithm " + hashAlgorithm + " is not being tracked"); + } + + d = TlsUtils.cloneHash(hashAlgorithm, d); + if (buf != null) + { + buf.updateDigest(d); + } + + byte[] bs = new byte[d.getDigestSize()]; + d.doFinal(bs, 0); + return bs; + } + + public String getAlgorithmName() + { + throw new IllegalStateException("Use fork() to get a definite Digest"); + } + + public int getDigestSize() + { + throw new IllegalStateException("Use fork() to get a definite Digest"); + } + + public void update(byte input) + { + if (buf != null) + { + buf.write(input); + return; + } + + Enumeration e = hashes.elements(); + while (e.hasMoreElements()) + { + Digest hash = (Digest)e.nextElement(); + hash.update(input); + } + } + + public void update(byte[] input, int inOff, int len) + { + if (buf != null) + { + buf.write(input, inOff, len); + return; + } + + Enumeration e = hashes.elements(); + while (e.hasMoreElements()) + { + Digest hash = (Digest)e.nextElement(); + hash.update(input, inOff, len); + } + } + + public int doFinal(byte[] output, int outOff) + { + throw new IllegalStateException("Use fork() to get a definite Digest"); + } + + public void reset() + { + if (buf != null) + { + buf.reset(); + return; + } + + Enumeration e = hashes.elements(); + while (e.hasMoreElements()) + { + Digest hash = (Digest)e.nextElement(); + hash.reset(); + } + } + + protected void checkStopBuffering() + { + if (buf != null && hashes.size() <= BUFFERING_HASH_LIMIT) + { + Enumeration e = hashes.elements(); + while (e.hasMoreElements()) + { + Digest hash = (Digest)e.nextElement(); + buf.updateDigest(hash); + } + + this.buf = null; + } + } + + protected void checkTrackingHash(Short hashAlgorithm) + { + if (!hashes.containsKey(hashAlgorithm)) + { + Digest hash = TlsUtils.createHash(hashAlgorithm.shortValue()); + hashes.put(hashAlgorithm, hash); + } + } +} |