package org.apache.sysml.api.ml;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.MatrixMetadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import scala.Predef$;
import scala.Tuple2;
import scala.reflect.ScalaSignature;

/* compiled from: BaseSystemMLRegressor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00154q!\u0001\u0002\u0011\u0002\u0007\u0005QB\u0001\u000eCCN,7+_:uK6lEJU3he\u0016\u001c8o\u001c:N_\u0012,GN\u0003\u0002\u0004\t\u0005\u0011Q\u000e\u001c\u0006\u0003\u000b\u0019\t1!\u00199j\u0015\t9\u0001\"A\u0003tsNlGN\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0004\u00019!\u0002CA\b\u0013\u001b\u0005\u0001\"\"A\t\u0002\u000bM\u001c\u0017\r\\1\n\u0005M\u0001\"AB!osJ+g\r\u0005\u0002\u0016-5\t!!\u0003\u0002\u0018\u0005\tQ\")Y:f'f\u001cH/Z7N\u0019\u0016\u001bH/[7bi>\u0014Xj\u001c3fY\")\u0011\u0004\u0001C\u00015\u00051A%\u001b8ji\u0012\"\u0012a\u0007\t\u0003\u001fqI!!\b\t\u0003\tUs\u0017\u000e\u001e\u0005\u0006?\u0001!\t\u0001I\u0001\u000eE\u0006\u001cX\r\u0016:b]N4wN]7\u0015\t\u0005B#F\r\t\u0003E\u0015r!aD\u0012\n\u0005\u0011\u0002\u0012A\u0002)sK\u0012,g-\u0003\u0002'O\t11\u000b\u001e:j]\u001eT!\u0001\n\t\t\u000b%r\u0002\u0019A\u0011\u0002\ra{f-\u001b7f\u0011\u0015Yc\u00041\u0001-\u0003\t\u00198\r\u0005\u0002.a5\taF\u0003\u00020\u0011\u0005)1\u000f]1sW&\u0011\u0011G\f\u0002\r'B\f'o[\"p]R,\u0007\u0010\u001e\u0005\u0006gy\u0001\r!I\u0001\u000eaJ,G-[2uS>tg+\u0019:\t\u000b}\u0001A\u0011A\u001b\u0015\tY\u0002%i\u0011\t\u0003oyj\u0011\u0001\u000f\u0006\u0003si\nA\u0001Z1uC*\u00111\bP\u0001\u0007[\u0006$(/\u001b=\u000b\u0005u2\u0011a\u0002:v]RLW.Z\u0005\u0003\u007fa\u00121\"T1ue&D(\t\\8dW\")\u0011\t\u000ea\u0001m\u0005\t\u0001\fC\u0003,i\u0001\u0007A\u0006C\u00034i\u0001\u0007\u0011\u0005C\u0003 \u0001\u0011\u0005Q\t\u0006\u0003G5\u000e$\u0007CA$X\u001d\tAEK\u0004\u0002J%:\u0011!*\u0015\b\u0003\u0017Bs!\u0001T(\u000e\u00035S!A\u0014\u0007\u0002\rq\u0012xn\u001c;?\u0013\u0005Y\u0011BA\u0005\u000b\u0013\ty\u0003\"\u0003\u0002T]\u0005\u00191/\u001d7\n\u0005U3\u0016a\u00029bG.\fw-\u001a\u0006\u0003':J!\u0001W-\u0003\u0013\u0011\u000bG/\u0019$sC6,'BA+W\u0011\u0015YF\t1\u0001]\u0003\t!g\r\u0005\u0002^A:\u0011QCX\u0005\u0003?\n\tAbU2sSB$8/\u0016;jYNL!!\u00192\u0003\u001bM\u0003\u0018M]6ECR\fG+\u001f9f\u0015\ty&\u0001C\u0003,\t\u0002\u0007A\u0006C\u00034\t\u0002\u0007\u0011\u0005")
/* loaded from: input_file:org/apache/sysml/api/ml/BaseSystemMLRegressorModel.class */
public interface BaseSystemMLRegressorModel extends BaseSystemMLEstimatorModel {

    /* compiled from: BaseSystemMLRegressor.scala */
    /* renamed from: org.apache.sysml.api.ml.BaseSystemMLRegressorModel$class, reason: invalid class name */
    /* loaded from: input_file:org/apache/sysml/api/ml/BaseSystemMLRegressorModel$class.class */
    public abstract class Cclass {
        public static String baseTransform(BaseSystemMLRegressorModel baseSystemMLRegressorModel, String str, SparkContext sparkContext, String str2) {
            MLContext mLContext = new MLContext(sparkContext);
            baseSystemMLRegressorModel.updateML(mLContext);
            MLResults execute = mLContext.execute(ScriptFactory.dml(baseSystemMLRegressorModel.dmlRead("X", str)).out("X"));
            Tuple2<Script, String> predictionScript = baseSystemMLRegressorModel.getPredictionScript(false);
            mLContext.execute(ScriptFactory.dml(baseSystemMLRegressorModel.dmlWrite("X")).in("X", mLContext.execute(((Script) predictionScript._1()).in((String) predictionScript._2(), execute.getMatrix("X"))).getMatrix(str2)));
            return "output.mtx";
        }

        public static MatrixBlock baseTransform(BaseSystemMLRegressorModel baseSystemMLRegressorModel, MatrixBlock matrixBlock, SparkContext sparkContext, String str) {
            MLContext mLContext = new MLContext(sparkContext);
            baseSystemMLRegressorModel.updateML(mLContext);
            Tuple2<Script, String> predictionScript = baseSystemMLRegressorModel.getPredictionScript(true);
            MatrixBlock matrixBlock2 = mLContext.execute(((Script) predictionScript._1()).in((String) predictionScript._2(), matrixBlock)).getMatrix(str).toMatrixBlock();
            if (matrixBlock2.getNumColumns() != 1) {
                throw new RuntimeException("Expected prediction to be a column vector");
            }
            return matrixBlock2;
        }

        public static Dataset baseTransform(BaseSystemMLRegressorModel baseSystemMLRegressorModel, Dataset dataset, SparkContext sparkContext, String str) {
            MLContext mLContext = new MLContext(sparkContext);
            baseSystemMLRegressorModel.updateML(mLContext);
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics();
            JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock = RDDConverterUtils.dataFrameToBinaryBlock(JavaSparkContext$.MODULE$.fromSparkContext(dataset.rdd().sparkContext()), dataset, matrixCharacteristics, false, true);
            Tuple2<Script, String> predictionScript = baseSystemMLRegressorModel.getPredictionScript(false);
            return PredictionUtils$.MODULE$.joinUsingID(RDDConverterUtilsExt.addIDToDataFrame(dataset, dataset.sparkSession(), RDDConverterUtils.DF_ID_COLUMN), mLContext.execute(((Script) predictionScript._1()).in((String) predictionScript._2(), new Matrix(dataFrameToBinaryBlock, new MatrixMetadata(matrixCharacteristics)))).getDataFrame(str).select(RDDConverterUtils.DF_ID_COLUMN, Predef$.MODULE$.wrapRefArray(new String[]{"C1"})).withColumnRenamed("C1", "prediction"));
        }

        public static void $init$(BaseSystemMLRegressorModel baseSystemMLRegressorModel) {
        }
    }

    String baseTransform(String str, SparkContext sparkContext, String str2);

    MatrixBlock baseTransform(MatrixBlock matrixBlock, SparkContext sparkContext, String str);

    Dataset<Row> baseTransform(Dataset<?> dataset, SparkContext sparkContext, String str);
}
