Deterministic merging (#785)

pull/787/head
Michael Barry 2024-01-11 08:42:16 -05:00 zatwierdzone przez GitHub
rodzic 96eae6110b
commit 0dc2ee82e1
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 4AEE18F83AFDEB23
12 zmienionych plików z 285 dodań i 168 usunięć

Wyświetl plik

@ -15,7 +15,7 @@ public class BenchmarkKWayMerge {
public static void main(String[] args) { public static void main(String[] args) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
System.err.println(); System.err.println();
testMinHeap("quaternary", LongMinHeap::newArrayHeap); testMinHeap("quaternary", n -> LongMinHeap.newArrayHeap(n, Integer::compare));
System.err.println(String.join("\t", System.err.println(String.join("\t",
"priorityqueue", "priorityqueue",
Long.toString(testPriorityQueue(10).toMillis()), Long.toString(testPriorityQueue(10).toMillis()),

Wyświetl plik

@ -35,6 +35,7 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.TreeMap;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -80,7 +81,8 @@ public class VectorTile {
// TODO make these configurable // TODO make these configurable
private static final int EXTENT = 4096; private static final int EXTENT = 4096;
private static final double SIZE = 256d; private static final double SIZE = 256d;
private final Map<String, Layer> layers = new LinkedHashMap<>(); // use a treemap to ensure that layers are encoded in a consistent order
private final Map<String, Layer> layers = new TreeMap<>();
private LayerAttrStats.Updater.ForZoom layerStatsTracker = LayerAttrStats.Updater.ForZoom.NOOP; private LayerAttrStats.Updater.ForZoom layerStatsTracker = LayerAttrStats.Updater.ForZoom.NOOP;
private static int[] getCommands(Geometry input, int scale) { private static int[] getCommands(Geometry input, int scale) {

Wyświetl plik

@ -18,6 +18,7 @@
package com.onthegomap.planetiler.collection; package com.onthegomap.planetiler.collection;
import java.util.Arrays; import java.util.Arrays;
import java.util.function.IntBinaryOperator;
/** /**
* A min-heap stored in an array where each element has 4 children. * A min-heap stored in an array where each element has 4 children.
@ -38,24 +39,26 @@ import java.util.Arrays;
*/ */
class ArrayLongMinHeap implements LongMinHeap { class ArrayLongMinHeap implements LongMinHeap {
protected static final int NOT_PRESENT = -1; protected static final int NOT_PRESENT = -1;
protected final int[] tree; protected final int[] posToId;
protected final int[] positions; protected final int[] idToPos;
protected final long[] vals; protected final long[] posToValue;
protected final int max; protected final int max;
protected int size; protected int size;
private final IntBinaryOperator tieBreaker;
/** /**
* @param elements the number of elements that can be stored in this heap. Currently the heap cannot be resized or * @param elements the number of elements that can be stored in this heap. Currently the heap cannot be resized or
* shrunk/trimmed after initial creation. elements-1 is the maximum id that can be stored in this heap * shrunk/trimmed after initial creation. elements-1 is the maximum id that can be stored in this heap
*/ */
ArrayLongMinHeap(int elements) { ArrayLongMinHeap(int elements, IntBinaryOperator tieBreaker) {
// we use an offset of one to make the arithmetic a bit simpler/more efficient, the 0th elements are not used! // we use an offset of one to make the arithmetic a bit simpler/more efficient, the 0th elements are not used!
tree = new int[elements + 1]; posToId = new int[elements + 1];
positions = new int[elements + 1]; idToPos = new int[elements + 1];
Arrays.fill(positions, NOT_PRESENT); Arrays.fill(idToPos, NOT_PRESENT);
vals = new long[elements + 1]; posToValue = new long[elements + 1];
vals[0] = Long.MIN_VALUE; posToValue[0] = Long.MIN_VALUE;
this.max = elements; this.max = elements;
this.tieBreaker = tieBreaker;
} }
private static int firstChild(int index) { private static int firstChild(int index) {
@ -87,58 +90,59 @@ class ArrayLongMinHeap implements LongMinHeap {
" was pushed already, you need to use the update method if you want to change its value"); " was pushed already, you need to use the update method if you want to change its value");
} }
size++; size++;
tree[size] = id; posToId[size] = id;
positions[id] = size; idToPos[id] = size;
vals[size] = value; posToValue[size] = value;
percolateUp(size); percolateUp(size);
} }
@Override @Override
public boolean contains(int id) { public boolean contains(int id) {
checkIdInRange(id); checkIdInRange(id);
return positions[id] != NOT_PRESENT; return idToPos[id] != NOT_PRESENT;
} }
@Override @Override
public void update(int id, long value) { public void update(int id, long value) {
checkIdInRange(id); checkIdInRange(id);
int index = positions[id]; int pos = idToPos[id];
if (index < 0) { if (pos < 0) {
throw new IllegalStateException( throw new IllegalStateException(
"The heap does not contain: " + id + ". Use the contains method to check this before calling update"); "The heap does not contain: " + id + ". Use the contains method to check this before calling update");
} }
long prev = vals[index]; long prev = posToValue[pos];
vals[index] = value; posToValue[pos] = value;
if (value > prev) { int cmp = compareIdPos(value, prev, id, pos);
percolateDown(index); if (cmp > 0) {
} else if (value < prev) { percolateDown(pos);
percolateUp(index); } else if (cmp < 0) {
percolateUp(pos);
} }
} }
@Override @Override
public void updateHead(long value) { public void updateHead(long value) {
vals[1] = value; posToValue[1] = value;
percolateDown(1); percolateDown(1);
} }
@Override @Override
public int peekId() { public int peekId() {
return tree[1]; return posToId[1];
} }
@Override @Override
public long peekValue() { public long peekValue() {
return vals[1]; return posToValue[1];
} }
@Override @Override
public int poll() { public int poll() {
int id = peekId(); int id = peekId();
tree[1] = tree[size]; posToId[1] = posToId[size];
vals[1] = vals[size]; posToValue[1] = posToValue[size];
positions[tree[1]] = 1; idToPos[posToId[1]] = 1;
positions[id] = NOT_PRESENT; idToPos[id] = NOT_PRESENT;
size--; size--;
percolateDown(1); percolateDown(1);
return id; return id;
@ -147,29 +151,29 @@ class ArrayLongMinHeap implements LongMinHeap {
@Override @Override
public void clear() { public void clear() {
for (int i = 1; i <= size; i++) { for (int i = 1; i <= size; i++) {
positions[tree[i]] = NOT_PRESENT; idToPos[posToId[i]] = NOT_PRESENT;
} }
size = 0; size = 0;
} }
private void percolateUp(int index) { private void percolateUp(int pos) {
assert index != 0; assert pos != 0;
if (index == 1) { if (pos == 1) {
return; return;
} }
final int el = tree[index]; final int id = posToId[pos];
final long val = vals[index]; final long val = posToValue[pos];
// the finish condition (index==0) is covered here automatically because we set vals[0]=-inf // the finish condition (index==0) is covered here automatically because we set vals[0]=-inf
int parent; int parent;
long parentValue; long parentValue;
while (val < (parentValue = vals[parent = parent(index)])) { while (compareIdPos(val, parentValue = posToValue[parent = parent(pos)], id, parent) < 0) {
vals[index] = parentValue; posToValue[pos] = parentValue;
positions[tree[index] = tree[parent]] = index; idToPos[posToId[pos] = posToId[parent]] = pos;
index = parent; pos = parent;
} }
tree[index] = el; posToId[pos] = id;
vals[index] = val; posToValue[pos] = val;
positions[tree[index]] = index; idToPos[posToId[pos]] = pos;
} }
private void checkIdInRange(int id) { private void checkIdInRange(int id) {
@ -178,45 +182,65 @@ class ArrayLongMinHeap implements LongMinHeap {
} }
} }
private void percolateDown(int index) { private void percolateDown(int pos) {
if (size == 0) { if (size == 0) {
return; return;
} }
assert index > 0; assert pos > 0;
assert index <= size; assert pos <= size;
final int el = tree[index]; final int id = posToId[pos];
final long val = vals[index]; final long value = posToValue[pos];
int child; int child;
while ((child = firstChild(index)) <= size) { while ((child = firstChild(pos)) <= size) {
// optimization: this is a very hot code path for performance of k-way merging, // optimization: this is a very hot code path for performance of k-way merging,
// so manually-unroll the loop over the 4 child elements to find the minimum value // so manually-unroll the loop over the 4 child elements to find the minimum value
int minChild = child; int minChild = child;
long minValue = vals[child], value; long minValue = posToValue[child], childValue;
if (++child <= size) { if (++child <= size) {
if ((value = vals[child]) < minValue) { if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) {
minChild = child; minChild = child;
minValue = value; minValue = childValue;
} }
if (++child <= size) { if (++child <= size) {
if ((value = vals[child]) < minValue) { if (comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) {
minChild = child; minChild = child;
minValue = value; minValue = childValue;
} }
if (++child <= size && (value = vals[child]) < minValue) { if (++child <= size &&
comparePosPos(childValue = posToValue[child], minValue, child, minChild) < 0) {
minChild = child; minChild = child;
minValue = value; minValue = childValue;
} }
} }
} }
if (minValue >= val) { if (comparePosPos(value, minValue, pos, minChild) <= 0) {
break; break;
} }
vals[index] = minValue; posToValue[pos] = minValue;
positions[tree[index] = tree[minChild]] = index; idToPos[posToId[pos] = posToId[minChild]] = pos;
index = minChild; pos = minChild;
} }
tree[index] = el; posToId[pos] = id;
vals[index] = val; posToValue[pos] = value;
positions[el] = index; idToPos[id] = pos;
} }
private int comparePosPos(long val1, long val2, int pos1, int pos2) {
if (val1 < val2) {
return -1;
} else if (val1 == val2 && val1 != Long.MIN_VALUE) {
return tieBreaker.applyAsInt(posToId[pos1], posToId[pos2]);
}
return 1;
}
private int compareIdPos(long val1, long val2, int id1, int pos2) {
if (val1 < val2) {
return -1;
} else if (val1 == val2 && val1 != Long.MIN_VALUE) {
return tieBreaker.applyAsInt(id1, posToId[pos2]);
}
return 1;
}
} }

Wyświetl plik

@ -253,7 +253,7 @@ class ExternalMergeSort implements FeatureSort {
} }
} }
return LongMerger.mergeIterators(iterators); return LongMerger.mergeIterators(iterators, SortableFeature.COMPARE_BYTES);
} }
public int chunks() { public int chunks() {

Wyświetl plik

@ -131,7 +131,7 @@ interface FeatureSort extends Iterable<SortableFeature>, DiskBacked, MemoryEstim
} }
} }
}); });
return new ParallelIterator(reader, LongMerger.mergeSuppliers(queues)); return new ParallelIterator(reader, LongMerger.mergeSuppliers(queues, SortableFeature.COMPARE_BYTES));
} }
record ParallelIterator(Worker reader, @Override Iterator<SortableFeature> iterator) record ParallelIterator(Worker reader, @Override Iterator<SortableFeature> iterator)

