/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableProjectSemiAntiJoinTransposeRule;
import org.immutables.value.Value;

@Value.Enclosing
public class ProjectSemiAntiJoinTransposeRule
extends RelRule<ProjectSemiAntiJoinTransposeRuleConfig> {
    public static final ProjectSemiAntiJoinTransposeRule INSTANCE = ProjectSemiAntiJoinTransposeRuleConfig.DEFAULT.toRule();

    private ProjectSemiAntiJoinTransposeRule(ProjectSemiAntiJoinTransposeRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        LogicalJoin join = (LogicalJoin)call.rel(1);
        JoinRelType joinType = join.getJoinType();
        return joinType == JoinRelType.SEMI || joinType == JoinRelType.ANTI;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        int leftFieldCount;
        int allInputFieldCount;
        LogicalProject project = (LogicalProject)call.rel(0);
        LogicalJoin join = (LogicalJoin)call.rel(1);
        ImmutableBitSet joinCondFields = RelOptUtil.InputFinder.bits(join.getCondition());
        ImmutableBitSet projectFields = RelOptUtil.InputFinder.bits(project.getProjects(), null);
        ImmutableBitSet allNeededFields = projectFields.isEmpty() ? joinCondFields.union(ImmutableBitSet.of(0)) : joinCondFields.union(projectFields);
        if (allNeededFields.equals(ImmutableBitSet.range(0, allInputFieldCount = (leftFieldCount = join.getLeft().getRowType().getFieldCount()) + join.getRight().getRowType().getFieldCount()))) {
            return;
        }
        ImmutableBitSet leftNeededFields = ImmutableBitSet.range(0, leftFieldCount).intersect(allNeededFields);
        ImmutableBitSet rightNeededFields = ImmutableBitSet.range(leftFieldCount, allInputFieldCount).intersect(allNeededFields);
        RelNode newLeftInput = this.createNewJoinInput(call.builder(), join.getLeft(), leftNeededFields, 0);
        RelNode newRightInput = this.createNewJoinInput(call.builder(), join.getRight(), rightNeededFields, leftFieldCount);
        Mappings.TargetMapping mapping = Mappings.target(i -> allNeededFields.indexOf(i), allInputFieldCount, allNeededFields.cardinality());
        Join newJoin = this.createNewJoin(join, mapping, newLeftInput, newRightInput);
        List<RexNode> newProjects = this.createNewProjects(project, newJoin, mapping);
        RelNode topProject = call.builder().push(newJoin).project(newProjects, project.getRowType().getFieldNames()).build();
        call.transformTo(topProject);
    }

    private RelNode createNewJoinInput(RelBuilder relBuilder, RelNode originInput, ImmutableBitSet inputNeededFields, int offset) {
        RexBuilder rexBuilder = originInput.getCluster().getRexBuilder();
        RelDataTypeFactory.FieldInfoBuilder typeBuilder = relBuilder.getTypeFactory().builder();
        ArrayList<RexInputRef> newProjects = new ArrayList<RexInputRef>();
        ArrayList<String> newFieldNames = new ArrayList<String>();
        for (int i : inputNeededFields.toList()) {
            newProjects.add(rexBuilder.makeInputRef(originInput, i - offset));
            newFieldNames.add(originInput.getRowType().getFieldNames().get(i - offset));
            ((RelDataTypeFactory.Builder)typeBuilder).add(originInput.getRowType().getFieldList().get(i - offset));
        }
        return relBuilder.push(originInput).project(newProjects, newFieldNames).build();
    }

    private Join createNewJoin(Join originJoin, Mappings.TargetMapping mapping, RelNode newLeftInput, RelNode newRightInput) {
        RexNode newCondition = this.rewriteJoinCondition(originJoin, mapping);
        return LogicalJoin.create(newLeftInput, newRightInput, Collections.emptyList(), newCondition, originJoin.getVariablesSet(), originJoin.getJoinType());
    }

    private RexNode rewriteJoinCondition(final Join originJoin, final Mappings.TargetMapping mapping) {
        final RexBuilder rexBuilder = originJoin.getCluster().getRexBuilder();
        RexShuttle rexShuttle = new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef ref) {
                int leftFieldCount = originJoin.getLeft().getRowType().getFieldCount();
                RelDataType fieldType = ref.getIndex() < leftFieldCount ? originJoin.getLeft().getRowType().getFieldList().get(ref.getIndex()).getType() : originJoin.getRight().getRowType().getFieldList().get(ref.getIndex() - leftFieldCount).getType();
                return rexBuilder.makeInputRef(fieldType, mapping.getTarget(ref.getIndex()));
            }
        };
        return originJoin.getCondition().accept(rexShuttle);
    }

    private List<RexNode> createNewProjects(Project originProject, final RelNode newInput, final Mappings.TargetMapping mapping) {
        final RexBuilder rexBuilder = originProject.getCluster().getRexBuilder();
        RexShuttle projectShuffle = new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef ref) {
                return rexBuilder.makeInputRef(newInput, mapping.getTarget(ref.getIndex()));
            }
        };
        return originProject.getProjects().stream().map(p -> p.accept(projectShuffle)).collect(Collectors.toList());
    }

    @Value.Immutable(singleton=false)
    public static interface ProjectSemiAntiJoinTransposeRuleConfig
    extends RelRule.Config {
        public static final ProjectSemiAntiJoinTransposeRuleConfig DEFAULT = ImmutableProjectSemiAntiJoinTransposeRule.ProjectSemiAntiJoinTransposeRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(LogicalProject.class).inputs(b1 -> b1.operand(LogicalJoin.class).anyInputs())).withDescription("ProjectSemiAntiJoinTransposeRule");

        @Override
        default public ProjectSemiAntiJoinTransposeRule toRule() {
            return new ProjectSemiAntiJoinTransposeRule(this);
        }
    }
}

