package org.apache.sysml.runtime.controlprogram.parfor;

import java.util.Collections;
import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.io.Writable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.util.PairWritableBlock;
import org.apache.sysml.runtime.instructions.spark.data.DatasetObject;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.class */
public class RemoteDPParForSpark {
    protected static final Log LOG = LogFactory.getLog(RemoteDPParForSpark.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark$DataFrameToRowBinaryBlockFunction.class */
    public static class DataFrameToRowBinaryBlockFunction implements PairFunction<Tuple2<Row, Long>, Long, Writable> {
        private static final long serialVersionUID = -3162404379379461523L;
        private final long _clen;
        private final boolean _containsID;
        private final boolean _isVector;

        public DataFrameToRowBinaryBlockFunction(long j, boolean z, boolean z2) {
            this._clen = j;
            this._containsID = z;
            this._isVector = z2;
        }

        public Tuple2<Long, Writable> call(Tuple2<Row, Long> tuple2) throws Exception {
            long longValue = ((Long) tuple2._2()).longValue() + 1;
            int i = this._containsID ? 1 : 0;
            Object _1 = this._isVector ? ((Row) tuple2._1()).get(i) : tuple2._1();
            MatrixBlock matrixBlock = new MatrixBlock(1, (int) this._clen, _1 instanceof SparseVector);
            if (this._isVector) {
                SparseVector sparseVector = (Vector) _1;
                if (sparseVector instanceof SparseVector) {
                    SparseVector sparseVector2 = sparseVector;
                    int numNonzeros = sparseVector2.numNonzeros();
                    for (int i2 = 0; i2 < numNonzeros; i2++) {
                        matrixBlock.appendValue(0, sparseVector2.indices()[i2], sparseVector2.values()[i2]);
                    }
                } else {
                    for (int i3 = 0; i3 < this._clen; i3++) {
                        matrixBlock.appendValue(0, i3, sparseVector.apply(i3));
                    }
                }
            } else {
                Row row = (Row) _1;
                for (int i4 = i; i4 < i + this._clen; i4++) {
                    matrixBlock.appendValue(0, i4 - i, UtilFunctions.getDouble(row.get(i4)));
                }
            }
            matrixBlock.examSparsity();
            return new Tuple2<>(Long.valueOf(longValue), new PairWritableBlock(new MatrixIndexes(1L, 1L), matrixBlock));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark$PseudoGrouping.class */
    public static class PseudoGrouping implements Function<Tuple2<Long, Writable>, Tuple2<Long, Iterable<Writable>>> {
        private static final long serialVersionUID = 2016614593596923995L;

        private PseudoGrouping() {
        }

        public Tuple2<Long, Iterable<Writable>> call(Tuple2<Long, Writable> tuple2) {
            return new Tuple2<>(tuple2._1(), Collections.singletonList(tuple2._2()));
        }
    }

    public static RemoteParForJobReturn runJob(long j, String str, String str2, String str3, HashMap<String, byte[]> hashMap, String str4, MatrixObject matrixObject, ExecutionContext executionContext, ParForProgramBlock.PartitionFormat partitionFormat, OutputInfo outputInfo, boolean z, boolean z2, int i) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        JavaSparkContext sparkContext = sparkExecutionContext.getSparkContext();
        MatrixObject matrixObject2 = sparkExecutionContext.getMatrixObject(str2);
        MatrixCharacteristics matrixCharacteristics = matrixObject2.getMatrixCharacteristics();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(str2);
        LongAccumulator longAccumulator = sparkContext.sc().longAccumulator("tasks");
        LongAccumulator longAccumulator2 = sparkContext.sc().longAccumulator("iterations");
        int max = Math.max(i, Math.min(SparkUtils.getNumPreferredPartitions(matrixCharacteristics, binaryBlockRDDHandleForVariable), (int) partitionFormat.getNumParts(matrixCharacteristics)));
        RemoteDPParForSparkWorker remoteDPParForSparkWorker = new RemoteDPParForSparkWorker(str3, hashMap, str2, str, z2, matrixCharacteristics, z, partitionFormat, outputInfo, longAccumulator, longAccumulator2);
        JavaPairRDD<Long, Writable> partitionedInput = getPartitionedInput(sparkExecutionContext, str2, outputInfo, partitionFormat);
        RemoteParForJobReturn remoteParForJobReturn = new RemoteParForJobReturn(true, longAccumulator.value().intValue(), longAccumulator2.value().intValue(), RemoteParForUtils.getResults((requiresGrouping(partitionFormat, matrixObject2) ? partitionedInput.groupByKey(max) : partitionedInput.map(new PseudoGrouping())).mapPartitionsToPair(remoteDPParForSparkWorker).collect(), LOG));
        Statistics.incrementNoOfCompiledSPInst();
        Statistics.incrementNoOfExecutedSPInst();
        if (DMLScript.STATISTICS) {
            Statistics.maintainCPHeavyHitters("ParFor-DPESP", System.nanoTime() - nanoTime);
        }
        return remoteParForJobReturn;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static JavaPairRDD<Long, Writable> getPartitionedInput(SparkExecutionContext sparkExecutionContext, String str, OutputInfo outputInfo, ParForProgramBlock.PartitionFormat partitionFormat) {
        InputInfo inputInfo = InputInfo.BinaryBlockInputInfo;
        MatrixObject matrixObject = sparkExecutionContext.getMatrixObject(str);
        MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
        if (hasInputDataSet(partitionFormat, matrixObject)) {
            DatasetObject datasetObject = (DatasetObject) matrixObject.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0);
            Dataset<Row> dataset = datasetObject.getDataset();
            return (datasetObject.containsID() ? dataset.javaRDD().mapToPair(new RDDConverterUtils.DataFrameExtractIDFunction(dataset.schema().fieldIndex(RDDConverterUtils.DF_ID_COLUMN))) : dataset.javaRDD().zipWithIndex()).mapToPair(new DataFrameToRowBinaryBlockFunction(matrixCharacteristics.getCols(), datasetObject.isVectorBased(), datasetObject.containsID()));
        }
        if (!requiresGrouping(partitionFormat, matrixObject)) {
            return sparkExecutionContext.getBinaryBlockRDDHandleForVariable(str).flatMapToPair(new DataPartitionerRemoteSparkMapper(matrixCharacteristics, inputInfo, outputInfo, partitionFormat._dpf, partitionFormat._N));
        }
        JavaPairRDD binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(str);
        if (matrixObject.getRDDHandle().isCheckpointRDD() && !sparkExecutionContext.isRDDCached(binaryBlockRDDHandleForVariable.id())) {
            binaryBlockRDDHandleForVariable = ((RDDObject) matrixObject.getRDDHandle().getLineageChilds().get(0)).getRDD();
        }
        return binaryBlockRDDHandleForVariable.flatMapToPair(new DataPartitionerRemoteSparkMapper(matrixCharacteristics, inputInfo, outputInfo, partitionFormat._dpf, partitionFormat._N));
    }

    private static boolean requiresGrouping(ParForProgramBlock.PartitionFormat partitionFormat, MatrixObject matrixObject) {
        MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
        return ((partitionFormat == ParForProgramBlock.PartitionFormat.ROW_WISE && matrixCharacteristics.getNumColBlocks() > 1) || ((partitionFormat == ParForProgramBlock.PartitionFormat.COLUMN_WISE && matrixCharacteristics.getNumRowBlocks() > 1) || ((partitionFormat._dpf == ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N && matrixCharacteristics.getNumColBlocks() > 1) || (partitionFormat._dpf == ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N && matrixCharacteristics.getNumRowBlocks() > 1)))) && !hasInputDataSet(partitionFormat, matrixObject);
    }

    private static boolean hasInputDataSet(ParForProgramBlock.PartitionFormat partitionFormat, MatrixObject matrixObject) {
        return partitionFormat == ParForProgramBlock.PartitionFormat.ROW_WISE && matrixObject.getRDDHandle().isCheckpointRDD() && matrixObject.getRDDHandle().getLineageChilds().size() == 1 && matrixObject.getRDDHandle().getLineageChilds().get(0).getLineageChilds().size() == 1 && (matrixObject.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0) instanceof DatasetObject);
    }
}
