/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlJsonArrayAggAggFunction;
import org.apache.calcite.sql.fun.SqlJsonObjectAggAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;

@Internal
public class WrapJsonAggFunctionArgumentsRule
extends RelRule<Config> {
    public static final RelOptRule INSTANCE = Config.EMPTY.as(Config.class).onJsonAggregateFunctions().toRule();
    private static final RelHint MARKER_HINT = RelHint.builder("JSON_AGGREGATE_WRAPPED").build();

    public WrapJsonAggFunctionArgumentsRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        AggregateCall aggCall = aggregate.getAggCallList().get(0);
        RelNode aggInput = aggregate.getInput();
        RelBuilder relBuilder = call.builder().push(aggInput);
        List<Integer> affectedArgs = this.getAffectedArgs(aggCall);
        this.addProjections(aggregate.getCluster(), relBuilder, affectedArgs);
        Mappings.TargetMapping argsMapping = this.getAggArgsMapping(aggInput.getRowType().getFieldCount(), affectedArgs);
        AggregateCall newAggregateCall = aggCall.transform(argsMapping);
        Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.getGroupSet(), aggregate.getGroupSets(), (List)Collections.singletonList(newAggregateCall));
        call.transformTo(((LogicalAggregate)newAggregate).withHints(Collections.singletonList(MARKER_HINT)));
    }

    private List<Integer> getAffectedArgs(AggregateCall aggCall) {
        if (aggCall.getAggregation() instanceof SqlJsonObjectAggAggFunction) {
            int valueIndex = aggCall.getArgList().get(1);
            return Collections.singletonList(valueIndex);
        }
        return aggCall.getArgList().stream().distinct().collect(Collectors.toList());
    }

    private void addProjections(RelOptCluster cluster, RelBuilder relBuilder, List<Integer> affectedArgs) {
        BridgingSqlFunction operandToStringOperator = BridgingSqlFunction.of(cluster, BuiltInFunctionDefinitions.JSON_STRING);
        ArrayList<RexNode> projects = new ArrayList<RexNode>();
        affectedArgs.stream().map(argIdx -> relBuilder.call((SqlOperator)operandToStringOperator, relBuilder.field((int)argIdx))).forEach(projects::add);
        relBuilder.projectPlus(projects);
    }

    private Mappings.TargetMapping getAggArgsMapping(int inputCount, List<Integer> affectedArgs) {
        int newCount = inputCount + affectedArgs.size();
        Mapping argsMapping = Mappings.create(MappingType.BIJECTION, newCount, newCount);
        for (int i = 0; i < affectedArgs.size(); ++i) {
            argsMapping.set(affectedArgs.get(i), inputCount + i);
        }
        return argsMapping;
    }

    private static boolean isJsonAggregation(AggregateCall aggCall) {
        SqlAggFunction aggregation = aggCall.getAggregation();
        return aggregation instanceof SqlJsonObjectAggAggFunction || aggregation instanceof SqlJsonArrayAggAggFunction;
    }

    public static interface Config
    extends RelRule.Config {
        @Override
        default public RelOptRule toRule() {
            return new WrapJsonAggFunctionArgumentsRule(this);
        }

        default public Config onJsonAggregateFunctions() {
            Predicate<LogicalAggregate> jsonAggPredicate = aggregate -> aggregate.getAggCallList().stream().anyMatch(x$0 -> WrapJsonAggFunctionArgumentsRule.isJsonAggregation(x$0));
            RelRule.OperandTransform aggTransform = operandBuilder -> operandBuilder.operand(LogicalAggregate.class).predicate(jsonAggPredicate).anyInputs();
            return this.withOperandSupplier(aggTransform).as(Config.class);
        }
    }
}

