aboutsummaryrefslogtreecommitdiffstats
path: root/guava/src/com/google/common/math/BigIntegerMath.java
blob: 99e69fe956cd6c46235d67d6cd496e4a102207eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
/*
 * Copyright (C) 2011 The Guava Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.google.common.math;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.math.MathPreconditions.checkNonNegative;
import static com.google.common.math.MathPreconditions.checkPositive;
import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
import static java.math.RoundingMode.CEILING;
import static java.math.RoundingMode.FLOOR;
import static java.math.RoundingMode.HALF_EVEN;

import com.google.common.annotations.Beta;
import com.google.common.annotations.VisibleForTesting;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;

/**
 * A class for arithmetic on values of type {@code BigInteger}.
 *
 * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
 * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
 *
 * <p>Similar functionality for {@code int} and for {@code long} can be found in
 * {@link IntMath} and {@link LongMath} respectively.
 *
 * @author Louis Wasserman
 * @since 11.0
 */
@Beta
public final class BigIntegerMath {
  /**
   * Returns {@code true} if {@code x} represents a power of two.
   */
  public static boolean isPowerOfTwo(BigInteger x) {
    checkNotNull(x);
    return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
  }

  /**
   * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
   *
   * @throws IllegalArgumentException if {@code x <= 0}
   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
   *         is not a power of two
   */
  @SuppressWarnings("fallthrough")
  public static int log2(BigInteger x, RoundingMode mode) {
    checkPositive("x", checkNotNull(x));
    int logFloor = x.bitLength() - 1;
    switch (mode) {
      case UNNECESSARY:
        checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
      case DOWN:
      case FLOOR:
        return logFloor;

      case UP:
      case CEILING:
        return isPowerOfTwo(x) ? logFloor : logFloor + 1;

      case HALF_DOWN:
      case HALF_UP:
      case HALF_EVEN:
        if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
          BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
              SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
          if (x.compareTo(halfPower) <= 0) {
            return logFloor;
          } else {
            return logFloor + 1;
          }
        }
        /*
         * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
         *
         * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
         * logFloor + 1).
         */
        BigInteger x2 = x.pow(2);
        int logX2Floor = x2.bitLength() - 1;
        return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;

      default:
        throw new AssertionError();
    }
  }

  /*
   * The maximum number of bits in a square root for which we'll precompute an explicit half power
   * of two. This can be any value, but higher values incur more class load time and linearly
   * increasing memory consumption.
   */
  @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;

  @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
      new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);

  /**
   * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
   *
   * @throws IllegalArgumentException if {@code x <= 0}
   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
   *         is not a power of ten
   */
  @SuppressWarnings("fallthrough")
  public static int log10(BigInteger x, RoundingMode mode) {
    checkPositive("x", x);
    if (fitsInLong(x)) {
      return LongMath.log10(x.longValue(), mode);
    }

    // capacity of 10 suffices for all x <= 10^(2^10).
    List<BigInteger> powersOf10 = new ArrayList<BigInteger>(10);
    BigInteger powerOf10 = BigInteger.TEN;
    while (x.compareTo(powerOf10) >= 0) {
      powersOf10.add(powerOf10);
      powerOf10 = powerOf10.pow(2);
    }
    BigInteger floorPow = BigInteger.ONE;
    int floorLog = 0;
    for (int i = powersOf10.size() - 1; i >= 0; i--) {
      BigInteger powOf10 = powersOf10.get(i);
      floorLog *= 2;
      BigInteger tenPow = powOf10.multiply(floorPow);
      if (x.compareTo(tenPow) >= 0) {
        floorPow = tenPow;
        floorLog++;
      }
    }
    switch (mode) {
      case UNNECESSARY:
        checkRoundingUnnecessary(floorPow.equals(x));
        // fall through
      case FLOOR:
      case DOWN:
        return floorLog;

      case CEILING:
      case UP:
        return floorPow.equals(x) ? floorLog : floorLog + 1;

      case HALF_DOWN:
      case HALF_UP:
      case HALF_EVEN:
        // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
        BigInteger x2 = x.pow(2);
        BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
        return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
      default:
        throw new AssertionError();
    }
  }

  /**
   * Returns the square root of {@code x}, rounded with the specified rounding mode.
   *
   * @throws IllegalArgumentException if {@code x < 0}
   * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
   *         {@code sqrt(x)} is not an integer
   */
  @SuppressWarnings("fallthrough")
  public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
    checkNonNegative("x", x);
    if (fitsInLong(x)) {
      return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
    }
    BigInteger sqrtFloor = sqrtFloor(x);
    switch (mode) {
      case UNNECESSARY:
        checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
      case FLOOR:
      case DOWN:
        return sqrtFloor;
      case CEILING:
      case UP:
        return sqrtFloor.pow(2).equals(x) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
      case HALF_DOWN:
      case HALF_UP:
      case HALF_EVEN:
        BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
        /*
         * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
         * x and halfSquare are integers, this is equivalent to testing whether or not x <=
         * halfSquare.
         */
        return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
      default:
        throw new AssertionError();
    }
  }

  private static BigInteger sqrtFloor(BigInteger x) {
    /*
     * Adapted from Hacker's Delight, Figure 11-1.
     *
     * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
     * then we can get a double approximation of the square root. Then, we iteratively improve this
     * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
     * This iteration has the following two properties:
     *
     * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
     * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
     * and the arithmetic mean is always higher than the geometric mean.
     *
     * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
     * with each iteration, so this algorithm takes O(log(digits)) iterations.
     *
     * We start out with a double-precision approximation, which may be higher or lower than the
     * true value. Therefore, we perform at least one Newton iteration to get a guess that's
     * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
     */
    BigInteger sqrt0;
    int log2 = log2(x, FLOOR);
    if(log2 < DoubleUtils.MAX_DOUBLE_EXPONENT) {
      sqrt0 = sqrtApproxWithDoubles(x);
    } else {
      int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
      /*
       * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
       * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
       */
      sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
    }
    BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
    if (sqrt0.equals(sqrt1)) {
      return sqrt0;
    }
    do {
      sqrt0 = sqrt1;
      sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
    } while (sqrt1.compareTo(sqrt0) < 0);
    return sqrt0;
  }

  private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
    return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
  }

  /**
   * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
   * {@code RoundingMode}.
   *
   * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
   *         is not an integer multiple of {@code b}
   */
  public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode){
    BigDecimal pDec = new BigDecimal(p);
    BigDecimal qDec = new BigDecimal(q);
    return pDec.divide(qDec, 0, mode).toBigIntegerExact();
  }

  /**
   * Returns {@code n!}, that is, the product of the first {@code n} positive
   * integers, or {@code 1} if {@code n == 0}.
   *
   * <p><b>Warning</b>: the result takes <i>O(n log n)</i> space, so use cautiously.
   *
   * <p>This uses an efficient binary recursive algorithm to compute the factorial
   * with balanced multiplies.  It also removes all the 2s from the intermediate
   * products (shifting them back in at the end).
   *
   * @throws IllegalArgumentException if {@code n < 0}
   */
  public static BigInteger factorial(int n) {
    checkNonNegative("n", n);

    // If the factorial is small enough, just use LongMath to do it.
    if (n < LongMath.FACTORIALS.length) {
      return BigInteger.valueOf(LongMath.FACTORIALS[n]);
    }

    // Pre-allocate space for our list of intermediate BigIntegers.
    int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
    ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);

    // Start from the pre-computed maximum long factorial.
    int startingNumber = LongMath.FACTORIALS.length;
    long product = LongMath.FACTORIALS[startingNumber - 1];
    // Strip off 2s from this value.
    int shift = Long.numberOfTrailingZeros(product);
    product >>= shift;

    // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
    int productBits = LongMath.log2(product, FLOOR) + 1;
    int bits = LongMath.log2(startingNumber, FLOOR) + 1;
    // Check for the next power of two boundary, to save us a CLZ operation.
    int nextPowerOfTwo = 1 << (bits - 1);

    // Iteratively multiply the longs as big as they can go.
    for (long num = startingNumber; num <= n; num++) {
      // Check to see if the floor(log2(num)) + 1 has changed.
      if ((num & nextPowerOfTwo) != 0) {
        nextPowerOfTwo <<= 1;
        bits++;
      }
      // Get rid of the 2s in num.
      int tz = Long.numberOfTrailingZeros(num);
      long normalizedNum = num >> tz;
      shift += tz;
      // Adjust floor(log2(num)) + 1.
      int normalizedBits = bits - tz;
      // If it won't fit in a long, then we store off the intermediate product.
      if (normalizedBits + productBits >= Long.SIZE) {
        bignums.add(BigInteger.valueOf(product));
        product = 1;
        productBits = 0;
      }
      product *= normalizedNum;
      productBits = LongMath.log2(product, FLOOR) + 1;
    }
    // Check for leftovers.
    if (product > 1) {
      bignums.add(BigInteger.valueOf(product));
    }
    // Efficiently multiply all the intermediate products together.
    return listProduct(bignums).shiftLeft(shift);
  }

  static BigInteger listProduct(List<BigInteger> nums) {
    return listProduct(nums, 0, nums.size());
  }

  static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
    switch (end - start) {
      case 0:
        return BigInteger.ONE;
      case 1:
        return nums.get(start);
      case 2:
        return nums.get(start).multiply(nums.get(start + 1));
      case 3:
        return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
      default:
        // Otherwise, split the list in half and recursively do this.
        int m = (end + start) >>> 1;
        return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
    }
  }

 /**
   * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
   * {@code k}, that is, {@code n! / (k! (n - k)!)}.
   *
   * <p><b>Warning</b>: the result can take as much as <i>O(k log n)</i> space.
   *
   * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
   */
  public static BigInteger binomial(int n, int k) {
    checkNonNegative("n", n);
    checkNonNegative("k", k);
    checkArgument(k <= n, "k (%s) > n (%s)", k, n);
    if (k > (n >> 1)) {
      k = n - k;
    }
    if (k < LongMath.BIGGEST_BINOMIALS.length && n <= LongMath.BIGGEST_BINOMIALS[k]) {
      return BigInteger.valueOf(LongMath.binomial(n, k));
    }
    BigInteger result = BigInteger.ONE;
    for (int i = 0; i < k; i++) {
      result = result.multiply(BigInteger.valueOf(n - i));
      result = result.divide(BigInteger.valueOf(i + 1));
    }
    return result;
  }

  // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
  static boolean fitsInLong(BigInteger x) {
    return x.bitLength() <= Long.SIZE - 1;
  }

  private BigIntegerMath() {}
}