diff options
Diffstat (limited to 'guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java')
-rw-r--r-- | guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java | 1066 |
1 files changed, 329 insertions, 737 deletions
diff --git a/guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java b/guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java index b45d127..622454d 100644 --- a/guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java +++ b/guava-gwt/src-super/com/google/common/collect/super/com/google/common/collect/TreeMultiset.java @@ -17,217 +17,168 @@ package com.google.common.collect; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.GwtCompatible; -import com.google.common.base.Objects; -import com.google.common.primitives.Ints; +import static com.google.common.collect.BstSide.LEFT; +import static com.google.common.collect.BstSide.RIGHT; import java.io.Serializable; import java.util.Comparator; import java.util.ConcurrentModificationException; import java.util.Iterator; -import java.util.NoSuchElementException; import javax.annotation.Nullable; +import com.google.common.annotations.GwtCompatible; +import com.google.common.primitives.Ints; + /** - * A multiset which maintains the ordering of its elements, according to either their natural order - * or an explicit {@link Comparator}. In all cases, this implementation uses - * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to - * determine equivalence of instances. + * A multiset which maintains the ordering of its elements, according to either + * their natural order or an explicit {@link Comparator}. In all cases, this + * implementation uses {@link Comparable#compareTo} or {@link + * Comparator#compare} instead of {@link Object#equals} to determine + * equivalence of instances. * - * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the - * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the - * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}. - * - * <p>See the Guava User Guide article on <a href= - * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset"> - * {@code Multiset}</a>. + * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as + * explained by the {@link Comparable} class specification. Otherwise, the + * resulting multiset will violate the {@link java.util.Collection} contract, + * which is specified in terms of {@link Object#equals}. * * @author Louis Wasserman * @author Jared Levy * @since 2.0 (imported from Google Collections Library) */ @GwtCompatible(emulated = true) -public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable { +public final class TreeMultiset<E> extends AbstractSortedMultiset<E> + implements Serializable { /** - * Creates a new, empty multiset, sorted according to the elements' natural order. All elements - * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all - * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a - * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the - * user attempts to add an element to the multiset that violates this constraint (for example, - * the user attempts to add a string element to a set whose elements are integers), the - * {@code add(Object)} call will throw a {@code ClassCastException}. + * Creates a new, empty multiset, sorted according to the elements' natural + * order. All elements inserted into the multiset must implement the + * {@code Comparable} interface. Furthermore, all such elements must be + * <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a + * {@code ClassCastException} for any elements {@code e1} and {@code e2} in + * the multiset. If the user attempts to add an element to the multiset that + * violates this constraint (for example, the user attempts to add a string + * element to a set whose elements are integers), the {@code add(Object)} + * call will throw a {@code ClassCastException}. * - * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific - * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. + * <p>The type specification is {@code <E extends Comparable>}, instead of the + * more specific {@code <E extends Comparable<? super E>>}, to support + * classes defined without generics. */ public static <E extends Comparable> TreeMultiset<E> create() { return new TreeMultiset<E>(Ordering.natural()); } /** - * Creates a new, empty multiset, sorted according to the specified comparator. All elements - * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator: - * {@code comparator.compare(e1, - * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in - * the multiset. If the user attempts to add an element to the multiset that violates this - * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}. + * Creates a new, empty multiset, sorted according to the specified + * comparator. All elements inserted into the multiset must be <i>mutually + * comparable</i> by the specified comparator: {@code comparator.compare(e1, + * e2)} must not throw a {@code ClassCastException} for any elements {@code + * e1} and {@code e2} in the multiset. If the user attempts to add an element + * to the multiset that violates this constraint, the {@code add(Object)} call + * will throw a {@code ClassCastException}. * - * @param comparator - * the comparator that will be used to sort this multiset. A null value indicates that - * the elements' <i>natural ordering</i> should be used. + * @param comparator the comparator that will be used to sort this multiset. A + * null value indicates that the elements' <i>natural ordering</i> should + * be used. */ @SuppressWarnings("unchecked") - public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) { + public static <E> TreeMultiset<E> create( + @Nullable Comparator<? super E> comparator) { return (comparator == null) - ? new TreeMultiset<E>((Comparator) Ordering.natural()) - : new TreeMultiset<E>(comparator); + ? new TreeMultiset<E>((Comparator) Ordering.natural()) + : new TreeMultiset<E>(comparator); } /** - * Creates an empty multiset containing the given initial elements, sorted according to the - * elements' natural order. + * Creates an empty multiset containing the given initial elements, sorted + * according to the elements' natural order. * - * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}. + * <p>This implementation is highly efficient when {@code elements} is itself + * a {@link Multiset}. * - * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific - * {@code <E extends Comparable<? super E>>}, to support classes defined without generics. + * <p>The type specification is {@code <E extends Comparable>}, instead of the + * more specific {@code <E extends Comparable<? super E>>}, to support + * classes defined without generics. */ - public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) { + public static <E extends Comparable> TreeMultiset<E> create( + Iterable<? extends E> elements) { TreeMultiset<E> multiset = create(); Iterables.addAll(multiset, elements); return multiset; } - private final transient Reference<AvlNode<E>> rootReference; - private final transient GeneralRange<E> range; - private final transient AvlNode<E> header; - - TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) { - super(range.comparator()); - this.rootReference = rootReference; - this.range = range; - this.header = endLink; + /** + * Returns an iterator over the elements contained in this collection. + */ + @Override + public Iterator<E> iterator() { + // Needed to avoid Javadoc bug. + return super.iterator(); } - TreeMultiset(Comparator<? super E> comparator) { + private TreeMultiset(Comparator<? super E> comparator) { super(comparator); this.range = GeneralRange.all(comparator); - this.header = new AvlNode<E>(null, 1); - successor(header, header); - this.rootReference = new Reference<AvlNode<E>>(); + this.rootReference = new Reference<Node<E>>(); } - /** - * A function which can be summed across a subtree. - */ - private enum Aggregate { - SIZE { - @Override - int nodeAggregate(AvlNode<?> node) { - return node.elemCount; - } + private TreeMultiset(GeneralRange<E> range, Reference<Node<E>> root) { + super(range.comparator()); + this.range = range; + this.rootReference = root; + } - @Override - long treeAggregate(@Nullable AvlNode<?> root) { - return (root == null) ? 0 : root.totalCount; - } - }, - DISTINCT { - @Override - int nodeAggregate(AvlNode<?> node) { - return 1; - } + @SuppressWarnings("unchecked") + E checkElement(Object o) { + return (E) o; + } - @Override - long treeAggregate(@Nullable AvlNode<?> root) { - return (root == null) ? 0 : root.distinctElements; - } - }; - abstract int nodeAggregate(AvlNode<?> node); + private transient final GeneralRange<E> range; - abstract long treeAggregate(@Nullable AvlNode<?> root); - } + private transient final Reference<Node<E>> rootReference; - private long aggregateForEntries(Aggregate aggr) { - AvlNode<E> root = rootReference.get(); - long total = aggr.treeAggregate(root); - if (range.hasLowerBound()) { - total -= aggregateBelowRange(aggr, root); - } - if (range.hasUpperBound()) { - total -= aggregateAboveRange(aggr, root); - } - return total; - } + static final class Reference<T> { + T value; - private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) { - if (node == null) { - return 0; - } - int cmp = comparator().compare(range.getLowerEndpoint(), node.elem); - if (cmp < 0) { - return aggregateBelowRange(aggr, node.left); - } else if (cmp == 0) { - switch (range.getLowerBoundType()) { - case OPEN: - return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left); - case CLOSED: - return aggr.treeAggregate(node.left); - default: - throw new AssertionError(); - } - } else { - return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node) - + aggregateBelowRange(aggr, node.right); - } - } + public Reference() {} - private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) { - if (node == null) { - return 0; + public T get() { + return value; } - int cmp = comparator().compare(range.getUpperEndpoint(), node.elem); - if (cmp > 0) { - return aggregateAboveRange(aggr, node.right); - } else if (cmp == 0) { - switch (range.getUpperBoundType()) { - case OPEN: - return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right); - case CLOSED: - return aggr.treeAggregate(node.right); - default: - throw new AssertionError(); + + public boolean compareAndSet(T expected, T newValue) { + if (value == expected) { + value = newValue; + return true; } - } else { - return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node) - + aggregateAboveRange(aggr, node.left); + return false; } } @Override - public int size() { - return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE)); + int distinctElements() { + Node<E> root = rootReference.get(); + return Ints.checkedCast(BstRangeOps.totalInRange(distinctAggregate(), range, root)); } @Override - int distinctElements() { - return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT)); + public int size() { + Node<E> root = rootReference.get(); + return Ints.saturatedCast(BstRangeOps.totalInRange(sizeAggregate(), range, root)); } @Override public int count(@Nullable Object element) { try { - @SuppressWarnings("unchecked") - E e = (E) element; - AvlNode<E> root = rootReference.get(); - if (!range.contains(e) || root == null) { - return 0; + E e = checkElement(element); + if (range.contains(e)) { + Node<E> node = BstOperations.seek(comparator(), rootReference.get(), e); + return countOrZero(node); } - return root.count(comparator(), e); + return 0; } catch (ClassCastException e) { return 0; } catch (NullPointerException e) { @@ -235,713 +186,354 @@ public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements } } + private int mutate(@Nullable E e, MultisetModifier modifier) { + BstMutationRule<E, Node<E>> mutationRule = BstMutationRule.createRule( + modifier, + BstCountBasedBalancePolicies. + <E, Node<E>>singleRebalancePolicy(distinctAggregate()), + nodeFactory()); + BstMutationResult<E, Node<E>> mutationResult = + BstOperations.mutate(comparator(), mutationRule, rootReference.get(), e); + if (!rootReference.compareAndSet( + mutationResult.getOriginalRoot(), mutationResult.getChangedRoot())) { + throw new ConcurrentModificationException(); + } + Node<E> original = mutationResult.getOriginalTarget(); + return countOrZero(original); + } + @Override - public int add(@Nullable E element, int occurrences) { - checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences); + public int add(E element, int occurrences) { + checkElement(element); if (occurrences == 0) { return count(element); } checkArgument(range.contains(element)); - AvlNode<E> root = rootReference.get(); - if (root == null) { - comparator().compare(element, element); - AvlNode<E> newRoot = new AvlNode<E>(element, occurrences); - successor(header, newRoot, header); - rootReference.checkAndSet(root, newRoot); - return 0; - } - int[] result = new int[1]; // used as a mutable int reference to hold result - AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result); - rootReference.checkAndSet(root, newRoot); - return result[0]; + return mutate(element, new AddModifier(occurrences)); } @Override public int remove(@Nullable Object element, int occurrences) { - checkArgument(occurrences >= 0, "occurrences must be >= 0 but was %s", occurrences); - if (occurrences == 0) { + if (element == null) { + return 0; + } else if (occurrences == 0) { return count(element); } - AvlNode<E> root = rootReference.get(); - int[] result = new int[1]; // used as a mutable int reference to hold result - AvlNode<E> newRoot; try { - @SuppressWarnings("unchecked") - E e = (E) element; - if (!range.contains(e) || root == null) { - return 0; - } - newRoot = root.remove(comparator(), e, occurrences, result); + E e = checkElement(element); + return range.contains(e) ? mutate(e, new RemoveModifier(occurrences)) : 0; } catch (ClassCastException e) { return 0; - } catch (NullPointerException e) { - return 0; } - rootReference.checkAndSet(root, newRoot); - return result[0]; } @Override - public int setCount(@Nullable E element, int count) { - checkArgument(count >= 0); - if (!range.contains(element)) { - checkArgument(count == 0); - return 0; - } - - AvlNode<E> root = rootReference.get(); - if (root == null) { - if (count > 0) { - add(element, count); - } - return 0; - } - int[] result = new int[1]; // used as a mutable int reference to hold result - AvlNode<E> newRoot = root.setCount(comparator(), element, count, result); - rootReference.checkAndSet(root, newRoot); - return result[0]; + public boolean setCount(E element, int oldCount, int newCount) { + checkElement(element); + checkArgument(range.contains(element)); + return mutate(element, new ConditionalSetCountModifier(oldCount, newCount)) + == oldCount; } @Override - public boolean setCount(@Nullable E element, int oldCount, int newCount) { - checkArgument(newCount >= 0); - checkArgument(oldCount >= 0); + public int setCount(E element, int count) { + checkElement(element); checkArgument(range.contains(element)); - - AvlNode<E> root = rootReference.get(); - if (root == null) { - if (oldCount == 0) { - if (newCount > 0) { - add(element, newCount); - } - return true; - } else { - return false; - } - } - int[] result = new int[1]; // used as a mutable int reference to hold result - AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result); - rootReference.checkAndSet(root, newRoot); - return result[0] == oldCount; + return mutate(element, new SetCountModifier(count)); } - private Entry<E> wrapEntry(final AvlNode<E> baseEntry) { - return new Multisets.AbstractEntry<E>() { - @Override - public E getElement() { - return baseEntry.getElement(); - } - - @Override - public int getCount() { - int result = baseEntry.getCount(); - if (result == 0) { - return count(getElement()); - } else { - return result; - } - } - }; + private BstPathFactory<Node<E>, BstInOrderPath<Node<E>>> pathFactory() { + return BstInOrderPath.inOrderFactory(); } - /** - * Returns the first node in the tree that is in range. - */ - @Nullable private AvlNode<E> firstNode() { - AvlNode<E> root = rootReference.get(); - if (root == null) { - return null; - } - AvlNode<E> node; - if (range.hasLowerBound()) { - E endpoint = range.getLowerEndpoint(); - node = rootReference.get().ceiling(comparator(), endpoint); - if (node == null) { - return null; - } - if (range.getLowerBoundType() == BoundType.OPEN - && comparator().compare(endpoint, node.getElement()) == 0) { - node = node.succ; - } - } else { - node = header.succ; - } - return (node == header || !range.contains(node.getElement())) ? null : node; + @Override + Iterator<Entry<E>> entryIterator() { + Node<E> root = rootReference.get(); + final BstInOrderPath<Node<E>> startingPath = + BstRangeOps.furthestPath(range, LEFT, pathFactory(), root); + return iteratorInDirection(startingPath, RIGHT); } - @Nullable private AvlNode<E> lastNode() { - AvlNode<E> root = rootReference.get(); - if (root == null) { - return null; - } - AvlNode<E> node; - if (range.hasUpperBound()) { - E endpoint = range.getUpperEndpoint(); - node = rootReference.get().floor(comparator(), endpoint); - if (node == null) { - return null; - } - if (range.getUpperBoundType() == BoundType.OPEN - && comparator().compare(endpoint, node.getElement()) == 0) { - node = node.pred; - } - } else { - node = header.pred; - } - return (node == header || !range.contains(node.getElement())) ? null : node; + @Override + Iterator<Entry<E>> descendingEntryIterator() { + Node<E> root = rootReference.get(); + final BstInOrderPath<Node<E>> startingPath = + BstRangeOps.furthestPath(range, RIGHT, pathFactory(), root); + return iteratorInDirection(startingPath, LEFT); } - @Override - Iterator<Entry<E>> entryIterator() { + private Iterator<Entry<E>> iteratorInDirection( + @Nullable BstInOrderPath<Node<E>> start, final BstSide direction) { + final Iterator<BstInOrderPath<Node<E>>> pathIterator = + new AbstractLinkedIterator<BstInOrderPath<Node<E>>>(start) { + @Override + protected BstInOrderPath<Node<E>> computeNext(BstInOrderPath<Node<E>> previous) { + if (!previous.hasNext(direction)) { + return null; + } + BstInOrderPath<Node<E>> next = previous.next(direction); + // TODO(user): only check against one side + return range.contains(next.getTip().getKey()) ? next : null; + } + }; return new Iterator<Entry<E>>() { - AvlNode<E> current = firstNode(); - Entry<E> prevEntry; + E toRemove = null; @Override public boolean hasNext() { - if (current == null) { - return false; - } else if (range.tooHigh(current.getElement())) { - current = null; - return false; - } else { - return true; - } + return pathIterator.hasNext(); } @Override public Entry<E> next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - Entry<E> result = wrapEntry(current); - prevEntry = result; - if (current.succ == header) { - current = null; - } else { - current = current.succ; - } - return result; + BstInOrderPath<Node<E>> path = pathIterator.next(); + return new LiveEntry( + toRemove = path.getTip().getKey(), path.getTip().elemCount()); } @Override public void remove() { - checkState(prevEntry != null); - setCount(prevEntry.getElement(), 0); - prevEntry = null; + checkState(toRemove != null); + setCount(toRemove, 0); + toRemove = null; } }; } - @Override - Iterator<Entry<E>> descendingEntryIterator() { - return new Iterator<Entry<E>>() { - AvlNode<E> current = lastNode(); - Entry<E> prevEntry = null; + class LiveEntry extends Multisets.AbstractEntry<E> { + private Node<E> expectedRoot; + private final E element; + private int count; - @Override - public boolean hasNext() { - if (current == null) { - return false; - } else if (range.tooLow(current.getElement())) { - current = null; - return false; - } else { - return true; - } - } + private LiveEntry(E element, int count) { + this.expectedRoot = rootReference.get(); + this.element = element; + this.count = count; + } - @Override - public Entry<E> next() { - if (!hasNext()) { - throw new NoSuchElementException(); - } - Entry<E> result = wrapEntry(current); - prevEntry = result; - if (current.pred == header) { - current = null; - } else { - current = current.pred; - } - return result; - } + @Override + public E getElement() { + return element; + } - @Override - public void remove() { - checkState(prevEntry != null); - setCount(prevEntry.getElement(), 0); - prevEntry = null; + @Override + public int getCount() { + if (rootReference.get() == expectedRoot) { + return count; + } else { + // check for updates + expectedRoot = rootReference.get(); + return count = TreeMultiset.this.count(element); } - }; + } } @Override - public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) { - return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo( - comparator(), - upperBound, - boundType)), header); + public void clear() { + Node<E> root = rootReference.get(); + Node<E> cleared = BstRangeOps.minusRange(range, + BstCountBasedBalancePolicies.<E, Node<E>>fullRebalancePolicy(distinctAggregate()), + nodeFactory(), root); + if (!rootReference.compareAndSet(root, cleared)) { + throw new ConcurrentModificationException(); + } } @Override - public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) { - return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo( - comparator(), - lowerBound, - boundType)), header); + public SortedMultiset<E> headMultiset(E upperBound, BoundType boundType) { + checkNotNull(upperBound); + return new TreeMultiset<E>( + range.intersect(GeneralRange.upTo(comparator, upperBound, boundType)), rootReference); } - static int distinctElements(@Nullable AvlNode<?> node) { - return (node == null) ? 0 : node.distinctElements; + @Override + public SortedMultiset<E> tailMultiset(E lowerBound, BoundType boundType) { + checkNotNull(lowerBound); + return new TreeMultiset<E>( + range.intersect(GeneralRange.downTo(comparator, lowerBound, boundType)), rootReference); } - private static final class Reference<T> { - @Nullable private T value; - - @Nullable public T get() { - return value; - } - - public void checkAndSet(@Nullable T expected, T newValue) { - if (value != expected) { - throw new ConcurrentModificationException(); - } - value = newValue; - } + /** + * {@inheritDoc} + * + * @since 11.0 + */ + @Override + public Comparator<? super E> comparator() { + return super.comparator(); } - private static final class AvlNode<E> extends Multisets.AbstractEntry<E> { - @Nullable private final E elem; + private static final class Node<E> extends BstNode<E, Node<E>> implements Serializable { + private final long size; + private final int distinct; - // elemCount is 0 iff this node has been deleted. - private int elemCount; - - private int distinctElements; - private long totalCount; - private int height; - private AvlNode<E> left; - private AvlNode<E> right; - private AvlNode<E> pred; - private AvlNode<E> succ; - - AvlNode(@Nullable E elem, int elemCount) { + private Node(E key, int elemCount, @Nullable Node<E> left, + @Nullable Node<E> right) { + super(key, left, right); checkArgument(elemCount > 0); - this.elem = elem; - this.elemCount = elemCount; - this.totalCount = elemCount; - this.distinctElements = 1; - this.height = 1; - this.left = null; - this.right = null; - } - - public int count(Comparator<? super E> comparator, E e) { - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - return (left == null) ? 0 : left.count(comparator, e); - } else if (cmp > 0) { - return (right == null) ? 0 : right.count(comparator, e); - } else { - return elemCount; - } - } - - private AvlNode<E> addRightChild(E e, int count) { - right = new AvlNode<E>(e, count); - successor(this, right, succ); - height = Math.max(2, height); - distinctElements++; - totalCount += count; - return this; + this.size = (long) elemCount + sizeOrZero(left) + sizeOrZero(right); + this.distinct = 1 + distinctOrZero(left) + distinctOrZero(right); } - private AvlNode<E> addLeftChild(E e, int count) { - left = new AvlNode<E>(e, count); - successor(pred, left, this); - height = Math.max(2, height); - distinctElements++; - totalCount += count; - return this; + int elemCount() { + long result = size - sizeOrZero(childOrNull(LEFT)) + - sizeOrZero(childOrNull(RIGHT)); + return Ints.checkedCast(result); } - AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) { - /* - * It speeds things up considerably to unconditionally add count to totalCount here, - * but that destroys failure atomicity in the case of count overflow. =( - */ - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - AvlNode<E> initLeft = left; - if (initLeft == null) { - result[0] = 0; - return addLeftChild(e, count); - } - int initHeight = initLeft.height; - - left = initLeft.add(comparator, e, count, result); - if (result[0] == 0) { - distinctElements++; - } - this.totalCount += count; - return (left.height == initHeight) ? this : rebalance(); - } else if (cmp > 0) { - AvlNode<E> initRight = right; - if (initRight == null) { - result[0] = 0; - return addRightChild(e, count); - } - int initHeight = initRight.height; - - right = initRight.add(comparator, e, count, result); - if (result[0] == 0) { - distinctElements++; - } - this.totalCount += count; - return (right.height == initHeight) ? this : rebalance(); - } - - // adding count to me! No rebalance possible. - result[0] = elemCount; - long resultCount = (long) elemCount + count; - checkArgument(resultCount <= Integer.MAX_VALUE); - this.elemCount += count; - this.totalCount += count; - return this; + private Node(E key, int elemCount) { + this(key, elemCount, null, null); } - AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) { - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - AvlNode<E> initLeft = left; - if (initLeft == null) { - result[0] = 0; - return this; - } + private static final long serialVersionUID = 0; + } - left = initLeft.remove(comparator, e, count, result); + private static long sizeOrZero(@Nullable Node<?> node) { + return (node == null) ? 0 : node.size; + } - if (result[0] > 0) { - if (count >= result[0]) { - this.distinctElements--; - this.totalCount -= result[0]; - } else { - this.totalCount -= count; - } - } - return (result[0] == 0) ? this : rebalance(); - } else if (cmp > 0) { - AvlNode<E> initRight = right; - if (initRight == null) { - result[0] = 0; - return this; - } + private static int distinctOrZero(@Nullable Node<?> node) { + return (node == null) ? 0 : node.distinct; + } - right = initRight.remove(comparator, e, count, result); + private static int countOrZero(@Nullable Node<?> entry) { + return (entry == null) ? 0 : entry.elemCount(); + } - if (result[0] > 0) { - if (count >= result[0]) { - this.distinctElements--; - this.totalCount -= result[0]; - } else { - this.totalCount -= count; - } - } - return rebalance(); - } + @SuppressWarnings("unchecked") + private BstAggregate<Node<E>> distinctAggregate() { + return (BstAggregate) DISTINCT_AGGREGATE; + } - // removing count from me! - result[0] = elemCount; - if (count >= elemCount) { - return deleteMe(); - } else { - this.elemCount -= count; - this.totalCount -= count; - return this; - } + private static final BstAggregate<Node<Object>> DISTINCT_AGGREGATE = + new BstAggregate<Node<Object>>() { + @Override + public int entryValue(Node<Object> entry) { + return 1; } - AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) { - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - AvlNode<E> initLeft = left; - if (initLeft == null) { - result[0] = 0; - return (count > 0) ? addLeftChild(e, count) : this; - } - - left = initLeft.setCount(comparator, e, count, result); - - if (count == 0 && result[0] != 0) { - this.distinctElements--; - } else if (count > 0 && result[0] == 0) { - this.distinctElements++; - } - - this.totalCount += count - result[0]; - return rebalance(); - } else if (cmp > 0) { - AvlNode<E> initRight = right; - if (initRight == null) { - result[0] = 0; - return (count > 0) ? addRightChild(e, count) : this; - } - - right = initRight.setCount(comparator, e, count, result); - - if (count == 0 && result[0] != 0) { - this.distinctElements--; - } else if (count > 0 && result[0] == 0) { - this.distinctElements++; - } - - this.totalCount += count - result[0]; - return rebalance(); - } - - // setting my count - result[0] = elemCount; - if (count == 0) { - return deleteMe(); - } - this.totalCount += count - elemCount; - this.elemCount = count; - return this; + @Override + public long treeValue(@Nullable Node<Object> tree) { + return distinctOrZero(tree); } + }; - AvlNode<E> setCount( - Comparator<? super E> comparator, - @Nullable E e, - int expectedCount, - int newCount, - int[] result) { - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - AvlNode<E> initLeft = left; - if (initLeft == null) { - result[0] = 0; - if (expectedCount == 0 && newCount > 0) { - return addLeftChild(e, newCount); - } - return this; - } - - left = initLeft.setCount(comparator, e, expectedCount, newCount, result); + @SuppressWarnings("unchecked") + private BstAggregate<Node<E>> sizeAggregate() { + return (BstAggregate) SIZE_AGGREGATE; + } - if (result[0] == expectedCount) { - if (newCount == 0 && result[0] != 0) { - this.distinctElements--; - } else if (newCount > 0 && result[0] == 0) { - this.distinctElements++; - } - this.totalCount += newCount - result[0]; - } - return rebalance(); - } else if (cmp > 0) { - AvlNode<E> initRight = right; - if (initRight == null) { - result[0] = 0; - if (expectedCount == 0 && newCount > 0) { - return addRightChild(e, newCount); - } - return this; + private static final BstAggregate<Node<Object>> SIZE_AGGREGATE = + new BstAggregate<Node<Object>>() { + @Override + public int entryValue(Node<Object> entry) { + return entry.elemCount(); } - right = initRight.setCount(comparator, e, expectedCount, newCount, result); - - if (result[0] == expectedCount) { - if (newCount == 0 && result[0] != 0) { - this.distinctElements--; - } else if (newCount > 0 && result[0] == 0) { - this.distinctElements++; - } - this.totalCount += newCount - result[0]; + @Override + public long treeValue(@Nullable Node<Object> tree) { + return sizeOrZero(tree); } - return rebalance(); - } + }; - // setting my count - result[0] = elemCount; - if (expectedCount == elemCount) { - if (newCount == 0) { - return deleteMe(); - } - this.totalCount += newCount - elemCount; - this.elemCount = newCount; - } - return this; - } + @SuppressWarnings("unchecked") + private BstNodeFactory<Node<E>> nodeFactory() { + return (BstNodeFactory) NODE_FACTORY; + } - private AvlNode<E> deleteMe() { - int oldElemCount = this.elemCount; - this.elemCount = 0; - successor(pred, succ); - if (left == null) { - return right; - } else if (right == null) { - return left; - } else if (left.height >= right.height) { - AvlNode<E> newTop = pred; - // newTop is the maximum node in my left subtree - newTop.left = left.removeMax(newTop); - newTop.right = right; - newTop.distinctElements = distinctElements - 1; - newTop.totalCount = totalCount - oldElemCount; - return newTop.rebalance(); - } else { - AvlNode<E> newTop = succ; - newTop.right = right.removeMin(newTop); - newTop.left = left; - newTop.distinctElements = distinctElements - 1; - newTop.totalCount = totalCount - oldElemCount; - return newTop.rebalance(); - } - } + private static final BstNodeFactory<Node<Object>> NODE_FACTORY = + new BstNodeFactory<Node<Object>>() { + @Override + public Node<Object> createNode(Node<Object> source, @Nullable Node<Object> left, + @Nullable Node<Object> right) { + return new Node<Object>(source.getKey(), source.elemCount(), left, right); + } + }; - // Removes the minimum node from this subtree to be reused elsewhere - private AvlNode<E> removeMin(AvlNode<E> node) { - if (left == null) { - return right; - } else { - left = left.removeMin(node); - distinctElements--; - totalCount -= node.elemCount; - return rebalance(); - } - } + private abstract class MultisetModifier implements BstModifier<E, Node<E>> { + abstract int newCount(int oldCount); - // Removes the maximum node from this subtree to be reused elsewhere - private AvlNode<E> removeMax(AvlNode<E> node) { - if (right == null) { - return left; + @Nullable + @Override + public BstModificationResult<Node<E>> modify(E key, @Nullable Node<E> originalEntry) { + int oldCount = countOrZero(originalEntry); + int newCount = newCount(oldCount); + if (oldCount == newCount) { + return BstModificationResult.identity(originalEntry); + } else if (newCount == 0) { + return BstModificationResult.rebalancingChange(originalEntry, null); + } else if (oldCount == 0) { + return BstModificationResult.rebalancingChange(null, new Node<E>(key, newCount)); } else { - right = right.removeMax(node); - distinctElements--; - totalCount -= node.elemCount; - return rebalance(); + return BstModificationResult.rebuildingChange(originalEntry, + new Node<E>(originalEntry.getKey(), newCount)); } } + } - private void recomputeMultiset() { - this.distinctElements = 1 + TreeMultiset.distinctElements(left) - + TreeMultiset.distinctElements(right); - this.totalCount = elemCount + totalCount(left) + totalCount(right); - } - - private void recomputeHeight() { - this.height = 1 + Math.max(height(left), height(right)); - } - - private void recompute() { - recomputeMultiset(); - recomputeHeight(); - } + private final class AddModifier extends MultisetModifier { + private final int countToAdd; - private AvlNode<E> rebalance() { - switch (balanceFactor()) { - case -2: - if (right.balanceFactor() > 0) { - right = right.rotateRight(); - } - return rotateLeft(); - case 2: - if (left.balanceFactor() < 0) { - left = left.rotateLeft(); - } - return rotateRight(); - default: - recomputeHeight(); - return this; - } + private AddModifier(int countToAdd) { + checkArgument(countToAdd > 0); + this.countToAdd = countToAdd; } - private int balanceFactor() { - return height(left) - height(right); + @Override + int newCount(int oldCount) { + checkArgument(countToAdd <= Integer.MAX_VALUE - oldCount, "Cannot add this many elements"); + return oldCount + countToAdd; } + } - private AvlNode<E> rotateLeft() { - checkState(right != null); - AvlNode<E> newTop = right; - this.right = newTop.left; - newTop.left = this; - newTop.totalCount = this.totalCount; - newTop.distinctElements = this.distinctElements; - this.recompute(); - newTop.recomputeHeight(); - return newTop; - } + private final class RemoveModifier extends MultisetModifier { + private final int countToRemove; - private AvlNode<E> rotateRight() { - checkState(left != null); - AvlNode<E> newTop = left; - this.left = newTop.right; - newTop.right = this; - newTop.totalCount = this.totalCount; - newTop.distinctElements = this.distinctElements; - this.recompute(); - newTop.recomputeHeight(); - return newTop; + private RemoveModifier(int countToRemove) { + checkArgument(countToRemove > 0); + this.countToRemove = countToRemove; } - private static long totalCount(@Nullable AvlNode<?> node) { - return (node == null) ? 0 : node.totalCount; + @Override + int newCount(int oldCount) { + return Math.max(0, oldCount - countToRemove); } + } - private static int height(@Nullable AvlNode<?> node) { - return (node == null) ? 0 : node.height; - } + private final class SetCountModifier extends MultisetModifier { + private final int countToSet; - @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) { - int cmp = comparator.compare(e, elem); - if (cmp < 0) { - return (left == null) ? this : Objects.firstNonNull(left.ceiling(comparator, e), this); - } else if (cmp == 0) { - return this; - } else { - return (right == null) ? null : right.ceiling(comparator, e); - } - } - - @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) { - int cmp = comparator.compare(e, elem); - if (cmp > 0) { - return (right == null) ? this : Objects.firstNonNull(right.floor(comparator, e), this); - } else if (cmp == 0) { - return this; - } else { - return (left == null) ? null : left.floor(comparator, e); - } + private SetCountModifier(int countToSet) { + checkArgument(countToSet >= 0); + this.countToSet = countToSet; } @Override - public E getElement() { - return elem; + int newCount(int oldCount) { + return countToSet; } + } - @Override - public int getCount() { - return elemCount; + private final class ConditionalSetCountModifier extends MultisetModifier { + private final int expectedCount; + private final int setCount; + + private ConditionalSetCountModifier(int expectedCount, int setCount) { + checkArgument(setCount >= 0 & expectedCount >= 0); + this.expectedCount = expectedCount; + this.setCount = setCount; } @Override - public String toString() { - return Multisets.immutableEntry(getElement(), getCount()).toString(); + int newCount(int oldCount) { + return (oldCount == expectedCount) ? setCount : oldCount; } } - private static <T> void successor(AvlNode<T> a, AvlNode<T> b) { - a.succ = b; - b.pred = a; - } - - private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) { - successor(a, b); - successor(b, c); - } - /* - * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that - * calls the comparator to compare the two keys. If that change is made, - * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets. + * TODO(jlevy): Decide whether entrySet() should return entries with an + * equals() method that calls the comparator to compare the two keys. If that + * change is made, AbstractMultiset.equals() can simply check whether two + * multisets have equal entry sets. */ } - |