package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.nn.layers.low_rank_affine.Backward_output;
import org.apache.sysml.scripts.nn.layers.low_rank_affine.Init_output;

/* loaded from: input_file:org/apache/sysml/scripts/nn/layers/Low_rank_affine.class */
public class Low_rank_affine extends Script {
    public Low_rank_affine() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/nn/layers/low_rank_affine.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public Init_output init(Object obj, Object obj2, Object obj3) {
        Script script = new Script("source('scripts/nn/layers/low_rank_affine.dml') as mlcontextns;[U, V, b] = mlcontextns::init(D, M, R);");
        script.in("D", obj).in("M", obj2).in("R", obj3).out("U").out("V").out("b");
        MLResults execute = script.execute();
        return new Init_output(execute.getMatrix("U"), execute.getMatrix("V"), execute.getMatrix("b"));
    }

    public String init__docs() {
        return "init = function(int D, int M, int R)\n    return (matrix[double] U, matrix[double] V, matrix[double] b) {\n  /*\n   * Initialize the parameters of this layer.\n   *\n   * Note: This is just a convenience function, and parameters\n   * may be initialized manually if needed.\n   *\n   * We use the heuristic by He et al., which limits the magnification\n   * of inputs/gradients during forward/backward passes by scaling\n   * unit-Gaussian weights by a factor of sqrt(2/n), under the\n   * assumption of relu neurons.\n   *  - http://arxiv.org/abs/1502.01852\n   *\n   * Inputs:\n   *  - D: Dimensionality of the input features (number of features).\n   *  - M: Number of neurons in this layer.\n   *  - R: Rank of U,V matrices such that R << min(D, M).\n   *\n   * Outputs:\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   */\n";
    }

    public String init__source() {
        return "init = function(int D, int M, int R)\n    return (matrix[double] U, matrix[double] V, matrix[double] b) {\n  /*\n   * Initialize the parameters of this layer.\n   *\n   * Note: This is just a convenience function, and parameters\n   * may be initialized manually if needed.\n   *\n   * We use the heuristic by He et al., which limits the magnification\n   * of inputs/gradients during forward/backward passes by scaling\n   * unit-Gaussian weights by a factor of sqrt(2/n), under the\n   * assumption of relu neurons.\n   *  - http://arxiv.org/abs/1502.01852\n   *\n   * Inputs:\n   *  - D: Dimensionality of the input features (number of features).\n   *  - M: Number of neurons in this layer.\n   *  - R: Rank of U,V matrices such that R << min(D, M).\n   *\n   * Outputs:\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   */\n  U = rand(rows=D, cols=R, pdf=\"normal\") * sqrt(2.0/D)\n  V = rand(rows=R, cols=M, pdf=\"normal\") * sqrt(2.0/R)\n  b = matrix(0, rows=1, cols=M)\n}\n";
    }

    public Matrix forward(Object obj, Object obj2, Object obj3, Object obj4) {
        Script script = new Script("source('scripts/nn/layers/low_rank_affine.dml') as mlcontextns;out = mlcontextns::forward(X, U, V, b);");
        script.in("X", obj).in("U", obj2).in("V", obj3).in("b", obj4).out("out");
        return script.execute().getMatrix("out");
    }

    public String forward__docs() {
        return "forward = function(matrix[double] X, matrix[double] U, matrix[double] V, matrix[double] b)\n    return (matrix[double] out) {\n  /*\n   * Computes the forward pass for a low-rank affine (fully-connected) layer\n   * with M neurons.  The input data has N examples, each with D\n   * features.\n   *\n   * Inputs:\n   *  - X: Inputs, of shape (N, D).\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   *\n   * Outputs:\n   *  - out: Outputs, of shape (N, M).\n   */\n";
    }

    public String forward__source() {
        return "forward = function(matrix[double] X, matrix[double] U, matrix[double] V, matrix[double] b)\n    return (matrix[double] out) {\n  /*\n   * Computes the forward pass for a low-rank affine (fully-connected) layer\n   * with M neurons.  The input data has N examples, each with D\n   * features.\n   *\n   * Inputs:\n   *  - X: Inputs, of shape (N, D).\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   *\n   * Outputs:\n   *  - out: Outputs, of shape (N, M).\n   */\n  out = X %*% U %*% V + b\n}\n";
    }

    public Backward_output backward(Object obj, Object obj2, Object obj3, Object obj4, Object obj5) {
        Script script = new Script("source('scripts/nn/layers/low_rank_affine.dml') as mlcontextns;[dX, dU, dV, db] = mlcontextns::backward(dout, X, U, V, b);");
        script.in("dout", obj).in("X", obj2).in("U", obj3).in("V", obj4).in("b", obj5).out("dX").out("dU").out("dV").out("db");
        MLResults execute = script.execute();
        return new Backward_output(execute.getMatrix("dX"), execute.getMatrix("dU"), execute.getMatrix("dV"), execute.getMatrix("db"));
    }

    public String backward__docs() {
        return "backward = function(matrix[double] dout, matrix[double] X,\n                    matrix[double] U, matrix[double] V, matrix[double] b)\n    return (matrix[double] dX, matrix[double] dU, matrix[double] dV, matrix[double] db) {\n  /*\n   * Computes the backward pass for a low-rank fully-connected (affine) layer\n   * with M neurons.\n   *\n   * Inputs:\n   *  - dout: Gradient wrt `out` from upstream, of shape (N, M).\n   *  - X: Inputs, of shape (N, D).\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   *\n   * Outputs:\n   *  - dX: Gradient wrt `X`, of shape (N, D).\n   *  - dU: Gradient wrt `U`, of shape (D, R).\n   *  - dV: Gradient wrt `V`, of shape (R, M).\n   *  - db: Gradient wrt `b`, of shape (1, M).\n   */\n";
    }

    public String backward__source() {
        return "backward = function(matrix[double] dout, matrix[double] X,\n                    matrix[double] U, matrix[double] V, matrix[double] b)\n    return (matrix[double] dX, matrix[double] dU, matrix[double] dV, matrix[double] db) {\n  /*\n   * Computes the backward pass for a low-rank fully-connected (affine) layer\n   * with M neurons.\n   *\n   * Inputs:\n   *  - dout: Gradient wrt `out` from upstream, of shape (N, M).\n   *  - X: Inputs, of shape (N, D).\n   *  - U: LHS factor matrix for weights, of shape (D, R).\n   *  - V: RHS factor matrix for weights, of shape (R, M).\n   *  - b: Biases, of shape (1, M).\n   *\n   * Outputs:\n   *  - dX: Gradient wrt `X`, of shape (N, D).\n   *  - dU: Gradient wrt `U`, of shape (D, R).\n   *  - dV: Gradient wrt `V`, of shape (R, M).\n   *  - db: Gradient wrt `b`, of shape (1, M).\n   */\n  dX = dout %*% t(V) %*% t(U)\n  \n  # If out = Z %*% L, then dL = t(Z) %*% dout\n  # Substituting Z = X %*% U and L = V, we get\n  dV = t(U) %*% t(X) %*% dout\n    \n  dU = t(X) %*% dout %*% t(V)\n  \n  db = colSums(dout)\n}\n";
    }
}
