package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.lops.PickByCount;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.class */
public class QuantilePickSPInstruction extends BinarySPInstruction {
    private PickByCount.OperationTypes _type;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$ExtractAndSumFunction.class */
    private static class ExtractAndSumFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -584044441055250489L;
        private long _minRowIndex;
        private long _maxRowIndex;
        private int _minPos;
        private int _maxPos;

        public ExtractAndSumFunction(long j, long j2, int i) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(j, i);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(j2, i);
            this._minPos = UtilFunctions.computeCellInBlock(j, i);
            this._maxPos = UtilFunctions.computeCellInBlock(j2, i);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            int i = matrixIndexes.getRowIndex() == this._minRowIndex ? this._minPos : 0;
            int numRows = matrixIndexes.getRowIndex() == this._maxRowIndex ? this._maxPos + 1 : matrixBlock.getNumRows();
            MatrixBlock matrixBlock2 = new MatrixBlock(1, 2, false);
            matrixBlock2.setValue(0, 0, matrixBlock.getNumColumns() == 1 ? sum(matrixBlock, i, numRows) : sumWeighted(matrixBlock, i, numRows));
            return new Tuple2<>(new MatrixIndexes(1L, 1L), matrixBlock2);
        }

        private static double sum(MatrixBlock matrixBlock, int i, int i2) {
            double d = 0.0d;
            for (int i3 = i; i3 < i2; i3++) {
                d += matrixBlock.quickGetValue(i3, 0);
            }
            return d;
        }

        private static double sumWeighted(MatrixBlock matrixBlock, int i, int i2) {
            double d = 0.0d;
            for (int i3 = i; i3 < i2; i3++) {
                d += matrixBlock.quickGetValue(i3, 0) * matrixBlock.quickGetValue(i3, 1);
            }
            return d;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$ExtractWeightedQuantileFunction.class */
    public static class ExtractWeightedQuantileFunction implements Function2<Integer, Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, Iterator<Tuple2<Integer, double[]>>> {
        private static final long serialVersionUID = 4879975971050093739L;
        private final MatrixCharacteristics _mc;
        private final double[] _qdKeys;
        private final long[] _qiKeys;
        private final int[] _qPIDs;
        private final double[] _offsets;

        public ExtractWeightedQuantileFunction(MatrixCharacteristics matrixCharacteristics, double[] dArr, long[] jArr, int[] iArr, double[] dArr2) {
            this._mc = matrixCharacteristics;
            this._qdKeys = dArr;
            this._qiKeys = jArr;
            this._qPIDs = iArr;
            this._offsets = dArr2;
        }

        public Iterator<Tuple2<Integer, double[]>> call(Integer num, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            if (!ArrayUtils.contains(this._qPIDs, num.intValue())) {
                return Collections.emptyIterator();
            }
            int count = (int) Arrays.stream(this._qPIDs).filter(i -> {
                return i == num.intValue();
            }).count();
            int[] iArr = new int[count];
            int i2 = 0;
            for (int i3 = 0; i3 < this._qPIDs.length; i3++) {
                if (this._qPIDs[i3] == num.intValue()) {
                    int i4 = i2;
                    i2++;
                    iArr[i4] = i3;
                }
            }
            double d = this._offsets[iArr[0]];
            ArrayList arrayList = new ArrayList();
            while (it.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock> next = it.next();
                MatrixIndexes matrixIndexes = (MatrixIndexes) next._1();
                MatrixBlock matrixBlock = (MatrixBlock) next._2();
                for (int i5 = 0; i5 < matrixBlock.getNumRows(); i5++) {
                    double quickGetValue = matrixBlock.quickGetValue(i5, 1);
                    for (int i6 = 0; i6 < count; i6++) {
                        if (d + quickGetValue >= this._qiKeys[iArr[i6]]) {
                            arrayList.add(new Tuple2(Integer.valueOf(iArr[i6]), new double[]{UtilFunctions.computeCellIndex(matrixIndexes.getRowIndex(), this._mc.getRowsPerBlock(), i5), (d + quickGetValue) - this._qdKeys[iArr[i6]], matrixBlock.quickGetValue(i5, 0)}));
                            this._qiKeys[iArr[i6]] = Long.MAX_VALUE;
                        }
                    }
                    d += quickGetValue;
                }
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$FilterFunction.class */
    private static class FilterFunction implements Function<Tuple2<MatrixIndexes, MatrixBlock>, Boolean> {
        private static final long serialVersionUID = -8249102381116157388L;
        private long _minRowIndex;
        private long _maxRowIndex;

        public FilterFunction(long j, long j2, int i) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(j, i);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(j2, i);
        }

        public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            long rowIndex = ((MatrixIndexes) tuple2._1()).getRowIndex();
            return Boolean.valueOf(rowIndex >= this._minRowIndex && rowIndex <= this._maxRowIndex);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$SumWeightsFunction.class */
    public static class SumWeightsFunction implements Function2<Integer, Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, Iterator<Tuple2<Integer, Double>>> {
        private static final long serialVersionUID = 7169831202450745373L;

        private SumWeightsFunction() {
        }

        public Iterator<Tuple2<Integer, Double>> call(Integer num, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            double d = 0.0d;
            while (true) {
                double d2 = d;
                if (!it.hasNext()) {
                    return Arrays.asList(new Tuple2(num, Double.valueOf(d2))).iterator();
                }
                d = d2 + ((MatrixBlock) it.next()._2()).sumWeightForQuantile();
            }
        }
    }

    private QuantilePickSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, PickByCount.OperationTypes operationTypes, boolean z, String str, String str2) {
        this(operator, cPOperand, null, cPOperand2, operationTypes, z, str, str2);
    }

    private QuantilePickSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, PickByCount.OperationTypes operationTypes, boolean z, String str, String str2) {
        super(SPInstruction.SPType.QPick, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._type = null;
        this._type = operationTypes;
    }

    public static QuantilePickSPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(PickByCount.OPCODE)) {
            throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
        }
        if (instructionPartsWithValueType.length == 4) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), PickByCount.OperationTypes.IQM, false, str2, str);
        }
        if (instructionPartsWithValueType.length == 5) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), PickByCount.OperationTypes.valueOf(instructionPartsWithValueType[3]), Boolean.parseBoolean(instructionPartsWithValueType[4]), str2, str);
        }
        if (instructionPartsWithValueType.length == 6) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), PickByCount.OperationTypes.valueOf(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), str2, str);
        }
        return null;
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        switch (this._type) {
            case VALUEPICK:
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(getWeightedQuantileSummary(binaryBlockRDDHandleForVariable, matrixCharacteristics, Double.valueOf(executionContext.getScalarInput(this.input2).getDoubleValue()))[3]));
                return;
            case MEDIAN:
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(getWeightedQuantileSummary(binaryBlockRDDHandleForVariable, matrixCharacteristics, Double.valueOf(0.5d))[3]));
                return;
            case IQM:
                double[] weightedQuantileSummary = getWeightedQuantileSummary(binaryBlockRDDHandleForVariable, matrixCharacteristics, Double.valueOf(0.25d), Double.valueOf(0.75d));
                long ceil = (long) Math.ceil(weightedQuantileSummary[1]);
                long ceil2 = (long) Math.ceil(weightedQuantileSummary[2]);
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(MatrixBlock.computeIQMCorrection(RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>) binaryBlockRDDHandleForVariable.filter(new FilterFunction(ceil + 1, ceil2, matrixCharacteristics.getRowsPerBlock())).mapToPair(new ExtractAndSumFunction(ceil + 1, ceil2, matrixCharacteristics.getRowsPerBlock()))).getValue(0, 0), weightedQuantileSummary[0], weightedQuantileSummary[3], weightedQuantileSummary[5], weightedQuantileSummary[4], weightedQuantileSummary[6])));
                return;
            default:
                throw new DMLRuntimeException("Unsupported qpick operation type: " + this._type);
        }
    }

    private static double[] getWeightedQuantileSummary(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, Double... dArr) {
        double[] dArr2 = new double[(3 * dArr.length) + 1];
        if (matrixCharacteristics.getCols() == 2) {
            JavaPairRDD sortByKey = javaPairRDD.sortByKey();
            List<Tuple2> collect = sortByKey.mapPartitionsWithIndex(new SumWeightsFunction(), false).collect();
            dArr2[0] = collect.stream().mapToDouble(tuple2 -> {
                return ((Double) tuple2._2()).doubleValue();
            }).sum();
            double[] dArr3 = new double[dArr.length];
            long[] jArr = new long[dArr.length];
            int[] iArr = new int[dArr.length];
            double[] dArr4 = new double[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr3[i] = dArr[i].doubleValue() * dArr2[0];
                jArr[i] = (long) Math.ceil(dArr3[i]);
            }
            double d = 0.0d;
            for (Tuple2 tuple22 : collect) {
                double doubleValue = d + ((Double) tuple22._2()).doubleValue();
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (doubleValue >= jArr[i2] && iArr[i2] == 0) {
                        iArr[i2] = ((Integer) tuple22._1()).intValue();
                        dArr4[i2] = d;
                    }
                }
                d = doubleValue;
            }
            for (Tuple2 tuple23 : sortByKey.mapPartitionsWithIndex(new ExtractWeightedQuantileFunction(matrixCharacteristics, dArr3, jArr, iArr, dArr4), false).collect()) {
                dArr2[((Integer) tuple23._1()).intValue() + 1] = ((double[]) tuple23._2())[0];
                dArr2[((Integer) tuple23._1()).intValue() + dArr.length + 1] = ((double[]) tuple23._2())[1];
                dArr2[((Integer) tuple23._1()).intValue() + (2 * dArr.length) + 1] = ((double[]) tuple23._2())[2];
            }
        } else {
            dArr2[0] = matrixCharacteristics.getRows();
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr2[i3 + 1] = dArr[i3].doubleValue() * matrixCharacteristics.getRows();
                dArr2[i3 + dArr.length + 1] = Math.ceil(dArr2[i3 + 1]) - dArr2[i3 + 1];
                dArr2[i3 + (2 * dArr.length) + 1] = lookupKey(javaPairRDD, (long) Math.ceil(dArr2[i3 + 1]), matrixCharacteristics.getRowsPerBlock());
            }
        }
        return dArr2;
    }

    private static double lookupKey(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, long j, int i) {
        long computeBlockIndex = UtilFunctions.computeBlockIndex(j, i);
        long computeCellInBlock = UtilFunctions.computeCellInBlock(j, i);
        List lookup = javaPairRDD.lookup(new MatrixIndexes(computeBlockIndex, 1L));
        if (lookup.isEmpty()) {
            throw new DMLRuntimeException("Invalid key lookup in empty list.");
        }
        MatrixBlock matrixBlock = (MatrixBlock) lookup.get(0);
        if (matrixBlock.getNumRows() <= computeCellInBlock) {
            throw new DMLRuntimeException("Invalid key lookup for " + computeCellInBlock + " in block of size " + matrixBlock.getNumRows() + "x" + matrixBlock.getNumColumns());
        }
        return ((MatrixBlock) lookup.get(0)).quickGetValue((int) computeCellInBlock, 0);
    }
}
