package org.apache.sysml.udf.lib;

import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;

/* loaded from: input_file:org/apache/sysml/udf/lib/SGDNesterovUpdate.class */
public class SGDNesterovUpdate extends PackageFunction {
    private static final long serialVersionUID = -3905212831582648882L;
    private Matrix updatedX;
    private Matrix updatedV;
    private Random rand = new Random();

    @Override // org.apache.sysml.udf.PackageFunction
    public int getNumFunctionOutputs() {
        return 2;
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public FunctionParameter getFunctionOutput(int i) {
        if (i == 0) {
            return this.updatedX;
        }
        if (i == 1) {
            return this.updatedV;
        }
        throw new RuntimeException("Invalid function output being requested");
    }

    boolean isDense(MatrixBlock matrixBlock) {
        return (matrixBlock.isInSparseFormat() || matrixBlock.getDenseBlock() == null) ? false : true;
    }

    @Override // org.apache.sysml.udf.PackageFunction
    public void execute() {
        try {
            MatrixBlock acquireRead = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
            MatrixBlock acquireRead2 = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
            double parseDouble = Double.parseDouble(((Scalar) getFunctionInput(2)).getValue());
            double parseDouble2 = Double.parseDouble(((Scalar) getFunctionInput(3)).getValue());
            MatrixBlock acquireRead3 = ((Matrix) getFunctionInput(4)).getMatrixObject().acquireRead();
            double parseDouble3 = Double.parseDouble(((Scalar) getFunctionInput(5)).getValue());
            this.updatedV = new Matrix("tmp_" + this.rand.nextLong(), acquireRead3.getNumRows(), acquireRead3.getNumColumns(), Matrix.ValueType.Double);
            MatrixBlock allocateDenseMatrixBlock = allocateDenseMatrixBlock(this.updatedV);
            double[] denseBlockValues = allocateDenseMatrixBlock.getDenseBlockValues();
            if (isDense(acquireRead3) && isDense(acquireRead2) && isDense(acquireRead)) {
                double[] denseBlockValues2 = acquireRead3.getDenseBlockValues();
                double[] denseBlockValues3 = acquireRead2.getDenseBlockValues();
                double[] denseBlockValues4 = acquireRead.getDenseBlockValues();
                int i = 0;
                for (int i2 = 0; i2 < denseBlockValues.length; i2++) {
                    denseBlockValues[i2] = ((parseDouble2 * denseBlockValues2[i2]) - (parseDouble * denseBlockValues3[i2])) - ((parseDouble * parseDouble3) * denseBlockValues4[i2]);
                    i += denseBlockValues[i2] != 0.0d ? 1 : 0;
                }
                allocateDenseMatrixBlock.setNonZeros(i);
            } else {
                multiplyByConstant(acquireRead3, parseDouble2, denseBlockValues);
                multiplyByConstant(acquireRead2, -parseDouble, denseBlockValues);
                multiplyByConstant(acquireRead, (-parseDouble) * parseDouble3, denseBlockValues);
                allocateDenseMatrixBlock.recomputeNonZeros();
            }
            this.updatedV.setMatrixDoubleArray(allocateDenseMatrixBlock, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            this.updatedX = new Matrix("tmp_" + this.rand.nextLong(), acquireRead.getNumRows(), acquireRead.getNumColumns(), Matrix.ValueType.Double);
            MatrixBlock allocateDenseMatrixBlock2 = allocateDenseMatrixBlock(this.updatedX);
            double[] denseBlockValues5 = allocateDenseMatrixBlock2.getDenseBlockValues();
            if (isDense(acquireRead) && isDense(acquireRead3)) {
                double[] denseBlockValues6 = acquireRead.getDenseBlockValues();
                double[] denseBlockValues7 = acquireRead3.getDenseBlockValues();
                int i3 = 0;
                double d = parseDouble2 + 1.0d;
                for (int i4 = 0; i4 < denseBlockValues5.length; i4++) {
                    denseBlockValues5[i4] = (denseBlockValues6[i4] - (parseDouble2 * denseBlockValues7[i4])) + (d * denseBlockValues[i4]);
                    i3 += denseBlockValues5[i4] != 0.0d ? 1 : 0;
                }
                allocateDenseMatrixBlock2.setNonZeros(i3);
            } else if (isDense(acquireRead3)) {
                copy(acquireRead, denseBlockValues5);
                double[] denseBlockValues8 = acquireRead3.getDenseBlockValues();
                int i5 = 0;
                double d2 = parseDouble2 + 1.0d;
                for (int i6 = 0; i6 < denseBlockValues5.length; i6++) {
                    int i7 = i6;
                    denseBlockValues5[i7] = denseBlockValues5[i7] + ((-parseDouble2) * denseBlockValues8[i6]) + (d2 * denseBlockValues[i6]);
                    i5 += denseBlockValues5[i6] != 0.0d ? 1 : 0;
                }
                allocateDenseMatrixBlock2.setNonZeros(i5);
            } else {
                copy(acquireRead, denseBlockValues5);
                multiplyByConstant(acquireRead3, -parseDouble2, denseBlockValues5);
                multiplyByConstant(denseBlockValues, 1.0d + parseDouble2, denseBlockValues5);
                allocateDenseMatrixBlock2.recomputeNonZeros();
            }
            this.updatedX.setMatrixDoubleArray(allocateDenseMatrixBlock2, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
            ((Matrix) getFunctionInput(0)).getMatrixObject().release();
            ((Matrix) getFunctionInput(1)).getMatrixObject().release();
            ((Matrix) getFunctionInput(4)).getMatrixObject().release();
        } catch (IOException e) {
            throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
        }
    }

    private static MatrixBlock allocateDenseMatrixBlock(Matrix matrix) {
        MatrixBlock matrixBlock = new MatrixBlock((int) matrix.getNumRows(), (int) matrix.getNumCols(), false);
        matrixBlock.allocateDenseBlock();
        return matrixBlock;
    }

    private static void multiplyByConstant(double[] dArr, double d, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            int i2 = i;
            dArr2[i2] = dArr2[i2] + (dArr[i] * d);
        }
    }

    private static void multiplyByConstant(MatrixBlock matrixBlock, double d, double[] dArr) {
        if (matrixBlock.isInSparseFormat()) {
            Iterator<IJV> sparseBlockIterator = matrixBlock.getSparseBlockIterator();
            while (sparseBlockIterator.hasNext()) {
                IJV next = sparseBlockIterator.next();
                int i = next.getI() * next.getJ();
                dArr[i] = dArr[i] + (next.getV() * d);
            }
            return;
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        if (denseBlockValues != null) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (denseBlockValues[i2] * d);
            }
        }
    }

    private static void copy(MatrixBlock matrixBlock, double[] dArr) {
        if (matrixBlock.isInSparseFormat()) {
            Iterator<IJV> sparseBlockIterator = matrixBlock.getSparseBlockIterator();
            while (sparseBlockIterator.hasNext()) {
                IJV next = sparseBlockIterator.next();
                dArr[next.getI() * next.getJ()] = next.getV();
            }
            return;
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        if (denseBlockValues != null) {
            System.arraycopy(denseBlockValues, 0, dArr, 0, dArr.length);
        }
    }
}
