/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.strategy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.AllToAllBlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
import org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils;
import org.apache.flink.streaming.api.graph.StreamGraphContext;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
import org.apache.flink.streaming.runtime.partitioner.ForwardForConsecutiveHashPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
import org.apache.flink.table.runtime.strategy.AdaptiveJoinOptimizationUtils;
import org.apache.flink.table.runtime.strategy.BaseAdaptiveJoinOperatorOptimizationStrategy;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveSkewedJoinOptimizationStrategy
extends BaseAdaptiveJoinOperatorOptimizationStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(AdaptiveSkewedJoinOptimizationStrategy.class);
    private Map<Integer, Map<Integer, long[]>> aggregatedProducedBytesByTypeNumberAndNodeId;
    private OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy adaptiveSkewedJoinOptimizationStrategy;
    private long skewedThresholdInBytes;
    private double skewedFactor;

    public void initialize(StreamGraphContext context) {
        ReadableConfig config = context.getStreamGraph().getConfiguration();
        this.aggregatedProducedBytesByTypeNumberAndNodeId = new HashMap<Integer, Map<Integer, long[]>>();
        this.adaptiveSkewedJoinOptimizationStrategy = (OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy)config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_STRATEGY);
        this.skewedFactor = (Double)config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_FACTOR);
        this.skewedThresholdInBytes = ((MemorySize)config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_ADAPTIVE_SKEWED_JOIN_OPTIMIZATION_SKEWED_THRESHOLD)).getBytes();
    }

    public boolean onOperatorsFinished(OperatorsFinished operatorsFinished, StreamGraphContext context) throws Exception {
        this.visitDownstreamAdaptiveJoinNode(operatorsFinished, context);
        return true;
    }

    @Override
    void tryOptimizeAdaptiveJoin(OperatorsFinished operatorsFinished, StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode, List<ImmutableStreamEdge> upstreamStreamEdges, AdaptiveJoin adaptiveJoin) {
        if (!this.canPerformOptimization(context, adaptiveJoinNode)) {
            this.freeNodeStatistic(adaptiveJoinNode.getId());
            return;
        }
        for (ImmutableStreamEdge edge : upstreamStreamEdges) {
            BlockingResultInfo resultInfo = AdaptiveSkewedJoinOptimizationStrategy.getBlockingResultInfo(operatorsFinished, context, edge);
            Preconditions.checkState((boolean)(resultInfo instanceof AllToAllBlockingResultInfo));
            this.aggregatedProducedBytesByTypeNumber(adaptiveJoinNode, edge.getTypeNumber(), ((AllToAllBlockingResultInfo)resultInfo).getAggregatedSubpartitionBytes());
        }
        if (context.checkUpstreamNodesFinished(adaptiveJoinNode, null)) {
            this.applyAdaptiveSkewedJoinOptimization(context, adaptiveJoinNode, adaptiveJoin.getJoinType());
            this.freeNodeStatistic(adaptiveJoinNode.getId());
        }
    }

    private boolean canPerformOptimization(StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode) {
        if (AdaptiveJoinOptimizationUtils.isBroadcastJoin(adaptiveJoinNode)) {
            return false;
        }
        if (this.adaptiveSkewedJoinOptimizationStrategy == OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.AUTO) {
            return AdaptiveSkewedJoinOptimizationStrategy.canPerformOptimizationAutomatic(context, adaptiveJoinNode);
        }
        if (this.adaptiveSkewedJoinOptimizationStrategy == OptimizerConfigOptions.AdaptiveSkewedJoinOptimizationStrategy.FORCED) {
            return AdaptiveSkewedJoinOptimizationStrategy.canPerformOptimizationForced(context, adaptiveJoinNode);
        }
        return false;
    }

    private static BlockingResultInfo getBlockingResultInfo(OperatorsFinished operatorsFinished, StreamGraphContext context, ImmutableStreamEdge edge) {
        List resultInfos = (List)operatorsFinished.getResultInfoMap().get(edge.getSourceId());
        IntermediateDataSetID intermediateDataSetId = context.getConsumedIntermediateDataSetId(edge.getEdgeId());
        for (BlockingResultInfo result : resultInfos) {
            if (!result.getResultId().equals((Object)intermediateDataSetId)) continue;
            return result;
        }
        throw new IllegalStateException("No matching BlockingResultInfo found for edge ID: " + edge.getEdgeId());
    }

    private void aggregatedProducedBytesByTypeNumber(ImmutableStreamNode adaptiveJoinNode, int typeNumber, List<Long> subpartitionBytes) {
        Integer streamNodeId = adaptiveJoinNode.getId();
        long[] aggregatedSubpartitionBytes = this.aggregatedProducedBytesByTypeNumberAndNodeId.computeIfAbsent(streamNodeId, k -> new HashMap()).computeIfAbsent(typeNumber, ignore -> new long[subpartitionBytes.size()]);
        Preconditions.checkState((subpartitionBytes.size() == aggregatedSubpartitionBytes.length ? 1 : 0) != 0);
        for (int i = 0; i < subpartitionBytes.size(); ++i) {
            int n = i;
            aggregatedSubpartitionBytes[n] = aggregatedSubpartitionBytes[n] + subpartitionBytes.get(i);
        }
    }

    private void applyAdaptiveSkewedJoinOptimization(StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode, FlinkJoinType joinType) {
        boolean isModificationSucceed;
        long[] leftInputSize = this.aggregatedProducedBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(1);
        Preconditions.checkState((leftInputSize != null ? 1 : 0) != 0, (String)"Left input bytes of adaptive join [%s] is unknown, which is unexpected.", (Object[])new Object[]{adaptiveJoinNode.toString()});
        long[] rightInputSize = this.aggregatedProducedBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(2);
        Preconditions.checkState((rightInputSize != null ? 1 : 0) != 0, (String)"Right input bytes of adaptive join [%s] is unknown, which is unexpected.", (Object[])new Object[]{adaptiveJoinNode.toString()});
        long leftSkewedThreshold = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold((long)VertexParallelismAndInputInfosDeciderUtils.median((long[])leftInputSize), (double)this.skewedFactor, (long)this.skewedThresholdInBytes);
        long rightSkewedThreshold = VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold((long)VertexParallelismAndInputInfosDeciderUtils.median((long[])rightInputSize), (double)this.skewedFactor, (long)this.skewedThresholdInBytes);
        boolean isLeftOptimizable = false;
        boolean isRightOptimizable = false;
        switch (joinType) {
            case RIGHT: {
                isRightOptimizable = true;
                break;
            }
            case INNER: {
                isLeftOptimizable = true;
                isRightOptimizable = true;
                break;
            }
            case LEFT: 
            case SEMI: 
            case ANTI: {
                isLeftOptimizable = true;
                break;
            }
            default: {
                throw new IllegalStateException(String.format("Unexpected join type %s.", new Object[]{joinType}));
            }
        }
        isRightOptimizable &= AdaptiveSkewedJoinOptimizationStrategy.existBytesLargerThanThreshold(rightInputSize, rightSkewedThreshold);
        if (isLeftOptimizable &= AdaptiveSkewedJoinOptimizationStrategy.existBytesLargerThanThreshold(leftInputSize, leftSkewedThreshold)) {
            isModificationSucceed = AdaptiveSkewedJoinOptimizationStrategy.tryModifyInputAndOutputEdges(context, adaptiveJoinNode, 1);
            LOG.info("Apply skewed join optimization {} for left input of node {}.", (Object)(isModificationSucceed ? "succeeded" : "failed"), (Object)adaptiveJoinNode);
        }
        if (isRightOptimizable) {
            isModificationSucceed = AdaptiveSkewedJoinOptimizationStrategy.tryModifyInputAndOutputEdges(context, adaptiveJoinNode, 2);
            LOG.info("Apply skewed join optimization {} for right input of node {}.", (Object)(isModificationSucceed ? "succeeded" : "failed"), (Object)adaptiveJoinNode);
        }
    }

    private static boolean tryModifyInputAndOutputEdges(StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode, int typeNumber) {
        ArrayList<StreamEdgeUpdateRequestInfo> modifiedRequests = new ArrayList<StreamEdgeUpdateRequestInfo>();
        modifiedRequests.addAll(AdaptiveSkewedJoinOptimizationStrategy.generateCorrelationModificationRequestInfos(AdaptiveJoinOptimizationUtils.filterEdges(adaptiveJoinNode.getInEdges(), typeNumber)));
        modifiedRequests.addAll(AdaptiveSkewedJoinOptimizationStrategy.generateForwardPartitionerModificationRequestInfos(adaptiveJoinNode.getOutEdges(), context));
        return context.modifyStreamEdge(modifiedRequests);
    }

    private static List<StreamEdgeUpdateRequestInfo> generateCorrelationModificationRequestInfos(List<ImmutableStreamEdge> streamEdges) {
        ArrayList<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<StreamEdgeUpdateRequestInfo>();
        for (ImmutableStreamEdge edge : streamEdges) {
            streamEdgeUpdateRequestInfos.add(new StreamEdgeUpdateRequestInfo(edge.getEdgeId(), Integer.valueOf(edge.getSourceId()), Integer.valueOf(edge.getTargetId())).withIntraInputKeyCorrelated(false));
        }
        return streamEdgeUpdateRequestInfos;
    }

    private static List<StreamEdgeUpdateRequestInfo> generateForwardPartitionerModificationRequestInfos(List<ImmutableStreamEdge> streamEdges, StreamGraphContext context) {
        ArrayList<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<StreamEdgeUpdateRequestInfo>();
        for (ImmutableStreamEdge edge : streamEdges) {
            if (!edge.isForwardForConsecutiveHashEdge()) continue;
            StreamPartitioner partitioner = (StreamPartitioner)Preconditions.checkNotNull((Object)context.getOutputPartitioner(edge.getEdgeId(), Integer.valueOf(edge.getSourceId()), Integer.valueOf(edge.getTargetId())));
            StreamPartitioner newPartitioner = ((ForwardForConsecutiveHashPartitioner)partitioner).getHashPartitioner();
            streamEdgeUpdateRequestInfos.add(new StreamEdgeUpdateRequestInfo(edge.getEdgeId(), Integer.valueOf(edge.getSourceId()), Integer.valueOf(edge.getTargetId())).withOutputPartitioner(newPartitioner));
        }
        return streamEdgeUpdateRequestInfos;
    }

    private void freeNodeStatistic(Integer nodeId) {
        this.aggregatedProducedBytesByTypeNumberAndNodeId.remove(nodeId);
    }

    private static boolean existBytesLargerThanThreshold(long[] inputBytes, long threshold) {
        for (long byteSize : inputBytes) {
            if (byteSize <= threshold) continue;
            return true;
        }
        return false;
    }
}

