public StreamObserver reverseArtifactRetrievalService()

in runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java [300:535]


  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> reverseArtifactRetrievalService(
      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {

    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {

      /** The maximum number of parallel threads to use to stage. */
      public static final int THREAD_POOL_SIZE = 10;

      /** The maximum number of bytes to buffer across all writes before throttling. */
      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB

      IdGenerator idGenerator = IdGenerators.incrementingLongs();

      String stagingToken;
      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
      ExecutorService stagingExecutor;
      OverflowingSemaphore totalPendingBytes;

      State state = State.START;
      Queue<String> pendingResolves;
      String currentEnvironment;
      Queue<RunnerApi.ArtifactInformation> pendingGets;
      BlockingQueue<ByteString> currentOutput;

      @Override
      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = "fallthrough intended")
      // May be called by different threads for the same request; synchronized for memory
      // synchronization.
      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper responseWrapper) {
        switch (state) {
          case START:
            stagingToken = responseWrapper.getStagingToken();
            LOG.info("Staging artifacts for {}.", stagingToken);
            toResolve = toStage.get(stagingToken);
            if (toResolve == null) {
              responseObserver.onError(
                  new StatusException(
                      Status.INVALID_ARGUMENT.withDescription(
                          "Unknown staging token " + stagingToken)));
              return;
            }
            stagedFutures = new ConcurrentHashMap<>();
            pendingResolves = new ArrayDeque<>();
            pendingResolves.addAll(toResolve.keySet());
            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
            resolveNextEnvironment(responseObserver);
            break;

          case RESOLVE:
            {
              currentEnvironment = pendingResolves.remove();
              stagedFutures.put(currentEnvironment, new ArrayList<>());
              pendingGets = new ArrayDeque<>();
              for (RunnerApi.ArtifactInformation artifact :
                  responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
                Optional<RunnerApi.ArtifactInformation> fetched = getLocal(artifact);
                if (fetched.isPresent()) {
                  stagedFutures
                      .get(currentEnvironment)
                      .add(CompletableFuture.completedFuture(fetched.get()));
                } else {
                  pendingGets.add(artifact);
                  responseObserver.onNext(
                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
                          .setGetArtifact(
                              ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
                          .build());
                }
              }
              LOG.info(
                  "Getting {} artifacts for {}.{}.",
                  pendingGets.size(),
                  stagingToken,
                  pendingResolves.peek());
              if (pendingGets.isEmpty()) {
                resolveNextEnvironment(responseObserver);
              } else {
                state = State.GET;
              }
              break;
            }

          case GET:
            RunnerApi.ArtifactInformation currentArtifact = pendingGets.remove();
            String name = createFilename(currentEnvironment, currentArtifact);
            try {
              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
              currentOutput = new ArrayBlockingQueue<ByteString>(100);
              stagedFutures
                  .get(currentEnvironment)
                  .add(
                      stagingExecutor.submit(
                          new StoreArtifact(
                              stagingToken,
                              name,
                              currentArtifact,
                              currentOutput,
                              totalPendingBytes)));
            } catch (Exception exn) {
              LOG.error("Error submitting.", exn);
              responseObserver.onError(exn);
            }
            state = State.GETCHUNK;
            // fall through

          case GETCHUNK:
            try {
              ByteString chunk = responseWrapper.getGetArtifactResponse().getData();
              if (chunk.size() > 0) { // Make sure we don't accidentally send the EOF value.
                totalPendingBytes.aquire(chunk.size());
                currentOutput.put(chunk);
              }
              if (responseWrapper.getIsLast()) {
                currentOutput.put(ByteString.EMPTY); // The EOF value.
                if (pendingGets.isEmpty()) {
                  resolveNextEnvironment(responseObserver);
                } else {
                  state = State.GET;
                  LOG.debug("Waiting for {}", pendingGets.peek());
                }
              }
            } catch (Exception exn) {
              LOG.error("Error submitting.", exn);
              onError(exn);
            }
            break;

          default:
            responseObserver.onError(
                new StatusException(
                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + state)));
        }
      }

      private void resolveNextEnvironment(
          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
        if (pendingResolves.isEmpty()) {
          finishStaging(responseObserver);
        } else {
          state = State.RESOLVE;
          LOG.info("Resolving artifacts for {}.{}.", stagingToken, pendingResolves.peek());
          responseObserver.onNext(
              ArtifactApi.ArtifactRequestWrapper.newBuilder()
                  .setResolveArtifact(
                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
                          .addAllArtifacts(toResolve.get(pendingResolves.peek())))
                  .build());
        }
      }

      private void finishStaging(
          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
        LOG.debug("Finishing staging for {}.", stagingToken);
        Map<String, List<RunnerApi.ArtifactInformation>> staged = new HashMap<>();
        try {
          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> entry :
              stagedFutures.entrySet()) {
            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
            for (Future<RunnerApi.ArtifactInformation> future : entry.getValue()) {
              envStaged.add(future.get());
            }
            staged.put(entry.getKey(), envStaged);
          }
          ArtifactStagingService.this.staged.put(stagingToken, staged);
          stagingExecutor.shutdown();
          state = State.DONE;
          LOG.info("Artifacts fully staged for {}.", stagingToken);
          responseObserver.onCompleted();
        } catch (Exception exn) {
          LOG.error("Error staging artifacts", exn);
          responseObserver.onError(exn);
          state = State.ERROR;
          return;
        }
      }

      /**
       * Return an alternative artifact if we do not need to get this over the artifact API, or
       * possibly at all.
       */
      private Optional<RunnerApi.ArtifactInformation> getLocal(
          RunnerApi.ArtifactInformation artifact) {
        return Optional.empty();
      }

      /**
       * Attempts to provide a reasonable filename for the artifact.
       *
       * @param index a monotonically increasing index, which provides uniqueness
       * @param environment the environment id
       * @param artifact the artifact itself
       */
      private String createFilename(String environment, RunnerApi.ArtifactInformation artifact) {
        String path;
        try {
          if (artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN)) {
            path =
                RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
                    .getStagedName();
          } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
            path = RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
          } else if (artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
            path = RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
          } else {
            path = "artifact";
          }
        } catch (InvalidProtocolBufferException exn) {
          throw new RuntimeException(exn);
        }
        // Limit to the last contiguous alpha-numeric sequence. In particular, this will exclude
        // all path separators.
        List<String> components = Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
        String base = components.get(components.size() - 1);
        return clip(
            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 25), base), 100);
      }

      private String clip(String s, int maxLength) {
        return s.length() < maxLength ? s : s.substring(0, maxLength);
      }

      @Override
      public void onError(Throwable throwable) {
        stagingExecutor.shutdownNow();
        LOG.error("Error staging artifacts", throwable);
        state = State.ERROR;
      }

      @Override
      public void onCompleted() {
        Preconditions.checkArgument(state == State.DONE);
      }
    };
  }