/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.strategy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingResultInfo;
import org.apache.flink.runtime.scheduler.adaptivebatch.OperatorsFinished;
import org.apache.flink.streaming.api.graph.StreamGraphContext;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamEdge;
import org.apache.flink.streaming.api.graph.util.ImmutableStreamNode;
import org.apache.flink.streaming.api.graph.util.StreamEdgeUpdateRequestInfo;
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.partitioner.ForwardForUnspecifiedPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.adaptive.AdaptiveJoin;
import org.apache.flink.table.runtime.strategy.AdaptiveJoinOptimizationUtils;
import org.apache.flink.table.runtime.strategy.BaseAdaptiveJoinOperatorOptimizationStrategy;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveBroadcastJoinOptimizationStrategy
extends BaseAdaptiveJoinOperatorOptimizationStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(AdaptiveBroadcastJoinOptimizationStrategy.class);
    private Long broadcastThreshold;
    private Map<Integer, Map<Integer, Long>> aggregatedInputBytesByTypeNumberAndNodeId;
    private Set<Integer> optimizedAdaptiveJoinNodes;

    public void initialize(StreamGraphContext context) {
        ReadableConfig config = context.getStreamGraph().getConfiguration();
        this.broadcastThreshold = (Long)config.get(OptimizerConfigOptions.TABLE_OPTIMIZER_BROADCAST_JOIN_THRESHOLD);
        this.aggregatedInputBytesByTypeNumberAndNodeId = new HashMap<Integer, Map<Integer, Long>>();
        this.optimizedAdaptiveJoinNodes = new HashSet<Integer>();
    }

    public boolean onOperatorsFinished(OperatorsFinished operatorsFinished, StreamGraphContext context) {
        this.visitDownstreamAdaptiveJoinNode(operatorsFinished, context);
        return true;
    }

    @Override
    protected void tryOptimizeAdaptiveJoin(OperatorsFinished operatorsFinished, StreamGraphContext context, ImmutableStreamNode adaptiveJoinNode, List<ImmutableStreamEdge> upstreamStreamEdges, AdaptiveJoin adaptiveJoin) {
        if (!this.canPerformOptimization(adaptiveJoinNode, context) || this.optimizedAdaptiveJoinNodes.contains(adaptiveJoinNode.getId())) {
            return;
        }
        for (ImmutableStreamEdge upstreamEdge : upstreamStreamEdges) {
            IntermediateDataSetID relatedDataSetId = context.getConsumedIntermediateDataSetId(upstreamEdge.getEdgeId());
            long producedBytes = ((List)operatorsFinished.getResultInfoMap().get(upstreamEdge.getSourceId())).stream().filter(blockingResultInfo -> relatedDataSetId.equals((Object)blockingResultInfo.getResultId())).mapToLong(BlockingResultInfo::getNumBytesProduced).sum();
            this.aggregatedInputBytesByTypeNumber(adaptiveJoinNode, upstreamEdge.getTypeNumber(), producedBytes);
        }
        FlinkJoinType joinType = adaptiveJoin.getJoinType();
        Long leftInputSize = null;
        Long rightInputSize = null;
        if (context.checkUpstreamNodesFinished(adaptiveJoinNode, Integer.valueOf(1))) {
            leftInputSize = this.aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(1);
            Preconditions.checkState((leftInputSize != null ? 1 : 0) != 0, (String)"Left input bytes of adaptive join [%s] is unknown, which is unexpected.", (Object[])new Object[]{adaptiveJoinNode.getId()});
            boolean leftIsBuild = true;
            if (this.checkInputSideCanBeBroadcast(joinType, leftIsBuild, leftInputSize).booleanValue() && this.tryBroadcastOptimization(adaptiveJoinNode, context, adaptiveJoin, leftIsBuild, leftInputSize)) {
                return;
            }
        }
        if (context.checkUpstreamNodesFinished(adaptiveJoinNode, Integer.valueOf(2))) {
            rightInputSize = this.aggregatedInputBytesByTypeNumberAndNodeId.get(adaptiveJoinNode.getId()).get(2);
            Preconditions.checkState((rightInputSize != null ? 1 : 0) != 0, (String)"Right input bytes of adaptive join [%s] is unknown, which is unexpected.", (Object[])new Object[]{adaptiveJoinNode.getId()});
            boolean leftIsBuild = false;
            if (this.checkInputSideCanBeBroadcast(joinType, leftIsBuild, rightInputSize).booleanValue() && this.tryBroadcastOptimization(adaptiveJoinNode, context, adaptiveJoin, leftIsBuild, rightInputSize)) {
                return;
            }
        }
        if (leftInputSize != null && rightInputSize != null) {
            LOG.debug("The size of the specified side of the input data for the join node [{}] is too large to be converted into a broadcast hash join. The Join type: {}, Broadcast threshold: {} bytes, Left input size: {} bytes, Right input size: {} bytes.", new Object[]{adaptiveJoinNode.getId(), joinType, this.broadcastThreshold, leftInputSize, rightInputSize});
            boolean leftSmallerThanRight = leftInputSize < rightInputSize;
            adaptiveJoin.markAsBroadcastJoin(false, leftSmallerThanRight);
            this.optimizedAdaptiveJoinNodes.add(adaptiveJoinNode.getId());
            this.aggregatedInputBytesByTypeNumberAndNodeId.remove(adaptiveJoinNode.getId());
        }
    }

    private boolean tryBroadcastOptimization(ImmutableStreamNode adaptiveJoinNode, StreamGraphContext context, AdaptiveJoin adaptiveJoin, boolean leftIsBuild, long inputBytes) {
        if (this.tryModifyStreamEdgesForBroadcastJoin(adaptiveJoinNode.getInEdges(), context, leftIsBuild)) {
            LOG.info("The {} input data size of the join node [{}] is small enough, adaptively convert it to a broadcast hash join. Broadcast threshold bytes: {}, actual input bytes: {}.", new Object[]{leftIsBuild ? "left" : "right", adaptiveJoinNode.getId(), this.broadcastThreshold, inputBytes});
            adaptiveJoin.markAsBroadcastJoin(true, leftIsBuild);
            this.optimizedAdaptiveJoinNodes.add(adaptiveJoinNode.getId());
            this.aggregatedInputBytesByTypeNumberAndNodeId.remove(adaptiveJoinNode.getId());
            return true;
        }
        LOG.info("Modification to stream edges for the join node [{}] failed. Keep the join node as is.", (Object)adaptiveJoinNode.getId());
        return false;
    }

    private Boolean checkInputSideCanBeBroadcast(FlinkJoinType joinType, boolean isLeftBuild, long producedBytes) {
        if (producedBytes < this.broadcastThreshold) {
            switch (joinType) {
                case RIGHT: {
                    return isLeftBuild;
                }
                case INNER: {
                    return true;
                }
                case LEFT: 
                case SEMI: 
                case ANTI: {
                    return !isLeftBuild;
                }
            }
            throw new RuntimeException(String.format("Unexpected join type %s.", new Object[]{joinType}));
        }
        return false;
    }

    private boolean canPerformOptimization(ImmutableStreamNode adaptiveJoinNode, StreamGraphContext context) {
        if (AdaptiveJoinOptimizationUtils.isBroadcastJoinDisabled(context.getStreamGraph().getConfiguration()) || AdaptiveJoinOptimizationUtils.isBroadcastJoin(adaptiveJoinNode)) {
            return false;
        }
        return AdaptiveBroadcastJoinOptimizationStrategy.canPerformOptimizationAutomatic(context, adaptiveJoinNode);
    }

    private void aggregatedInputBytesByTypeNumber(ImmutableStreamNode adaptiveJoinNode, int typeNumber, long producedBytes) {
        Integer streamNodeId = adaptiveJoinNode.getId();
        this.aggregatedInputBytesByTypeNumberAndNodeId.computeIfAbsent(streamNodeId, k -> new HashMap()).merge(typeNumber, producedBytes, Long::sum);
    }

    private List<StreamEdgeUpdateRequestInfo> generateStreamEdgeUpdateRequestInfos(List<ImmutableStreamEdge> modifiedEdges, StreamPartitioner<?> outputPartitioner) {
        ArrayList<StreamEdgeUpdateRequestInfo> streamEdgeUpdateRequestInfos = new ArrayList<StreamEdgeUpdateRequestInfo>();
        for (ImmutableStreamEdge streamEdge : modifiedEdges) {
            StreamEdgeUpdateRequestInfo streamEdgeUpdateRequestInfo = new StreamEdgeUpdateRequestInfo(streamEdge.getEdgeId(), Integer.valueOf(streamEdge.getSourceId()), Integer.valueOf(streamEdge.getTargetId())).withOutputPartitioner(outputPartitioner);
            streamEdgeUpdateRequestInfos.add(streamEdgeUpdateRequestInfo);
        }
        return streamEdgeUpdateRequestInfos;
    }

    private boolean tryModifyStreamEdgesForBroadcastJoin(List<ImmutableStreamEdge> inEdges, StreamGraphContext context, boolean leftIsBuild) {
        List<StreamEdgeUpdateRequestInfo> modifiedBuildSideEdges = this.generateStreamEdgeUpdateRequestInfos(AdaptiveJoinOptimizationUtils.filterEdges(inEdges, leftIsBuild ? 1 : 2), (StreamPartitioner<?>)new BroadcastPartitioner());
        List<StreamEdgeUpdateRequestInfo> modifiedProbeSideEdges = this.generateStreamEdgeUpdateRequestInfos(AdaptiveJoinOptimizationUtils.filterEdges(inEdges, leftIsBuild ? 2 : 1), (StreamPartitioner<?>)new ForwardForUnspecifiedPartitioner());
        modifiedBuildSideEdges.addAll(modifiedProbeSideEdges);
        return context.modifyStreamEdge(modifiedBuildSideEdges);
    }
}

