summaryrefslogtreecommitdiffstats
path: root/bcprov/src/main/java/org/bouncycastle/crypto/agreement/kdf/DHKEKGenerator.java
blob: 6feb5076e74850d48a3fb2208f6580e80276e6be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package org.bouncycastle.crypto.agreement.kdf;

import java.io.IOException;

import org.bouncycastle.asn1.ASN1EncodableVector;
import org.bouncycastle.asn1.ASN1Encoding;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.DEROctetString;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.DERTaggedObject;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.DerivationFunction;
import org.bouncycastle.crypto.DerivationParameters;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.util.Pack;

/**
 * RFC 2631 Diffie-hellman KEK derivation function.
 */
public class DHKEKGenerator
    implements DerivationFunction
{
    private final Digest digest;

    private ASN1ObjectIdentifier algorithm;
    private int                 keySize;
    private byte[]              z;
    private byte[]              partyAInfo;

    public DHKEKGenerator(
        Digest digest)
    {
        this.digest = digest;
    }

    public void init(DerivationParameters param)
    {
        DHKDFParameters params = (DHKDFParameters)param;

        this.algorithm = params.getAlgorithm();
        this.keySize = params.getKeySize();
        this.z = params.getZ();
        this.partyAInfo = params.getExtraInfo();
    }

    public Digest getDigest()
    {
        return digest;
    }

    public int generateBytes(byte[] out, int outOff, int len)
        throws DataLengthException, IllegalArgumentException
    {
        if ((out.length - len) < outOff)
        {
            throw new DataLengthException("output buffer too small");
        }

        long    oBytes = len;
        int     outLen = digest.getDigestSize();

        //
        // this is at odds with the standard implementation, the
        // maximum value should be hBits * (2^32 - 1) where hBits
        // is the digest output size in bits. We can't have an
        // array with a long index at the moment...
        //
        if (oBytes > ((2L << 32) - 1))
        {
            throw new IllegalArgumentException("Output length too large");
        }

        int cThreshold = (int)((oBytes + outLen - 1) / outLen);

        byte[] dig = new byte[digest.getDigestSize()];

        int counter = 1;

        for (int i = 0; i < cThreshold; i++)
        {
            digest.update(z, 0, z.length);

            // OtherInfo
            ASN1EncodableVector v1 = new ASN1EncodableVector();
            // KeySpecificInfo
            ASN1EncodableVector v2 = new ASN1EncodableVector();

            v2.add(algorithm);
            v2.add(new DEROctetString(Pack.intToBigEndian(counter)));

            v1.add(new DERSequence(v2));

            if (partyAInfo != null)
            {
                v1.add(new DERTaggedObject(true, 0, new DEROctetString(partyAInfo)));
            }

            v1.add(new DERTaggedObject(true, 2, new DEROctetString(Pack.intToBigEndian(keySize))));

            try
            {
                byte[] other = new DERSequence(v1).getEncoded(ASN1Encoding.DER);

                digest.update(other, 0, other.length);
            }
            catch (IOException e)
            {
                throw new IllegalArgumentException("unable to encode parameter info: " + e.getMessage());
            }

            digest.doFinal(dig, 0);

            if (len > outLen)
            {
                System.arraycopy(dig, 0, out, outOff, outLen);
                outOff += outLen;
                len -= outLen;
            }
            else
            {
                System.arraycopy(dig, 0, out, outOff, len);
            }

            counter++;
        }

        digest.reset();

        return (int)oBytes;
    }
}