package org.apache.sysml.runtime.instructions.spark.utils;

import java.io.IOException;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.hadoop.io.Text;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.mllib.util.NumericParser;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer;
import org.apache.sysml.runtime.util.FastStringTokenizer;
import scala.Tuple2;
import scala.collection.JavaConversions;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.class */
public class RDDConverterUtilsExt {

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$AddRowID.class */
    public static class AddRowID implements Function<Tuple2<Row, Long>, Row> {
        private static final long serialVersionUID = -3733816995375745659L;

        public Row call(Tuple2<Row, Long> tuple2) throws Exception {
            int length = ((Row) tuple2._1).length();
            Object[] objArr = new Object[length + 1];
            for (int i = 0; i < length; i++) {
                objArr[i] = ((Row) tuple2._1).get(i);
            }
            objArr[length] = new Double(((Long) tuple2._2).longValue() + 1);
            return RowFactory.create(objArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$IJVToBinaryBlockFunctionHelper.class */
    public static class IJVToBinaryBlockFunctionHelper implements Serializable {
        private static final long serialVersionUID = -7952801318564745821L;
        private static final int BUFFER_SIZE = 4000000;
        private int _bufflen;
        private long _rlen;
        private long _clen;
        private int _brlen;
        private int _bclen;

        public IJVToBinaryBlockFunctionHelper(MatrixCharacteristics matrixCharacteristics) {
            this._bufflen = -1;
            this._rlen = -1L;
            this._clen = -1L;
            this._brlen = -1;
            this._bclen = -1;
            if (!matrixCharacteristics.dimsKnown()) {
                throw new DMLRuntimeException("The dimensions need to be known in given MatrixCharacteristics for given input RDD");
            }
            this._rlen = matrixCharacteristics.getRows();
            this._clen = matrixCharacteristics.getCols();
            this._brlen = matrixCharacteristics.getRowsPerBlock();
            this._bclen = matrixCharacteristics.getColsPerBlock();
            this._bufflen = (int) Math.min(this._rlen * this._clen, DistributedCacheInput.PARTITION_SIZE);
        }

        public Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell(Text text) {
            FastStringTokenizer fastStringTokenizer = new FastStringTokenizer(' ');
            String text2 = text.toString();
            if (text2.startsWith("%")) {
                return null;
            }
            fastStringTokenizer.reset(text2);
            return new Tuple2<>(new MatrixIndexes(fastStringTokenizer.nextLong(), fastStringTokenizer.nextLong()), new MatrixCell(fastStringTokenizer.nextDouble()));
        }

        public Tuple2<MatrixIndexes, MatrixCell> matrixEntryToMatrixCell(MatrixEntry matrixEntry) {
            return new Tuple2<>(new MatrixIndexes(matrixEntry.i(), matrixEntry.j()), new MatrixCell(matrixEntry.value()));
        }

        Iterable<Tuple2<MatrixIndexes, MatrixBlock>> convertToBinaryBlock(Object obj, RDDConverterTypes rDDConverterTypes) throws Exception {
            Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell;
            ArrayList arrayList = new ArrayList();
            ReblockBuffer reblockBuffer = new ReblockBuffer(this._bufflen, this._rlen, this._clen, this._brlen, this._bclen);
            Iterator it = (Iterator) obj;
            while (it.hasNext()) {
                switch (rDDConverterTypes) {
                    case MATRIXENTRY_TO_MATRIXCELL:
                        textToMatrixCell = matrixEntryToMatrixCell((MatrixEntry) it.next());
                        break;
                    case TEXT_TO_MATRIX_CELL:
                        textToMatrixCell = textToMatrixCell((Text) it.next());
                        break;
                    default:
                        throw new Exception("Invalid converter for IJV data:" + rDDConverterTypes.toString());
                }
                if (textToMatrixCell != null) {
                    if (reblockBuffer.getSize() >= reblockBuffer.getCapacity()) {
                        flushBufferToList(reblockBuffer, arrayList);
                    }
                    reblockBuffer.appendCell(((MatrixIndexes) textToMatrixCell._1).getRowIndex(), ((MatrixIndexes) textToMatrixCell._1).getColumnIndex(), ((MatrixCell) textToMatrixCell._2).getValue());
                }
            }
            flushBufferToList(reblockBuffer, arrayList);
            return arrayList;
        }

        private static void flushBufferToList(ReblockBuffer reblockBuffer, ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList) throws IOException, DMLRuntimeException {
            reblockBuffer.flushBufferToBinaryBlocks().stream().map(indexedMatrixValue -> {
                return SparkUtils.fromIndexedMatrixBlock(indexedMatrixValue);
            }).forEach(tuple2 -> {
                arrayList.add(tuple2);
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$MatrixEntryToBinaryBlockFunction.class */
    public static class MatrixEntryToBinaryBlockFunction implements PairFlatMapFunction<Iterator<MatrixEntry>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 4907483236186747224L;
        private IJVToBinaryBlockFunctionHelper helper;

        public MatrixEntryToBinaryBlockFunction(MatrixCharacteristics matrixCharacteristics) {
            this.helper = null;
            this.helper = new IJVToBinaryBlockFunctionHelper(matrixCharacteristics);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<MatrixEntry> it) throws Exception {
            return this.helper.convertToBinaryBlock(it, RDDConverterTypes.MATRIXENTRY_TO_MATRIXCELL).iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$RDDConverterTypes.class */
    public enum RDDConverterTypes {
        TEXT_TO_MATRIX_CELL,
        MATRIXENTRY_TO_MATRIXCELL
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(JavaSparkContext javaSparkContext, CoordinateMatrix coordinateMatrix, MatrixCharacteristics matrixCharacteristics, boolean z) {
        JavaPairRDD mapPartitionsToPair = coordinateMatrix.entries().toJavaRDD().mapPartitionsToPair(new MatrixEntryToBinaryBlockFunction(matrixCharacteristics));
        if (z && matrixCharacteristics.mightHaveEmptyBlocks()) {
            mapPartitionsToPair = mapPartitionsToPair.union(SparkUtils.getEmptyBlockRDD(javaSparkContext, matrixCharacteristics));
        }
        return RDDAggregateUtils.mergeByKey(mapPartitionsToPair, false);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(SparkContext sparkContext, CoordinateMatrix coordinateMatrix, MatrixCharacteristics matrixCharacteristics, boolean z) {
        return coordinateMatrixToBinaryBlock(new JavaSparkContext(sparkContext), coordinateMatrix, matrixCharacteristics, true);
    }

    public static Dataset<Row> projectColumns(Dataset<Row> dataset, ArrayList<String> arrayList) {
        ArrayList arrayList2 = new ArrayList();
        for (int i = 1; i < arrayList.size(); i++) {
            arrayList2.add(arrayList.get(i));
        }
        return dataset.select(arrayList.get(0), JavaConversions.asScalaBuffer(arrayList2).toList());
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] bArr, long j, long j2) {
        return convertPy4JArrayToMB(bArr, (int) j, (int) j2, false);
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] bArr, int i, int i2) {
        return convertPy4JArrayToMB(bArr, i, i2, false);
    }

    public static MatrixBlock convertSciPyCOOToMB(byte[] bArr, byte[] bArr2, byte[] bArr3, long j, long j2, long j3) {
        return convertSciPyCOOToMB(bArr, bArr2, bArr3, (int) j, (int) j2, (int) j3);
    }

    public static MatrixBlock convertSciPyCOOToMB(byte[] bArr, byte[] bArr2, byte[] bArr3, int i, int i2, int i3) {
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, true);
        matrixBlock.allocateSparseRowsBlock(false);
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.nativeOrder());
        ByteBuffer wrap2 = ByteBuffer.wrap(bArr2);
        wrap2.order(ByteOrder.nativeOrder());
        ByteBuffer wrap3 = ByteBuffer.wrap(bArr3);
        wrap3.order(ByteOrder.nativeOrder());
        for (int i4 = 0; i4 < i3; i4++) {
            matrixBlock.setValue(wrap2.getInt(), wrap3.getInt(), wrap.getDouble());
        }
        matrixBlock.recomputeNonZeros();
        matrixBlock.examSparsity();
        return matrixBlock;
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] bArr, long j, long j2, boolean z) {
        return convertPy4JArrayToMB(bArr, (int) j, (int) j2, z);
    }

    public static MatrixBlock allocateDenseOrSparse(int i, int i2, boolean z) {
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, z);
        matrixBlock.allocateBlock();
        return matrixBlock;
    }

    public static MatrixBlock allocateDenseOrSparse(long j, long j2, boolean z) {
        if (j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE || j2 > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("Dimensions of matrix are too large to be passed via NumPy/SciPy:" + j + " X " + j2);
        }
        return allocateDenseOrSparse(j, j2, z);
    }

    public static void copyRowBlocks(MatrixBlock matrixBlock, int i, MatrixBlock matrixBlock2, int i2, int i3, int i4) {
        copyRowBlocks(matrixBlock, i, matrixBlock2, i2, i3, i4);
    }

    public static void copyRowBlocks(MatrixBlock matrixBlock, long j, MatrixBlock matrixBlock2, int i, int i2, int i3) {
        copyRowBlocks(matrixBlock, j, matrixBlock2, i, i2, i3);
    }

    public static void copyRowBlocks(MatrixBlock matrixBlock, int i, MatrixBlock matrixBlock2, long j, long j2, long j3) {
        copyRowBlocks(matrixBlock, i, matrixBlock2, j, j2, j3);
    }

    public static void copyRowBlocks(MatrixBlock matrixBlock, long j, MatrixBlock matrixBlock2, long j2, long j3, long j4) {
        matrixBlock2.copy((int) (j * j2), (int) Math.min(((j + 1) * j2) - 1, j3 - 1), 0, (int) (j4 - 1), matrixBlock, false);
    }

    public static void postProcessAfterCopying(MatrixBlock matrixBlock) {
        matrixBlock.recomputeNonZeros();
        matrixBlock.examSparsity();
    }

    public static MatrixBlock convertPy4JArrayToMB(byte[] bArr, int i, int i2, boolean z) {
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, z, -1L);
        if (z) {
            throw new DMLRuntimeException("Convertion to sparse format not supported");
        }
        long j = i * i2;
        if (j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("Dense NumPy array of size " + j + " cannot be converted to MatrixBlock");
        }
        double[] dArr = new double[(int) j];
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.nativeOrder());
        for (int i3 = 0; i3 < i * i2; i3++) {
            dArr[i3] = wrap.getDouble();
        }
        matrixBlock.init(dArr, i, i2);
        matrixBlock.recomputeNonZeros();
        matrixBlock.examSparsity();
        return matrixBlock;
    }

    public static byte[] convertMBtoPy4JDenseArr(MatrixBlock matrixBlock) {
        if (matrixBlock.isInSparseFormat()) {
            matrixBlock.sparseToDense();
        }
        long numRows = matrixBlock.getNumRows() * matrixBlock.getNumColumns();
        if (numRows > Integer.MAX_VALUE / 8) {
            throw new DMLRuntimeException("MatrixBlock of size " + numRows + " cannot be converted to dense numpy array");
        }
        byte[] bArr = new byte[(int) (numRows * 8)];
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        if (matrixBlock.isEmptyBlock()) {
            for (int i = 0; i < numRows; i++) {
                ByteBuffer.wrap(bArr, i * 8, 8).order(ByteOrder.nativeOrder()).putDouble(0.0d);
            }
        } else {
            if (denseBlockValues == null) {
                throw new DMLRuntimeException("Error while dealing with empty blocks.");
            }
            for (int i2 = 0; i2 < denseBlockValues.length; i2++) {
                ByteBuffer.wrap(bArr, i2 * 8, 8).order(ByteOrder.nativeOrder()).putDouble(denseBlockValues[i2]);
            }
        }
        return bArr;
    }

    public static Dataset<Row> addIDToDataFrame(Dataset<Row> dataset, SparkSession sparkSession, String str) {
        StructField[] fields = dataset.schema().fields();
        StructField[] structFieldArr = new StructField[fields.length + 1];
        for (int i = 0; i < fields.length; i++) {
            structFieldArr[i] = fields[i];
        }
        structFieldArr[fields.length] = DataTypes.createStructField(str, DataTypes.DoubleType, false);
        return sparkSession.createDataFrame(dataset.rdd().toJavaRDD().zipWithIndex().map(new AddRowID()), new StructType(structFieldArr));
    }

    public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> dataset) {
        StructField[] fields = dataset.schema().fields();
        StructField[] structFieldArr = new StructField[fields.length];
        for (int i = 0; i < fields.length; i++) {
            structFieldArr[i] = DataTypes.createStructField(fields[i].name(), new VectorUDT(), true);
        }
        return sparkSession.createDataFrame(dataset.rdd().toJavaRDD().zipWithIndex().map(new Function<Tuple2<Row, Long>, Row>() { // from class: org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.1StringToVector
            private static final long serialVersionUID = -4733816995375745659L;

            public Row call(Tuple2<Row, Long> tuple2) throws Exception {
                Row row = (Row) tuple2._1;
                int length = row.length();
                if (length > 1) {
                    throw new DMLRuntimeException("The row must have at most one column");
                }
                Object[] objArr = new Object[length];
                ArrayList arrayList = new ArrayList();
                int i2 = 0;
                while (i2 < row.length()) {
                    Object obj = row.get(i2);
                    if (obj == null) {
                        arrayList.add(null);
                    } else {
                        if (!(obj instanceof String)) {
                            throw new DMLRuntimeException("Only String is supported");
                        }
                        StringBuffer stringBuffer = new StringBuffer(((String) obj).trim());
                        while (i2 < 2) {
                            if ((stringBuffer.charAt(0) == '(' && stringBuffer.charAt(stringBuffer.length() - 1) == ')') || (stringBuffer.charAt(0) == '[' && stringBuffer.charAt(stringBuffer.length() - 1) == ']')) {
                                stringBuffer.deleteCharAt(0);
                                stringBuffer.setLength(stringBuffer.length() - 1);
                            }
                            i2++;
                        }
                        try {
                            arrayList.add(Vectors.dense((double[]) NumericParser.parse("[" + stringBuffer.toString().replaceAll(" *, *", ",") + "]")));
                        } catch (Exception e) {
                            throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
                        }
                    }
                    i2++;
                }
                return RowFactory.create(arrayList.toArray());
            }
        }).rdd(), DataTypes.createStructType(structFieldArr));
    }
}
