package org.apache.sysml.scripts.algorithms;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.scripts.algorithms.cox.Ensure_trust_bound_output;

/* loaded from: input_file:org/apache/sysml/scripts/algorithms/Cox.class */
public class Cox extends Script {
    public Cox() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/algorithms/Cox.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public Ensure_trust_bound_output ensure_trust_bound(Object obj, Object obj2, Object obj3, Object obj4) {
        Script script = new Script("source('scripts/algorithms/Cox.dml') as mlcontextns;[x_new, is_violated] = mlcontextns::ensure_trust_bound(x, a, b, c);");
        script.in("x", obj).in(GPUInstruction.MISC_TIMER_ALLOCATE, obj2).in("b", obj3).in("c", obj4).out("x_new").out("is_violated");
        MLResults execute = script.execute();
        return new Ensure_trust_bound_output(execute.getDouble("x_new"), execute.getBoolean("is_violated"));
    }

    public String ensure_trust_bound__docs() {
        return "ensure_trust_bound =\n    function (double x, double a, double b, double c)\n    return (double x_new, boolean is_violated)\n{\n    if (a * x^2 + b * x + c > 0)\n    {\n        is_violated = TRUE;\n        rad = sqrt (b ^ 2 - 4 * a * c);\n        if (b >= 0) {\n            x_new = - (2 * c) / (b + rad);\n        } else {\n            x_new = - (b - rad) / (2 * a);\n        }\n    } else {\n        is_violated = FALSE;\n        x_new = x;\n    }\n}\n";
    }

    public String ensure_trust_bound__source() {
        return "ensure_trust_bound =\n    function (double x, double a, double b, double c)\n    return (double x_new, boolean is_violated)\n{\n    if (a * x^2 + b * x + c > 0)\n    {\n        is_violated = TRUE;\n        rad = sqrt (b ^ 2 - 4 * a * c);\n        if (b >= 0) {\n            x_new = - (2 * c) / (b + rad);\n        } else {\n            x_new = - (b - rad) / (2 * a);\n        }\n    } else {\n        is_violated = FALSE;\n        x_new = x;\n    }\n}\n";
    }

    public double update_trust_bound(Object obj, Object obj2, Object obj3, Object obj4, Object obj5) {
        Script script = new Script("source('scripts/algorithms/Cox.dml') as mlcontextns;delta = mlcontextns::update_trust_bound(delta, sb_distance, so_exact, so_linear_approx, so_quadratic_approx);");
        script.in("delta", obj).in("sb_distance", obj2).in("so_exact", obj3).in("so_linear_approx", obj4).in("so_quadratic_approx", obj5).out("delta");
        return script.execute().getDouble("delta");
    }

    public String update_trust_bound__docs() {
        return "update_trust_bound =\n    function (double delta,\n              double sb_distance,\n              double so_exact,\n              double so_linear_approx,\n              double so_quadratic_approx)\n    return   (double delta)\n{\n    sigma1 = 0.25;\n    sigma2 = 0.5;\n    sigma3 = 4.0;\n\n    if (so_exact <= so_linear_approx) {\n       alpha = sigma3;\n    } else {\n       alpha = max (sigma1, - 0.5 * so_linear_approx / (so_exact - so_linear_approx));\n    }\n\n    rho = so_exact / so_quadratic_approx;\n    if (rho < 0.0001) {\n        delta = min (max (alpha, sigma1) * sb_distance, sigma2 * delta);\n    } else { if (rho < 0.25) {\n        delta = max (sigma1 * delta, min (alpha * sb_distance, sigma2 * delta));\n    } else { if (rho < 0.75) {\n        delta = max (sigma1 * delta, min (alpha * sb_distance, sigma3 * delta));\n    } else {\n        delta = max (delta, min (alpha * sb_distance, sigma3 * delta));\n    }}} \n}\n";
    }

    public String update_trust_bound__source() {
        return "update_trust_bound =\n    function (double delta,\n              double sb_distance,\n              double so_exact,\n              double so_linear_approx,\n              double so_quadratic_approx)\n    return   (double delta)\n{\n    sigma1 = 0.25;\n    sigma2 = 0.5;\n    sigma3 = 4.0;\n\n    if (so_exact <= so_linear_approx) {\n       alpha = sigma3;\n    } else {\n       alpha = max (sigma1, - 0.5 * so_linear_approx / (so_exact - so_linear_approx));\n    }\n\n    rho = so_exact / so_quadratic_approx;\n    if (rho < 0.0001) {\n        delta = min (max (alpha, sigma1) * sb_distance, sigma2 * delta);\n    } else { if (rho < 0.25) {\n        delta = max (sigma1 * delta, min (alpha * sb_distance, sigma2 * delta));\n    } else { if (rho < 0.75) {\n        delta = max (sigma1 * delta, min (alpha * sb_distance, sigma3 * delta));\n    } else {\n        delta = max (delta, min (alpha * sb_distance, sigma3 * delta));\n    }}} \n}\n";
    }
}
