/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.ml;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
import org.apache.flink.streaming.api.functions.async.ResultFuture;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;

public class AsyncMLPredictRunner
extends AbstractAsyncFunctionRunner<RowData> {
    private final int asyncBufferCapacity;
    private transient BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;

    public AsyncMLPredictRunner(GeneratedFunction<AsyncFunction<RowData, RowData>> generatedFetcher, int asyncBufferCapacity) {
        super(generatedFetcher);
        this.asyncBufferCapacity = asyncBufferCapacity;
    }

    @Override
    public void open(OpenContext openContext) throws Exception {
        super.open(openContext);
        this.resultFutureBuffer = new ArrayBlockingQueue<JoinedRowResultFuture>(this.asyncBufferCapacity + 1);
        for (int i = 0; i < this.asyncBufferCapacity + 1; ++i) {
            JoinedRowResultFuture rf = new JoinedRowResultFuture(this.resultFutureBuffer);
            this.resultFutureBuffer.add(rf);
        }
        this.registerMetric((MetricGroup)this.getRuntimeContext().getMetricGroup());
    }

    public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture) throws Exception {
        try {
            JoinedRowResultFuture buffer = this.resultFutureBuffer.take();
            buffer.reset(input, resultFuture);
            this.fetcher.asyncInvoke((Object)input, (ResultFuture)buffer);
        }
        catch (Throwable t) {
            resultFuture.completeExceptionally(t);
        }
    }

    private void registerMetric(MetricGroup metricGroup) {
        metricGroup.gauge("ai_queue_length", () -> this.asyncBufferCapacity + 1 - this.resultFutureBuffer.size());
        metricGroup.gauge("ai_queue_capacity", () -> this.asyncBufferCapacity);
        metricGroup.gauge("ai_queue_usage_ratio", () -> 1.0 * (double)(this.asyncBufferCapacity + 1 - this.resultFutureBuffer.size()) / (double)this.asyncBufferCapacity);
    }

    private static final class JoinedRowResultFuture
    implements ResultFuture<RowData> {
        private final BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;
        private ResultFuture<RowData> realOutput;
        private RowData leftRow;

        public JoinedRowResultFuture(BlockingQueue<JoinedRowResultFuture> resultFutureBuffer) {
            this.resultFutureBuffer = resultFutureBuffer;
        }

        public void reset(RowData row, ResultFuture<RowData> realOutput) {
            this.realOutput = realOutput;
            this.leftRow = row;
        }

        public void complete(Collection<RowData> result) {
            ArrayList<JoinedRowData> outRows = new ArrayList<JoinedRowData>();
            for (RowData rightRow : result) {
                JoinedRowData outRow = new JoinedRowData(this.leftRow.getRowKind(), this.leftRow, rightRow);
                outRows.add(outRow);
            }
            this.realOutput.complete(outRows);
            try {
                this.resultFutureBuffer.put(this);
            }
            catch (InterruptedException e) {
                this.completeExceptionally(e);
            }
        }

        public void completeExceptionally(Throwable error) {
            this.realOutput.completeExceptionally(error);
        }

        public void complete(CollectionSupplier<RowData> supplier) {
            throw new UnsupportedOperationException();
        }
    }
}

