diff --git a/flatmap-core/src/main/java/com/onthegomap/flatmap/collections/LongLongMap.java b/flatmap-core/src/main/java/com/onthegomap/flatmap/collections/LongLongMap.java index dc1cfcd2..7859e634 100644 --- a/flatmap-core/src/main/java/com/onthegomap/flatmap/collections/LongLongMap.java +++ b/flatmap-core/src/main/java/com/onthegomap/flatmap/collections/LongLongMap.java @@ -85,7 +85,7 @@ public interface LongLongMap extends Closeable { String rate = Format.formatNumeric(entries * NANOSECONDS_PER_SECOND / (end - start), false) + "/s"; System.err.println("Loaded " + entries + " in " + Duration.ofNanos(end - start).toSeconds() + "s (" + rate + ")"); writeRate.set(rate); - }).awaitAndLog(loggers, Duration.ofSeconds(10), Duration.ofSeconds(10)); + }).awaitAndLog(loggers, Duration.ofSeconds(10)); map.get(1); System.err.println("Storage: " + Format.formatStorage(map.fileSize(), false)); diff --git a/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Topology.java b/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Topology.java index 19ba6ea2..5c7d5202 100644 --- a/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Topology.java +++ b/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Topology.java @@ -1,55 +1,39 @@ package com.onthegomap.flatmap.worker; +import static com.onthegomap.flatmap.worker.Worker.joinFutures; + import com.onthegomap.flatmap.monitoring.ProgressLoggers; import com.onthegomap.flatmap.monitoring.Stats; import java.time.Duration; import java.util.Collection; import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import java.util.function.Supplier; public record Topology( String name, - com.onthegomap.flatmap.worker.Topology previous, + Topology previous, WorkQueue inputQueue, - Worker worker + Worker worker, + CompletableFuture done ) { public static Empty start(String prefix, Stats stats) { return new Empty(prefix, stats); } - // track time since last log and stagger initial log interval for each step to keep logs - // coming at consistent intervals - private void doAwaitAndLog(ProgressLoggers loggers, Duration logInterval, long startNanos) { - if (previous != null) { - previous.doAwaitAndLog(loggers, logInterval, startNanos); - if (inputQueue != null) { - inputQueue.close(); - } - } - if (worker != null) { - long elapsedSoFar = System.nanoTime() - startNanos; - Duration sinceLastLog = Duration.ofNanos(elapsedSoFar % logInterval.toNanos()); - Duration untilNextLog = logInterval.minus(sinceLastLog); - worker.awaitAndLog(loggers, untilNextLog, logInterval); - } - } - public void awaitAndLog(ProgressLoggers loggers, Duration logInterval) { - doAwaitAndLog(loggers, logInterval, System.nanoTime()); + loggers.awaitAndLog(done, logInterval); loggers.log(); } public void await() { - if (previous != null) { - previous.await(); - if (inputQueue != null) { - inputQueue.close(); - } - } - if (worker != null) { - worker.await(); + try { + done.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); } } @@ -146,13 +130,21 @@ public record Topology( private Topology build() { var previousTopology = previous == null || previous.worker == null ? null : previous.build(); - return new Topology<>(name, previousTopology, inputQueue, worker); + var doneFuture = worker != null ? worker.done() : CompletableFuture.completedFuture(true); + if (previousTopology != null) { + doneFuture = joinFutures(doneFuture, previousTopology.done); + } + if (outputQueue != null) { + doneFuture = doneFuture.thenRun(outputQueue::close); + } + return new Topology<>(name, previousTopology, inputQueue, worker, doneFuture); } public Topology sinkTo(String name, int threads, SinkStep step) { var previousTopology = build(); var worker = new Worker(prefix + "_" + name, stats, threads, () -> step.run(outputQueue.threadLocalReader())); - return new Topology<>(name, previousTopology, outputQueue, worker); + var doneFuture = joinFutures(worker.done(), previousTopology.done); + return new Topology<>(name, previousTopology, outputQueue, worker, doneFuture); } public Topology sinkToConsumer(String name, int threads, Consumer step) { diff --git a/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Worker.java b/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Worker.java index 600a9182..a35ec6fb 100644 --- a/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Worker.java +++ b/flatmap-core/src/main/java/com/onthegomap/flatmap/worker/Worker.java @@ -2,11 +2,16 @@ package com.onthegomap.flatmap.worker; import com.onthegomap.flatmap.monitoring.ProgressLoggers; import com.onthegomap.flatmap.monitoring.Stats; +import java.io.IOException; +import java.io.UncheckedIOException; import java.time.Duration; -import java.util.concurrent.ExecutorService; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; @@ -15,9 +20,8 @@ import org.slf4j.LoggerFactory; public class Worker { private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); - private final ExecutorService es; private final String prefix; - private final Stats stats; + private final CompletableFuture done; private static class NamedThreadFactory implements ThreadFactory { @@ -49,52 +53,77 @@ public class Worker { public Worker(String prefix, Stats stats, int threads, RunnableThatThrows task) { this.prefix = prefix; - this.stats = stats; stats.gauge(prefix + "_threads", threads); - es = Executors.newFixedThreadPool(threads, new NamedThreadFactory(prefix)); + var es = Executors.newFixedThreadPool(threads, new NamedThreadFactory(prefix)); + List> results = new ArrayList<>(); for (int i = 0; i < threads; i++) { - es.submit(() -> { + results.add(CompletableFuture.runAsync(() -> { String id = Thread.currentThread().getName(); LOGGER.trace("Starting worker"); try { task.run(); } catch (Throwable e) { System.err.println("Worker " + id + " died"); - e.printStackTrace(); - System.exit(1); + throwRuntimeException(e); } finally { LOGGER.trace("Finished worker"); } - }); + }, es)); } es.shutdown(); + done = joinFutures(results); } public String getPrefix() { return prefix; } - public void awaitAndLog(ProgressLoggers loggers, Duration initialLogInterval, Duration logInterval) { - try { - if (!es.awaitTermination(initialLogInterval.toNanos(), TimeUnit.NANOSECONDS)) { - loggers.log(); - while (!es.awaitTermination(logInterval.toNanos(), TimeUnit.NANOSECONDS)) { - loggers.log(); + public static CompletableFuture joinFutures(CompletableFuture... futures) { + return joinFutures(List.of(futures)); + } + + public static CompletableFuture joinFutures(Collection> futures) { + CompletableFuture result = new CompletableFuture<>(); + for (CompletableFuture f : futures) { + f.whenComplete((res, ex) -> { + if (ex != null) { + result.completeExceptionally(ex); } - } - } catch (InterruptedException e) { - throw new RuntimeException(e); + }); } + CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).thenAccept(result::complete); + return result; + } + + public CompletableFuture done() { + return done; + } + + public void awaitAndLog(ProgressLoggers loggers, Duration logInterval) { + loggers.awaitAndLog(done(), logInterval); } public void await() { try { - es.awaitTermination(365, TimeUnit.DAYS); - } catch (InterruptedException e) { + done().get(); + } catch (ExecutionException | InterruptedException e) { throw new RuntimeException(e); } } + private static void throwRuntimeException(Throwable exception) { + if (exception instanceof RuntimeException runtimeException) { + throw runtimeException; + } else if (exception instanceof IOException ioe) { + throw new UncheckedIOException(ioe); + } else if (exception instanceof Error error) { + throw error; + } else if (exception instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException(exception); + } + public interface RunnableThatThrows { void run() throws Exception; diff --git a/flatmap-core/src/test/java/com/onthegomap/flatmap/FlatMapTest.java b/flatmap-core/src/test/java/com/onthegomap/flatmap/FlatMapTest.java index 6a884bb6..daebefa0 100644 --- a/flatmap-core/src/test/java/com/onthegomap/flatmap/FlatMapTest.java +++ b/flatmap-core/src/test/java/com/onthegomap/flatmap/FlatMapTest.java @@ -3,6 +3,7 @@ package com.onthegomap.flatmap; import static com.onthegomap.flatmap.TestUtils.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import com.graphhopper.reader.ReaderElement; @@ -754,6 +755,19 @@ public class FlatMapTest { ), results.tiles); } + @Test + public void testExceptionWhileProcessingOsm() { + assertThrows(RuntimeException.class, () -> runWithOsmElements( + Map.of("threads", "1"), + List.of( + with(new ReaderNode(1, 0, 0), t -> t.setTag("attr", "value")) + ), + (in, features) -> { + throw new Error(); + } + )); + } + @Test public void testOsmLine() throws Exception { var results = runWithOsmElements( diff --git a/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/TopologyTest.java b/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/TopologyTest.java index 156413db..4cc66c3c 100644 --- a/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/TopologyTest.java +++ b/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/TopologyTest.java @@ -1,6 +1,7 @@ package com.onthegomap.flatmap.worker; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.onthegomap.flatmap.monitoring.ProgressLoggers; import com.onthegomap.flatmap.monitoring.Stats; @@ -11,6 +12,8 @@ import java.util.Set; import java.util.TreeSet; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; public class TopologyTest { @@ -86,4 +89,38 @@ public class TopologyTest { assertEquals(Set.of(1, 2, 3, 4), result); } + + @ParameterizedTest + @Timeout(10) + @ValueSource(ints = {1, 2, 3}) + public void testThrowingExceptionInTopologyHandledGracefully(int failureStage) { + class ExpectedException extends RuntimeException {} + Set result = Collections.synchronizedSet(new TreeSet<>()); + var topology = Topology.start("test", stats) + .fromGenerator("reader", (next) -> { + if (failureStage == 1) { + throw new ExpectedException(); + } + next.accept(0); + next.accept(1); + }).addBuffer("reader_queue", 1) + .addWorker("process", 1, (prev, next) -> { + if (failureStage == 2) { + throw new ExpectedException(); + } + Integer item; + while ((item = prev.get()) != null) { + next.accept(item * 2 + 1); + next.accept(item * 2 + 2); + } + }).addBuffer("writer_queue", 1) + .sinkToConsumer("writer", 1, item -> { + if (failureStage == 3) { + throw new ExpectedException(); + } + }); + + assertThrows(RuntimeException.class, + () -> topology.await());//awaitAndLog(new ProgressLoggers("test"), Duration.ofSeconds(1))); + } } diff --git a/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/WorkerTest.java b/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/WorkerTest.java new file mode 100644 index 00000000..7de9863c --- /dev/null +++ b/flatmap-core/src/test/java/com/onthegomap/flatmap/worker/WorkerTest.java @@ -0,0 +1,25 @@ +package com.onthegomap.flatmap.worker; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.onthegomap.flatmap.monitoring.Stats; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +public class WorkerTest { + + @Test + @Timeout(10) + public void testExceptionHandled() { + AtomicInteger counter = new AtomicInteger(0); + var worker = new Worker("prefix", Stats.inMemory(), 4, () -> { + if (counter.incrementAndGet() == 1) { + throw new Error(); + } else { + Thread.sleep(5000); + } + }); + assertThrows(RuntimeException.class, worker::await); + } +}