Wyświetl plik

@ -2,7 +2,7 @@ package com.onthegomap.planetiler.collection;
/** /**
* An item with a {@code long key} that can be used for sorting/grouping. * An item with a {@code long key} that can be used for sorting/grouping.
* * <p>
* These items can be sorted or grouped by {@link FeatureSort}/{@link FeatureGroup} implementations. Sorted lists can * These items can be sorted or grouped by {@link FeatureSort}/{@link FeatureGroup} implementations. Sorted lists can
* also be merged using {@link LongMerger}. * also be merged using {@link LongMerger}.
*/ */

Wyświetl plik

@ -1,6 +1,7 @@
package com.onthegomap.planetiler.collection; package com.onthegomap.planetiler.collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
@ -16,29 +17,34 @@ public class LongMerger {
private LongMerger() {} private LongMerger() {}
/** Merges sorted items from {@link Supplier Suppliers} that return {@code null} when there are no items left. */ /** Merges sorted items from {@link Supplier Suppliers} that return {@code null} when there are no items left. */
public static <T extends HasLongSortKey> Iterator<T> mergeSuppliers(List<? extends Supplier<T>> suppliers) { public static <T extends HasLongSortKey> Iterator<T> mergeSuppliers(List<? extends Supplier<T>> suppliers,
return mergeIterators(suppliers.stream().map(SupplierIterator::new).toList()); Comparator<T> tieBreaker) {
return mergeIterators(suppliers.stream().map(SupplierIterator::new).toList(), tieBreaker);
} }
/** Merges sorted iterators into a combined iterator over all the items. */ /** Merges sorted iterators into a combined iterator over all the items. */
public static <T extends HasLongSortKey> Iterator<T> mergeIterators(List<? extends Iterator<T>> iterators) { public static <T extends HasLongSortKey> Iterator<T> mergeIterators(List<? extends Iterator<T>> iterators,
Comparator<T> tieBreaker) {
return switch (iterators.size()) { return switch (iterators.size()) {
case 0 -> Collections.emptyIterator(); case 0 -> Collections.emptyIterator();
case 1 -> iterators.get(0); case 1 -> iterators.get(0);
case 2 -> new TwoWayMerge<>(iterators.get(0), iterators.get(1)); case 2 -> new TwoWayMerge<>(iterators.get(0), iterators.get(1), tieBreaker);
case 3 -> new ThreeWayMerge<>(iterators.get(0), iterators.get(1), iterators.get(2)); case 3 -> new ThreeWayMerge<>(iterators.get(0), iterators.get(1), iterators.get(2), tieBreaker);
default -> new KWayMerge<>(iterators); default -> new KWayMerge<>(iterators, tieBreaker);
}; };
} }
private static class TwoWayMerge<T extends HasLongSortKey> implements Iterator<T> { private static class TwoWayMerge<T extends HasLongSortKey> implements Iterator<T> {
private final Comparator<T> tieBreaker;
T a, b; T a, b;
long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE; long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE;
final Iterator<T> inputA, inputB; final Iterator<T> inputA, inputB;
TwoWayMerge(Iterator<T> inputA, Iterator<T> inputB) { TwoWayMerge(Iterator<T> inputA, Iterator<T> inputB, Comparator<T> tieBreaker) {
this.inputA = inputA; this.inputA = inputA;
this.inputB = inputB; this.inputB = inputB;
this.tieBreaker = tieBreaker;
if (inputA.hasNext()) { if (inputA.hasNext()) {
a = inputA.next(); a = inputA.next();
ak = a.key(); ak = a.key();
@ -57,7 +63,7 @@ public class LongMerger {
@Override @Override
public T next() { public T next() {
T result; T result;
if (ak < bk) { if (lessThan(ak, bk, a, b)) {
result = a; result = a;
if (inputA.hasNext()) { if (inputA.hasNext()) {
a = inputA.next(); a = inputA.next();
@ -80,14 +86,21 @@ public class LongMerger {
} }
return result; return result;
} }
private boolean lessThan(long ak, long bk, T a, T b) {
return ak < bk || (ak == bk && lessThanCmp(a, b, tieBreaker));
}
} }
private static class ThreeWayMerge<T extends HasLongSortKey> implements Iterator<T> { private static class ThreeWayMerge<T extends HasLongSortKey> implements Iterator<T> {
private final Comparator<T> tieBreaker;
T a, b, c; T a, b, c;
long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE, ck = Long.MAX_VALUE; long ak = Long.MAX_VALUE, bk = Long.MAX_VALUE, ck = Long.MAX_VALUE;
final Iterator<T> inputA, inputB, inputC; final Iterator<T> inputA, inputB, inputC;
ThreeWayMerge(Iterator<T> inputA, Iterator<T> inputB, Iterator<T> inputC) { ThreeWayMerge(Iterator<T> inputA, Iterator<T> inputB, Iterator<T> inputC, Comparator<T> tieBreaker) {
this.tieBreaker = tieBreaker;
this.inputA = inputA; this.inputA = inputA;
this.inputB = inputB; this.inputB = inputB;
this.inputC = inputC; this.inputC = inputC;
@ -114,8 +127,8 @@ public class LongMerger {
public T next() { public T next() {
T result; T result;
// use at most 2 comparisons to get the next item // use at most 2 comparisons to get the next item
if (ak < bk) { if (lessThan(ak, bk, a, b)) {
if (ak < ck) { if (lessThan(ak, ck, a, c)) {
// ACB / ABC // ACB / ABC
result = a; result = a;
if (inputA.hasNext()) { if (inputA.hasNext()) {
@ -136,7 +149,7 @@ public class LongMerger {
ck = Long.MAX_VALUE; ck = Long.MAX_VALUE;
} }
} }
} else if (ck < bk) { } else if (lessThan(ck, bk, c, b)) {
// CAB // CAB
result = c; result = c;
if (inputC.hasNext()) { if (inputC.hasNext()) {
@ -161,6 +174,21 @@ public class LongMerger {
} }
return result; return result;
} }
private boolean lessThan(long ak, long bk, T a, T b) {
return ak < bk || (ak == bk && lessThanCmp(a, b, tieBreaker));
}
}
private static <T> boolean lessThanCmp(T a, T b, Comparator<T> tieBreaker) {
// nulls go at the end
if (a == null) {
return false;
} else if (b == null) {
return true;
} else {
return tieBreaker.compare(a, b) < 0;
}
} }
private static class KWayMerge<T extends HasLongSortKey> implements Iterator<T> { private static class KWayMerge<T extends HasLongSortKey> implements Iterator<T> {
@ -169,10 +197,10 @@ public class LongMerger {
private final LongMinHeap heap; private final LongMinHeap heap;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
KWayMerge(List<? extends Iterator<T>> inputIterators) { KWayMerge(List<? extends Iterator<T>> inputIterators, Comparator<T> tieBreaker) {
this.iterators = new Iterator[inputIterators.size()]; this.iterators = new Iterator[inputIterators.size()];
this.items = (T[]) new HasLongSortKey[inputIterators.size()]; this.items = (T[]) new HasLongSortKey[inputIterators.size()];
this.heap = LongMinHeap.newArrayHeap(inputIterators.size()); this.heap = LongMinHeap.newArrayHeap(inputIterators.size(), (a, b) -> tieBreaker.compare(items[a], items[b]));
int outIdx = 0; int outIdx = 0;
for (Iterator<T> iter : inputIterators) { for (Iterator<T> iter : inputIterators) {
if (iter.hasNext()) { if (iter.hasNext()) {

Wyświetl plik

@ -17,6 +17,8 @@
*/ */
package com.onthegomap.planetiler.collection; package com.onthegomap.planetiler.collection;
import java.util.function.IntBinaryOperator;
/** /**
* API for min-heaps that keeps track of {@code int} keys in a range from {@code [0, size)} ordered by {@code long} * API for min-heaps that keeps track of {@code int} keys in a range from {@code [0, size)} ordered by {@code long}
* values. * values.
@ -31,8 +33,8 @@ public interface LongMinHeap {
* <p> * <p>
* This is slightly faster than a traditional binary min heap due to a shallower, more cache-friendly memory layout. * This is slightly faster than a traditional binary min heap due to a shallower, more cache-friendly memory layout.
*/ */
static LongMinHeap newArrayHeap(int elements) { static LongMinHeap newArrayHeap(int elements, IntBinaryOperator tieBreaker) {
return new ArrayLongMinHeap(elements); return new ArrayLongMinHeap(elements, tieBreaker);
} }
int size(); int size();

Wyświetl plik

@ -1,12 +1,20 @@
package com.onthegomap.planetiler.collection; package com.onthegomap.planetiler.collection;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
public record SortableFeature(@Override long key, byte[] value) implements Comparable<SortableFeature>, HasLongSortKey { public record SortableFeature(@Override long key, byte[] value) implements Comparable<SortableFeature>, HasLongSortKey {
public static final Comparator<SortableFeature> COMPARE_BYTES = (a, b) -> Arrays.compareUnsigned(a.value, b.value);
@Override @Override
public int compareTo(SortableFeature o) { public int compareTo(SortableFeature o) {
return Long.compare(key, o.key); if (key < o.key) {
return -1;
} else if (key == o.key) {
return Arrays.compareUnsigned(value, o.value);
} else {
return 1;
}
} }
@Override @Override

Wyświetl plik

@ -205,8 +205,8 @@ public class CompareArchives {
compareList(name, "keys list", layer1.getKeysList(), layer2.getKeysList()); compareList(name, "keys list", layer1.getKeysList(), layer2.getKeysList());
compareList(name, "values list", layer1.getValuesList(), layer2.getValuesList()); compareList(name, "values list", layer1.getValuesList(), layer2.getValuesList());
if (compareValues(name, "features count", layer1.getFeaturesCount(), layer2.getFeaturesCount())) { if (compareValues(name, "features count", layer1.getFeaturesCount(), layer2.getFeaturesCount())) {
var ids1 = layer1.getFeaturesList().stream().map(f -> f.getId()); var ids1 = layer1.getFeaturesList().stream().map(f -> f.getId()).toList();
var ids2 = layer1.getFeaturesList().stream().map(f -> f.getId()); var ids2 = layer2.getFeaturesList().stream().map(f -> f.getId()).toList();
if (compareValues(name, "feature ids", Set.of(ids1), Set.of(ids2)) && if (compareValues(name, "feature ids", Set.of(ids1), Set.of(ids2)) &&
compareValues(name, "feature order", ids1, ids2)) { compareValues(name, "feature order", ids1, ids2)) {
for (int i = 0; i < layer1.getFeaturesCount() && i < layer2.getFeaturesCount(); i++) { for (int i = 0; i < layer1.getFeaturesCount() && i < layer2.getFeaturesCount(); i++) {

Wyświetl plik

@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.NoSuchElementException; import java.util.NoSuchElementException;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -12,21 +13,36 @@ import java.util.stream.Stream;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
class LongMergerTest { class LongMergerTest {
record Item(long key) implements HasLongSortKey {} record Item(long key, int secondary) implements HasLongSortKey, Comparable<Item> {
@Override
public int compareTo(Item o) {
int cmp = Long.compare(key, o.key);
if (cmp == 0) {
cmp = Integer.compare(secondary, o.secondary);
}
return cmp;
}
long value() {
return key + secondary;
}
}
record ItemList(List<Item> items) {} record ItemList(List<Item> items) {}
private static ItemList list(long... items) { private static ItemList list(boolean primaryKey, long... items) {
return new ItemList(LongStream.of(items).mapToObj(Item::new).toList()); return new ItemList(
LongStream.of(items).mapToObj(i -> primaryKey ? new Item(i, 0) : new Item(0, (int) i)).toList());
} }
private static List<Long> merge(ItemList... lists) { private static List<Long> merge(ItemList... lists) {
List<Long> list = new ArrayList<>(); List<Long> list = new ArrayList<>();
var iter = LongMerger.mergeIterators(Stream.of(lists) var iter = LongMerger.mergeIterators(Stream.of(lists)
.map(d -> d.items.iterator()) .map(d -> d.items.iterator())
.toList()); .toList(), Comparator.naturalOrder());
iter.forEachRemaining(item -> list.add(item.key)); iter.forEachRemaining(item -> list.add(item.value()));
assertThrows(NoSuchElementException.class, iter::next); assertThrows(NoSuchElementException.class, iter::next);
return list; return list;
} }
@ -36,10 +52,11 @@ class LongMergerTest {
assertEquals(List.of(), merge()); assertEquals(List.of(), merge());
} }
@Test @ParameterizedTest
void testMergeSupplier() { @ValueSource(booleans = {true, false})
void testMergeSupplier(boolean primaryKey) {
List<Long> list = new ArrayList<>(); List<Long> list = new ArrayList<>();
var iter = LongMerger.mergeSuppliers(Stream.of(new ItemList[]{list(1, 2)}) var iter = LongMerger.mergeSuppliers(Stream.of(new ItemList[]{list(primaryKey, 1, 2)})
.map(d -> d.items.iterator()) .map(d -> d.items.iterator())
.<Supplier<Item>>map(d -> () -> { .<Supplier<Item>>map(d -> () -> {
try { try {
@ -48,17 +65,18 @@ class LongMergerTest {
return null; return null;
} }
}) })
.toList()); .toList(), Comparator.naturalOrder());
iter.forEachRemaining(item -> list.add(item.key)); iter.forEachRemaining(item -> list.add(item.value()));
assertThrows(NoSuchElementException.class, iter::next); assertThrows(NoSuchElementException.class, iter::next);
assertEquals(List.of(1L, 2L), list); assertEquals(List.of(1L, 2L), list);
} }
@Test @ParameterizedTest
void testMerge1() { @ValueSource(booleans = {true, false})
assertEquals(List.of(), merge(list())); void testMerge1(boolean primaryKey) {
assertEquals(List.of(1L), merge(list(1))); assertEquals(List.of(), merge(list(primaryKey)));
assertEquals(List.of(1L, 2L), merge(list(1, 2))); assertEquals(List.of(1L), merge(list(primaryKey, 1)));
assertEquals(List.of(1L, 2L), merge(list(primaryKey, 1, 2)));
} }
@ParameterizedTest @ParameterizedTest
@ -73,17 +91,21 @@ class LongMergerTest {
"1 3,2,1 2 3", "1 3,2,1 2 3",
}, nullValues = {"null"}) }, nullValues = {"null"})
void testMerge2(String a, String b, String output) { void testMerge2(String a, String b, String output) {
var listA = list(parse(a)); for (boolean primaryKey : List.of(false, true)) {
var listB = list(parse(b)); var listA = list(primaryKey, parse(a));
var listB = list(primaryKey, parse(b));
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listA, listB) merge(listA, listB),
"primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listA) merge(listB, listA),
"primary=" + primaryKey
); );
} }
}
@ParameterizedTest @ParameterizedTest
@CsvSource(value = { @CsvSource(value = {
@ -98,40 +120,42 @@ class LongMergerTest {
"1 3,2,4,1 2 3 4", "1 3,2,4,1 2 3 4",
}, nullValues = {""}) }, nullValues = {""})
void testMerge3(String a, String b, String c, String output) { void testMerge3(String a, String b, String c, String output) {
var listA = list(parse(a)); for (boolean primaryKey : List.of(false, true)) {
var listB = list(parse(b)); var listA = list(primaryKey, parse(a));
var listC = list(parse(c)); var listB = list(primaryKey, parse(b));
var listC = list(primaryKey, parse(c));
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listA, listB, listC), merge(listA, listB, listC),
"ABC" "ABC primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listA, listC, listB), merge(listA, listC, listB),
"ACB" "ACB primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listA, listC), merge(listB, listA, listC),
"BAC" "BAC primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listC, listA), merge(listB, listC, listA),
"BCA" "BCA primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listC, listA, listB), merge(listC, listA, listB),
"CAB" "CAB primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listC, listB, listA), merge(listC, listB, listA),
"CBA" "CBA primary=" + primaryKey
); );
} }
}
@ParameterizedTest @ParameterizedTest
@CsvSource(value = { @CsvSource(value = {
@ -146,32 +170,34 @@ class LongMergerTest {
"1 2,2 3,,,1 2 2 3", "1 2,2 3,,,1 2 2 3",
}, nullValues = {""}) }, nullValues = {""})
void testMerge4(String a, String b, String c, String d, String output) { void testMerge4(String a, String b, String c, String d, String output) {
var listA = list(parse(a)); for (boolean primaryKey : List.of(false, true)) {
var listB = list(parse(b)); var listA = list(primaryKey, parse(a));
var listC = list(parse(c)); var listB = list(primaryKey, parse(b));
var listD = list(parse(d)); var listC = list(primaryKey, parse(c));
var listD = list(primaryKey, parse(d));
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listA, listB, listC, listD), merge(listA, listB, listC, listD),
"ABCD" "ABCD primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listA, listC, listD), merge(listB, listA, listC, listD),
"BACD" "BACD primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listC, listA, listD), merge(listB, listC, listA, listD),
"BCAD" "BCAD primary=" + primaryKey
); );
assertEquals( assertEquals(
LongStream.of(parse(output)).boxed().toList(), LongStream.of(parse(output)).boxed().toList(),
merge(listB, listC, listD, listA), merge(listB, listC, listD, listA),
"BCDA" "BCDA primary=" + primaryKey
); );
} }
}
private static long[] parse(String in) { private static long[] parse(String in) {
return in == null ? new long[0] : Stream.of(in.split("\\s+")) return in == null ? new long[0] : Stream.of(in.split("\\s+"))

Wyświetl plik

@ -29,6 +29,8 @@ import com.carrotsearch.hppc.IntSet;
import java.util.PriorityQueue; import java.util.PriorityQueue;
import java.util.Random; import java.util.Random;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
/** /**
@ -42,7 +44,7 @@ class LongMinHeapTest {
protected LongMinHeap heap; protected LongMinHeap heap;
void create(int capacity) { void create(int capacity) {
heap = LongMinHeap.newArrayHeap(capacity); heap = LongMinHeap.newArrayHeap(capacity, Integer::compare);
} }
@Test @Test
@ -77,6 +79,31 @@ class LongMinHeapTest {
assertThrows(IllegalStateException.class, () -> heap.push(2, 4L)); assertThrows(IllegalStateException.class, () -> heap.push(2, 4L));
} }
@ParameterizedTest
@CsvSource({
"0, 1, 2, 3, 4, 5",
"5, 4, 3, 2, 1, 0",
"0, 1, 2, 5, 4, 3",
"0, 1, 5, 2, 4, 3",
"0, 5, 1, 2, 4, 3",
"5, 0, 1, 2, 4, 3",
})
void tieBreaker(int a, int b, int c, int d, int e, int f) {
heap = LongMinHeap.newArrayHeap(6, (id1, id2) -> -Integer.compare(id1, id2));
heap.push(a, 0L);
heap.push(b, 0L);
heap.push(c, 0L);
heap.push(d, 0L);
heap.push(e, 0L);
heap.push(f, 0L);
assertEquals(5, heap.poll());
assertEquals(4, heap.poll());
assertEquals(3, heap.poll());
assertEquals(2, heap.poll());
assertEquals(1, heap.poll());
assertEquals(0, heap.poll());
}
@Test @Test
void testContains() { void testContains() {
create(4); create(4);