package org.apache.sysml.runtime.instructions.spark.utils;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysml.runtime.instructions.spark.data.RowMatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.class */
public class RDDAggregateUtils {
    private static final boolean TREE_AGGREGATION = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$AggregateSingleBlockFunction.class */
    public static class AggregateSingleBlockFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -3672377410407066396L;
        private AggregateOperator _op;
        private MatrixBlock _corr = null;

        public AggregateSingleBlockFunction(AggregateOperator aggregateOperator) {
            this._op = null;
            this._op = aggregateOperator;
        }

        public MatrixBlock call(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) throws Exception {
            if (matrixBlock.getNumRows() == 0 && matrixBlock.getNumColumns() == 0) {
                matrixBlock.copy(matrixBlock2);
                return matrixBlock;
            }
            if (matrixBlock2.getNumRows() == 0 && matrixBlock2.getNumColumns() == 0) {
                return matrixBlock;
            }
            if (this._op.correctionExists && this._corr == null) {
                this._corr = new MatrixBlock(matrixBlock.getNumRows(), matrixBlock.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(matrixBlock, this._op.correctionExists ? this._corr : null, matrixBlock2, this._op, true);
            return matrixBlock;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$CreateBlockCombinerFunction.class */
    public static class CreateBlockCombinerFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1987501624176848292L;
        private final boolean _deep;

        public CreateBlockCombinerFunction(boolean z) {
            this._deep = z;
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return this._deep ? new MatrixBlock(matrixBlock) : matrixBlock;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$CreateCellCombinerFunction.class */
    public static class CreateCellCombinerFunction implements Function<Double, KahanObject> {
        private static final long serialVersionUID = 3697505233057172994L;

        private CreateCellCombinerFunction() {
        }

        public KahanObject call(Double d) throws Exception {
            return new KahanObject(d.doubleValue(), 0.0d);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$CreateCorrBlockCombinerFunction.class */
    public static class CreateCorrBlockCombinerFunction implements Function<MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = -3666451526776017343L;
        private final boolean _deep;

        public CreateCorrBlockCombinerFunction(boolean z) {
            this._deep = z;
        }

        public CorrMatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            return new CorrMatrixBlock(this._deep ? new MatrixBlock(matrixBlock) : matrixBlock);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$CreateRowBlockCombinerFunction.class */
    private static class CreateRowBlockCombinerFunction implements Function<RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 2866598914232118425L;

        private CreateRowBlockCombinerFunction() {
        }

        public MatrixBlock call(RowMatrixBlock rowMatrixBlock) throws Exception {
            MatrixBlock value = rowMatrixBlock.getValue();
            MatrixBlock matrixBlock = new MatrixBlock(rowMatrixBlock.getLen(), value.getNumColumns(), true);
            matrixBlock.copy(rowMatrixBlock.getRow(), rowMatrixBlock.getRow(), 0, value.getNumColumns() - 1, value, false);
            matrixBlock.setNonZeros(value.getNonZeros());
            matrixBlock.examSparsity();
            return matrixBlock;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$ExtractDoubleCell.class */
    public static class ExtractDoubleCell implements Function<KahanObject, Double> {
        private static final long serialVersionUID = -2873241816558275742L;

        private ExtractDoubleCell() {
        }

        public Double call(KahanObject kahanObject) throws Exception {
            return Double.valueOf(kahanObject._sum);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$ExtractMatrixBlock.class */
    public static class ExtractMatrixBlock implements Function<CorrMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5242158678070843495L;

        private ExtractMatrixBlock() {
        }

        public MatrixBlock call(CorrMatrixBlock corrMatrixBlock) throws Exception {
            corrMatrixBlock.getValue().examSparsity();
            return corrMatrixBlock.getValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeAggBlockCombinerFunction.class */
    public static class MergeAggBlockCombinerFunction implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 4803711632648880797L;
        private AggregateOperator _op;

        public MergeAggBlockCombinerFunction(AggregateOperator aggregateOperator) {
            this._op = null;
            this._op = aggregateOperator;
        }

        public CorrMatrixBlock call(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) throws Exception {
            MatrixBlock value = corrMatrixBlock.getValue();
            MatrixBlock value2 = corrMatrixBlock2.getValue();
            MatrixBlock correction = corrMatrixBlock.getCorrection();
            if (correction == null && this._op.correctionExists) {
                correction = corrMatrixBlock2.getCorrection() != null ? corrMatrixBlock2.getCorrection() : new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            if (this._op.correctionExists) {
                OperationsOnMatrixValues.incrementalAggregation(value, correction, value2, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value, null, value2, this._op, true);
            }
            return new CorrMatrixBlock(value, correction);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeAggBlockValueFunction.class */
    public static class MergeAggBlockValueFunction implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 389422125491172011L;
        private AggregateOperator _op;

        public MergeAggBlockValueFunction(AggregateOperator aggregateOperator) {
            this._op = null;
            this._op = aggregateOperator;
        }

        public CorrMatrixBlock call(CorrMatrixBlock corrMatrixBlock, MatrixBlock matrixBlock) throws Exception {
            MatrixBlock value = corrMatrixBlock.getValue();
            MatrixBlock correction = corrMatrixBlock.getCorrection();
            if (correction == null && this._op.correctionExists) {
                correction = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            if (this._op.correctionExists) {
                OperationsOnMatrixValues.incrementalAggregation(value, correction, matrixBlock, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value, null, matrixBlock, this._op, true);
            }
            return new CorrMatrixBlock(value, correction);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeBlocksFunction.class */
    public static class MergeBlocksFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -8881019027250258850L;
        private boolean _deep;

        public MergeBlocksFunction() {
            this(true);
        }

        public MergeBlocksFunction(boolean z) {
            this._deep = false;
            this._deep = z;
        }

        public MatrixBlock call(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) throws Exception {
            long nonZeros = matrixBlock.getNonZeros();
            long nonZeros2 = matrixBlock2.getNonZeros();
            if (matrixBlock.getNumRows() != matrixBlock2.getNumRows() || matrixBlock.getNumColumns() != matrixBlock2.getNumColumns()) {
                throw new DMLRuntimeException("Mismatched block sizes for: " + matrixBlock.getNumRows() + " " + matrixBlock.getNumColumns() + " " + matrixBlock2.getNumRows() + " " + matrixBlock2.getNumColumns());
            }
            MatrixBlock matrixBlock3 = this._deep ? new MatrixBlock(matrixBlock) : matrixBlock;
            matrixBlock3.merge(matrixBlock2, false);
            matrixBlock3.examSparsity();
            if (matrixBlock3.getNonZeros() != nonZeros + nonZeros2) {
                throw new DMLRuntimeException("Number of non-zeros does not match: " + matrixBlock3.getNonZeros() + " != " + nonZeros + " + " + nonZeros2);
            }
            return matrixBlock3;
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeRowBlockValueFunction.class */
    private static class MergeRowBlockValueFunction implements Function2<MatrixBlock, RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -803689998683298516L;

        private MergeRowBlockValueFunction() {
        }

        public MatrixBlock call(MatrixBlock matrixBlock, RowMatrixBlock rowMatrixBlock) throws Exception {
            MatrixBlock value = rowMatrixBlock.getValue();
            matrixBlock.copy(rowMatrixBlock.getRow(), rowMatrixBlock.getRow(), 0, value.getNumColumns() - 1, value, true);
            matrixBlock.examSparsity();
            return matrixBlock;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeSumBlockCombinerFunction.class */
    public static class MergeSumBlockCombinerFunction implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 7664941774566119853L;
        private AggregateOperator _op = new AggregateOperator(0.0d, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);
        private final boolean _deep;

        public MergeSumBlockCombinerFunction(boolean z) {
            this._deep = z;
        }

        public CorrMatrixBlock call(CorrMatrixBlock corrMatrixBlock, CorrMatrixBlock corrMatrixBlock2) throws Exception {
            MatrixBlock value = corrMatrixBlock.getValue();
            MatrixBlock value2 = corrMatrixBlock2.getValue();
            MatrixBlock correction = corrMatrixBlock.getCorrection();
            if (correction == null) {
                correction = corrMatrixBlock2.getCorrection() != null ? corrMatrixBlock2.getCorrection() : (value2.isEmptyBlock(false) || (!this._deep && value.isEmptyBlock(false))) ? null : new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(value, correction, value2, this._op, false, this._deep);
            return corrMatrixBlock.set(value, correction);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeSumBlockValueFunction.class */
    public static class MergeSumBlockValueFunction implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 3703543699467085539L;
        private AggregateOperator _op = new AggregateOperator(0.0d, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);
        private final boolean _deep;

        public MergeSumBlockValueFunction(boolean z) {
            this._deep = z;
        }

        public CorrMatrixBlock call(CorrMatrixBlock corrMatrixBlock, MatrixBlock matrixBlock) throws Exception {
            if (matrixBlock.isEmptyBlock(false)) {
                return corrMatrixBlock;
            }
            MatrixBlock value = corrMatrixBlock.getValue();
            MatrixBlock correction = corrMatrixBlock.getCorrection();
            if (correction == null && !matrixBlock.isEmptyBlock(false)) {
                correction = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(value, correction, matrixBlock, this._op, false, this._deep);
            return corrMatrixBlock.set(value, correction);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeSumCellCombinerFunction.class */
    public static class MergeSumCellCombinerFunction implements Function2<KahanObject, KahanObject, KahanObject> {
        private static final long serialVersionUID = 8726716909849119657L;

        private MergeSumCellCombinerFunction() {
        }

        public KahanObject call(KahanObject kahanObject, KahanObject kahanObject2) throws Exception {
            KahanPlus.getKahanPlusFnObject().execute2(kahanObject, kahanObject2._sum);
            return kahanObject;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$MergeSumCellValueFunction.class */
    public static class MergeSumCellValueFunction implements Function2<KahanObject, Double, KahanObject> {
        private static final long serialVersionUID = 468335171573184825L;

        private MergeSumCellValueFunction() {
        }

        public KahanObject call(KahanObject kahanObject, Double d) throws Exception {
            KahanPlus.getKahanPlusFnObject().execute2(kahanObject, d.doubleValue());
            return kahanObject;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils$SumSingleBlockFunction.class */
    public static class SumSingleBlockFunction implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1737038715965862222L;
        private AggregateOperator _op;
        private MatrixBlock _corr = null;
        private boolean _deep;

        public SumSingleBlockFunction(boolean z) {
            this._op = null;
            this._deep = false;
            this._op = new AggregateOperator(0.0d, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);
            this._deep = z;
        }

        public MatrixBlock call(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) throws Exception {
            if (matrixBlock.getNumRows() <= 0 || matrixBlock.getNumColumns() <= 0) {
                matrixBlock.copy(matrixBlock2);
                return matrixBlock;
            }
            if (matrixBlock2.getNumRows() <= 0 || matrixBlock2.getNumColumns() <= 0) {
                return matrixBlock;
            }
            if (this._corr == null) {
                this._corr = new MatrixBlock(matrixBlock.getNumRows(), matrixBlock.getNumColumns(), false);
            }
            MatrixBlock matrixBlock3 = this._deep ? new MatrixBlock(matrixBlock) : matrixBlock;
            OperationsOnMatrixValues.incrementalAggregation(matrixBlock3, this._corr, matrixBlock2, this._op, false);
            return matrixBlock3;
        }
    }

    public static MatrixBlock sumStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return sumStable((JavaRDD<MatrixBlock>) javaPairRDD.values());
    }

    public static MatrixBlock sumStable(JavaRDD<MatrixBlock> javaRDD) {
        return (MatrixBlock) javaRDD.fold(new MatrixBlock(), new SumSingleBlockFunction(false));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return sumByKeyStable(javaPairRDD, javaPairRDD.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, boolean z) {
        return sumByKeyStable(javaPairRDD, javaPairRDD.getNumPartitions(), z);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, int i, boolean z) {
        return javaPairRDD.combineByKey(new CreateCorrBlockCombinerFunction(z), new MergeSumBlockValueFunction(z), new MergeSumBlockCombinerFunction(z), i).mapValues(new ExtractMatrixBlock());
    }

    public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable(JavaPairRDD<MatrixIndexes, Double> javaPairRDD) {
        return sumCellsByKeyStable(javaPairRDD, javaPairRDD.getNumPartitions());
    }

    public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable(JavaPairRDD<MatrixIndexes, Double> javaPairRDD, int i) {
        return javaPairRDD.combineByKey(new CreateCellCombinerFunction(), new MergeSumCellValueFunction(), new MergeSumCellCombinerFunction(), i).mapValues(new ExtractDoubleCell());
    }

    public static MatrixBlock aggStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, AggregateOperator aggregateOperator) {
        return aggStable((JavaRDD<MatrixBlock>) javaPairRDD.values(), aggregateOperator);
    }

    public static MatrixBlock aggStable(JavaRDD<MatrixBlock> javaRDD, AggregateOperator aggregateOperator) {
        return (MatrixBlock) javaRDD.fold(new MatrixBlock(), new AggregateSingleBlockFunction(aggregateOperator));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, AggregateOperator aggregateOperator) {
        return aggByKeyStable(javaPairRDD, aggregateOperator, javaPairRDD.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, AggregateOperator aggregateOperator, boolean z) {
        return aggByKeyStable(javaPairRDD, aggregateOperator, javaPairRDD.getNumPartitions(), z);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, AggregateOperator aggregateOperator, int i, boolean z) {
        return javaPairRDD.combineByKey(new CreateCorrBlockCombinerFunction(z), new MergeAggBlockValueFunction(aggregateOperator), new MergeAggBlockCombinerFunction(aggregateOperator), i).mapValues(new ExtractMatrixBlock());
    }

    public static double max(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        AggregateUnaryOperator parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
        return aggStable((JavaRDD<MatrixBlock>) javaPairRDD.map(new AggregateUnarySPInstruction.RDDUAggFunction2(parseBasicAggregateUnaryOperator, -1, -1)), parseBasicAggregateUnaryOperator.aggOp).quickGetValue(0, 0);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return mergeByKey(javaPairRDD, javaPairRDD.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, boolean z) {
        return mergeByKey(javaPairRDD, javaPairRDD.getNumPartitions(), z);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, int i, boolean z) {
        return javaPairRDD.combineByKey(new CreateBlockCombinerFunction(z), new MergeBlocksFunction(false), new MergeBlocksFunction(false), i);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeRowsByKey(JavaPairRDD<MatrixIndexes, RowMatrixBlock> javaPairRDD) {
        return javaPairRDD.combineByKey(new CreateRowBlockCombinerFunction(), new MergeRowBlockValueFunction(), new MergeBlocksFunction(false));
    }
}
