package org.apache.sysml.scripts.algorithms;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.algorithms.kmeans_predict.Get_best_assignments_output;

/* loaded from: input_file:org/apache/sysml/scripts/algorithms/Kmeans_predict.class */
public class Kmeans_predict extends Script {
    public Kmeans_predict() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/algorithms/Kmeans-predict.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public Get_best_assignments_output get_best_assignments(Object obj) {
        Script script = new Script("source('scripts/algorithms/Kmeans-predict.dml') as mlcontextns;[row_ids, col_ids, margins, max_counts, rounded_percentages] = mlcontextns::get_best_assignments(counts);");
        script.in("counts", obj).out("row_ids").out("col_ids").out("margins").out("max_counts").out("rounded_percentages");
        MLResults execute = script.execute();
        return new Get_best_assignments_output(execute.getMatrix("row_ids"), execute.getMatrix("col_ids"), execute.getMatrix("margins"), execute.getMatrix("max_counts"), execute.getMatrix("rounded_percentages"));
    }

    public String get_best_assignments__docs() {
        return "get_best_assignments = function (Matrix[double] counts)\nreturn (Matrix[double] row_ids, Matrix[double] col_ids, Matrix[double] margins, \n        Matrix[double] max_counts, Matrix[double] rounded_percentages)\n{\n    margins = rowSums (counts);\n    select_positive = removeEmpty (target = diag (margins > 0), margin = \"rows\");\n    row_ids = select_positive %*% seq (1, nrow (margins), 1);\n    pos_counts = select_positive %*% counts;\n    pos_margins = select_positive %*% margins;\n    max_counts = rowMaxs (pos_counts);\n    is_max_count = (pos_counts == max_counts);\n    aggr_is_max_count = t(cumsum (t(is_max_count)));\n    col_ids = rowSums (aggr_is_max_count == 0) + 1;\n    rounded_percentages = round (1000000.0 * max_counts / pos_margins) / 10000.0;\n}\n";
    }

    public String get_best_assignments__source() {
        return "get_best_assignments = function (Matrix[double] counts)\nreturn (Matrix[double] row_ids, Matrix[double] col_ids, Matrix[double] margins, \n        Matrix[double] max_counts, Matrix[double] rounded_percentages)\n{\n    margins = rowSums (counts);\n    select_positive = removeEmpty (target = diag (margins > 0), margin = \"rows\");\n    row_ids = select_positive %*% seq (1, nrow (margins), 1);\n    pos_counts = select_positive %*% counts;\n    pos_margins = select_positive %*% margins;\n    max_counts = rowMaxs (pos_counts);\n    is_max_count = (pos_counts == max_counts);\n    aggr_is_max_count = t(cumsum (t(is_max_count)));\n    col_ids = rowSums (aggr_is_max_count == 0) + 1;\n    rounded_percentages = round (1000000.0 * max_counts / pos_margins) / 10000.0;\n}\n";
    }
}
