diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 8b0e4ee433fa..781fa91f1a03 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -31,6 +31,8 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.InstantCoder; @@ -308,21 +310,24 @@ public void splitRestriction( } @NewTracker - public RestrictionTracker[]> restrictionTracker( + public RestrictionTracker>[]> restrictionTracker( @Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) { return new BoundedSourceAsSDFRestrictionTracker<>(restriction, pipelineOptions); } @ProcessElement public void processElement( - RestrictionTracker[]> tracker, + RestrictionTracker>[]> tracker, OutputReceiver receiver) throws IOException { @SuppressWarnings( "rawtypes") // most straightforward way of creating array with type parameter - TimestampedValue[] out = new TimestampedValue[1]; + Supplier>[] out = new Supplier[1]; while (tracker.tryClaim(out)) { - receiver.outputWithTimestamp(out[0].getValue(), out[0].getTimestamp()); + TimestampedValue currentValue = out[0].get(); + if (currentValue != null) { + receiver.outputWithTimestamp(currentValue.getValue(), currentValue.getTimestamp()); + } } } @@ -337,40 +342,44 @@ public Coder restrictionCoder() { */ private static class BoundedSourceAsSDFRestrictionTracker< BoundedSourceT extends BoundedSource, T> - extends RestrictionTracker[]> implements HasProgress { + extends RestrictionTracker>[]> implements HasProgress { private final BoundedSourceT initialRestriction; private final PipelineOptions pipelineOptions; private BoundedSource.@Nullable BoundedReader currentReader = null; private boolean claimedAll; + private boolean splitActive; + private Supplier> readNextElementSupplier; BoundedSourceAsSDFRestrictionTracker( BoundedSourceT initialRestriction, PipelineOptions pipelineOptions) { this.initialRestriction = initialRestriction; this.pipelineOptions = pipelineOptions; + this.readNextElementSupplier = new Supplier<@Nullable TimestampedValue>() { + @Override + public @Nullable TimestampedValue get() { + try { + readOrThrow(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } } @Override - public boolean tryClaim(TimestampedValue[] position) { + public boolean tryClaim(Supplier>[] position) { + if (splitActive) { + // notify split to continue + // wait for split to complete + } if (claimedAll) { return false; } - try { - return tryClaimOrThrow(position); - } catch (IOException e) { - if (currentReader != null) { - try { - currentReader.close(); - } catch (IOException closeException) { - e.addSuppressed(closeException); - } finally { - currentReader = null; - } - } - throw new RuntimeException(e); - } + position[0] = readNextElementSupplier; + return true; } - private boolean tryClaimOrThrow(TimestampedValue[] position) throws IOException { + private @Nullable TimestampedValue readOrThrow() throws IOException { BoundedSource.BoundedReader currentReader = this.currentReader; if (currentReader == null) { BoundedSource.BoundedReader newReader = @@ -378,12 +387,10 @@ private boolean tryClaimOrThrow(TimestampedValue[] position) throws IOExcepti if (!newReader.start()) { claimedAll = true; newReader.close(); - return false; + return null; } - position[0] = - TimestampedValue.of(newReader.getCurrent(), newReader.getCurrentTimestamp()); this.currentReader = newReader; - return true; + return TimestampedValue.of(newReader.getCurrent(), newReader.getCurrentTimestamp()); } if (!currentReader.advance()) { @@ -393,12 +400,12 @@ private boolean tryClaimOrThrow(TimestampedValue[] position) throws IOExcepti } finally { this.currentReader = null; } - return false; + return null; } - position[0] = - TimestampedValue.of(currentReader.getCurrent(), currentReader.getCurrentTimestamp()); - return true; + return TimestampedValue.of( + currentReader.getCurrent(), + currentReader.getCurrentTimestamp()); } @SuppressWarnings("Finalize")