package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.matrix.data.DnnParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.DnnUtils;
import org.apache.sysml.utils.NativeHelper;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.class */
public class DnnCPInstruction extends UnaryCPInstruction {
    private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName());
    private static boolean warnedUnderUtilitization = false;
    private final CPOperand _in2;
    private final CPOperand _in3;
    private final CPOperand _in4;
    private final CPOperand _in5;
    private final CPOperand _in6;
    private final CPOperand _in7;
    private final CPOperand _in8;
    private final CPOperand _out2;
    private final CPOperand _out3;
    private final CPOperand _out4;
    private final CPOperand _out5;
    private final ArrayList<CPOperand> _input_shape;
    private final ArrayList<CPOperand> _filter_shape;
    private final ArrayList<CPOperand> _stride;
    private final ArrayList<CPOperand> _padding;
    private final int _numThreads;
    private final double _intermediateMemoryBudget;

    public DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i, double d, String str, String str2) {
        super(CPInstruction.CPType.Dnn, null, cPOperand, cPOperand4, str, str2);
        this._in2 = cPOperand2;
        this._in3 = cPOperand3;
        this._in4 = null;
        this._in5 = null;
        this._in6 = null;
        this._in7 = null;
        this._in8 = null;
        this._out2 = null;
        this._out3 = null;
        this._out4 = null;
        this._out5 = null;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
        this._numThreads = i;
        this._intermediateMemoryBudget = d;
    }

    public DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, int i, double d) {
        this(cPOperand, cPOperand2, (CPOperand) null, cPOperand3, (ArrayList<CPOperand>) null, (ArrayList<CPOperand>) null, (ArrayList<CPOperand>) null, (ArrayList<CPOperand>) null, i, d, str, str2);
        if (!str.equals("bias_add") && !str.equals("relu_backward") && !str.equals("bias_multiply")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + str);
        }
    }

    private DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i, double d) {
        this(cPOperand, (CPOperand) null, (CPOperand) null, cPOperand2, arrayList, arrayList2, arrayList3, arrayList4, i, d, str, str2);
    }

    public DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i, double d) {
        this(cPOperand, cPOperand2, (CPOperand) null, cPOperand3, arrayList, arrayList2, arrayList3, arrayList4, i, d, str, str2);
    }

    public DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i, double d) {
        this(cPOperand, cPOperand2, cPOperand3, cPOperand4, arrayList, arrayList2, arrayList3, arrayList4, i, d, str, str2);
    }

    public DnnCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, CPOperand cPOperand6, CPOperand cPOperand7, CPOperand cPOperand8, CPOperand cPOperand9, CPOperand cPOperand10, CPOperand cPOperand11, CPOperand cPOperand12, CPOperand cPOperand13, String str, String str2, double d) throws DMLRuntimeException {
        super(CPInstruction.CPType.Dnn, null, cPOperand, cPOperand9, str, str2);
        this._in2 = cPOperand2;
        this._in3 = cPOperand3;
        this._in4 = cPOperand4;
        this._in5 = cPOperand5;
        this._in6 = cPOperand6;
        this._in7 = cPOperand7;
        this._in8 = cPOperand8;
        this._out2 = cPOperand10;
        this._out3 = cPOperand11;
        this._out4 = cPOperand12;
        this._out5 = cPOperand13;
        this._stride = null;
        this._padding = null;
        this._input_shape = null;
        this._filter_shape = null;
        this._numThreads = 0;
        this._intermediateMemoryBudget = d;
    }

    public static DnnCPInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("maxpooling") || str2.equalsIgnoreCase("relu_maxpooling") || str2.equalsIgnoreCase("avgpooling")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 16);
            CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[14]);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList.add(new CPOperand(instructionPartsWithValueType[2]));
            arrayList.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[13]));
            return new DnnCPInstruction(cPOperand, cPOperand2, str2, str, arrayList, arrayList2, arrayList3, arrayList4, Integer.parseInt(instructionPartsWithValueType[15]), Double.parseDouble(instructionPartsWithValueType[16]));
        }
        if (str2.equalsIgnoreCase("maxpooling_backward") || str2.equalsIgnoreCase("relu_maxpooling_backward") || str2.equalsIgnoreCase("avgpooling_backward") || str2.equalsIgnoreCase("conv2d") || str2.equalsIgnoreCase("conv2d_backward_filter") || str2.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 17);
            CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[1]);
            CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[2]);
            CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[15]);
            ArrayList arrayList5 = new ArrayList();
            ArrayList arrayList6 = new ArrayList();
            ArrayList arrayList7 = new ArrayList();
            ArrayList arrayList8 = new ArrayList();
            arrayList5.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList5.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList6.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList7.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[13]));
            arrayList8.add(new CPOperand(instructionPartsWithValueType[14]));
            return new DnnCPInstruction(cPOperand3, cPOperand4, cPOperand5, str2, str, arrayList5, arrayList6, arrayList7, arrayList8, Integer.parseInt(instructionPartsWithValueType[16]), Double.parseDouble(instructionPartsWithValueType[17]));
        }
        if (!str2.equalsIgnoreCase("conv2d_bias_add")) {
            if (str2.equalsIgnoreCase("bias_add") || str2.equals("relu_backward") || str2.equalsIgnoreCase("bias_multiply")) {
                InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
                return new DnnCPInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), str2, str, Integer.parseInt(instructionPartsWithValueType[4]), Double.parseDouble(instructionPartsWithValueType[5]));
            }
            if (str2.equalsIgnoreCase("batch_norm2d")) {
                InstructionUtils.checkNumFields(instructionPartsWithValueType, 13);
                return new DnnCPInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), new CPOperand(instructionPartsWithValueType[7]), new CPOperand(instructionPartsWithValueType[8]), new CPOperand(instructionPartsWithValueType[9]), new CPOperand(instructionPartsWithValueType[10]), new CPOperand(instructionPartsWithValueType[11]), new CPOperand(instructionPartsWithValueType[12]), new CPOperand(instructionPartsWithValueType[13]), str2, str, 0.0d);
            }
            if (!str2.equalsIgnoreCase("batch_norm2d_backward")) {
                throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
            }
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 9);
            return new DnnCPInstruction(new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), new CPOperand(instructionPartsWithValueType[5]), new CPOperand(instructionPartsWithValueType[6]), null, null, new CPOperand(instructionPartsWithValueType[7]), new CPOperand(instructionPartsWithValueType[8]), new CPOperand(instructionPartsWithValueType[9]), null, null, str2, str, 0.0d);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 18);
        CPOperand cPOperand6 = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand7 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand8 = new CPOperand(instructionPartsWithValueType[3]);
        CPOperand cPOperand9 = new CPOperand(instructionPartsWithValueType[16]);
        ArrayList arrayList9 = new ArrayList();
        ArrayList arrayList10 = new ArrayList();
        ArrayList arrayList11 = new ArrayList();
        ArrayList arrayList12 = new ArrayList();
        arrayList9.add(new CPOperand(instructionPartsWithValueType[4]));
        arrayList9.add(new CPOperand(instructionPartsWithValueType[5]));
        arrayList10.add(new CPOperand(instructionPartsWithValueType[6]));
        arrayList10.add(new CPOperand(instructionPartsWithValueType[7]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[8]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[9]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[10]));
        arrayList11.add(new CPOperand(instructionPartsWithValueType[11]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[12]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[13]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[14]));
        arrayList12.add(new CPOperand(instructionPartsWithValueType[15]));
        return new DnnCPInstruction(cPOperand6, cPOperand7, cPOperand8, cPOperand9, str2, str, (ArrayList<CPOperand>) arrayList9, (ArrayList<CPOperand>) arrayList10, (ArrayList<CPOperand>) arrayList11, (ArrayList<CPOperand>) arrayList12, Integer.parseInt(instructionPartsWithValueType[17]), Double.parseDouble(instructionPartsWithValueType[18]));
    }

    private static int getScalarInput(ExecutionContext executionContext, ArrayList<CPOperand> arrayList, int i) {
        return (int) executionContext.getScalarInput(arrayList.get(i).getName(), arrayList.get(i).getValueType(), arrayList.get(i).isLiteral()).getLongValue();
    }

    public void processReluBackwardInstruction(ExecutionContext executionContext) {
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
        MatrixBlock matrixBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), matrixInput.isInSparseFormat() || matrixInput2.isInSparseFormat());
        if (!matrixInput.isEmpty() && !matrixInput2.isEmpty()) {
            matrixBlock.allocateBlock();
            LibMatrixDNN.reluBackward(matrixInput, matrixInput2, matrixBlock, this._numThreads);
        }
        executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        executionContext.setMatrixOutput(getOutputVariableName(), matrixBlock, getExtendedOpcode());
    }

    public void processBiasAddInstruction(ExecutionContext executionContext) {
        MatrixBlock matrixBlock;
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
        if (matrixInput2.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + matrixInput2.getNumColumns());
        }
        if (matrixInput.isEmpty() && matrixInput2.isEmpty()) {
            matrixBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), true);
        } else if (matrixInput2.isEmpty()) {
            matrixBlock = new MatrixBlock(matrixInput);
        } else {
            matrixBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), false);
            matrixBlock.allocateDenseBlock();
            LibMatrixDNN.biasAdd(matrixInput, matrixInput2, matrixBlock, this._numThreads);
        }
        executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        executionContext.setMatrixOutput(getOutputVariableName(), matrixBlock, getExtendedOpcode());
    }

    public void processBiasMultiplyInstruction(ExecutionContext executionContext) {
        MatrixBlock allocateBlock;
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
        if (matrixInput2.getNumColumns() != 1) {
            throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + matrixInput2.getNumColumns());
        }
        if (matrixInput2.isEmpty()) {
            allocateBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), true);
        } else {
            allocateBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), matrixInput.isInSparseFormat()).allocateBlock();
            LibMatrixDNN.biasMultiply(matrixInput, matrixInput2, allocateBlock, this._numThreads);
        }
        executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        executionContext.setMatrixOutput(getOutputVariableName(), allocateBlock, getExtendedOpcode());
    }

    public void processBatchNorm2dInstruction(ExecutionContext executionContext) {
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
        MatrixBlock matrixInput3 = executionContext.getMatrixInput(this._in3.getName(), getExtendedOpcode());
        MatrixBlock matrixInput4 = executionContext.getMatrixInput(this._in4.getName(), getExtendedOpcode());
        MatrixBlock matrixInput5 = executionContext.getMatrixInput(this._in5.getName(), getExtendedOpcode());
        String stringValue = executionContext.getScalarInput(this._in6.getName(), this._in6.getValueType(), this._in6.isLiteral()).getStringValue();
        double doubleValue = executionContext.getScalarInput(this._in7.getName(), this._in7.getValueType(), this._in7.isLiteral()).getDoubleValue();
        double doubleValue2 = executionContext.getScalarInput(this._in8.getName(), this._in8.getValueType(), this._in8.isLiteral()).getDoubleValue();
        MatrixBlock allocateBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock2 = new MatrixBlock(matrixInput4.getNumRows(), matrixInput4.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock3 = new MatrixBlock(matrixInput5.getNumRows(), matrixInput5.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock4 = new MatrixBlock(matrixInput4.getNumRows(), matrixInput4.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock5 = new MatrixBlock(matrixInput5.getNumRows(), matrixInput5.getNumColumns(), false).allocateBlock();
        LibMatrixDNN.batchNorm2D(matrixInput, matrixInput2, matrixInput3, matrixInput4, matrixInput5, stringValue, doubleValue, doubleValue2, allocateBlock, allocateBlock2, allocateBlock3, allocateBlock4, allocateBlock5);
        executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in3.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in4.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in5.getName(), getExtendedOpcode());
        executionContext.setMatrixOutput(this.output.getName(), allocateBlock, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out2.getName(), allocateBlock2, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out3.getName(), allocateBlock3, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out4.getName(), allocateBlock4, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out5.getName(), allocateBlock5, getExtendedOpcode());
    }

    public void processBatchNorm2dBackwardInstruction(ExecutionContext executionContext) {
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
        MatrixBlock matrixInput3 = executionContext.getMatrixInput(this._in3.getName(), getExtendedOpcode());
        double doubleValue = executionContext.getScalarInput(this._in4.getName(), this._in4.getValueType(), this._in4.isLiteral()).getDoubleValue();
        MatrixBlock matrixInput4 = executionContext.getMatrixInput(this._in5.getName(), getExtendedOpcode());
        MatrixBlock matrixInput5 = executionContext.getMatrixInput(this._in6.getName(), getExtendedOpcode());
        MatrixBlock allocateBlock = new MatrixBlock(matrixInput.getNumRows(), matrixInput.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock2 = new MatrixBlock(matrixInput3.getNumRows(), matrixInput3.getNumColumns(), false).allocateBlock();
        MatrixBlock allocateBlock3 = new MatrixBlock(matrixInput3.getNumRows(), matrixInput3.getNumColumns(), false).allocateBlock();
        LibMatrixDNN.batchNorm2DBackward(matrixInput, matrixInput2, matrixInput3, doubleValue, matrixInput4, matrixInput5, allocateBlock, allocateBlock2, allocateBlock3);
        executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in3.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in5.getName(), getExtendedOpcode());
        executionContext.releaseMatrixInput(this._in6.getName(), getExtendedOpcode());
        executionContext.setMatrixOutput(this.output.getName(), allocateBlock, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out2.getName(), allocateBlock2, getExtendedOpcode());
        executionContext.setMatrixOutput(this._out3.getName(), allocateBlock3, getExtendedOpcode());
    }

    private static boolean isFilterSparse(MatrixBlock matrixBlock) {
        long numRows = matrixBlock.getNumRows() * matrixBlock.getNumColumns();
        if (matrixBlock.isInSparseFormat() && numRows < 1.0E7d) {
            matrixBlock.sparseToDense();
        }
        return matrixBlock.isInSparseFormat();
    }

    @Override // org.apache.sysml.runtime.instructions.cp.CPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixBlock allocateBlock;
        if (this.instOpcode.equalsIgnoreCase("bias_add")) {
            processBiasAddInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            processBiasMultiplyInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            processReluBackwardInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d")) {
            processBatchNorm2dInstruction(executionContext);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
            processBatchNorm2dBackwardInstruction(executionContext);
            return;
        }
        MatrixBlock matrixInput = this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : executionContext.getMatrixInput(this.input1.getName(), getExtendedOpcode());
        int scalarInput = getScalarInput(executionContext, this._padding, 0);
        int scalarInput2 = getScalarInput(executionContext, this._padding, 1);
        int scalarInput3 = getScalarInput(executionContext, this._stride, 0);
        int scalarInput4 = getScalarInput(executionContext, this._stride, 1);
        int scalarInput5 = getScalarInput(executionContext, this._input_shape, 0);
        int scalarInput6 = getScalarInput(executionContext, this._input_shape, 1);
        int scalarInput7 = getScalarInput(executionContext, this._input_shape, 2);
        int scalarInput8 = getScalarInput(executionContext, this._input_shape, 3);
        int scalarInput9 = getScalarInput(executionContext, this._filter_shape, 0);
        int scalarInput10 = getScalarInput(executionContext, this._filter_shape, 2);
        int scalarInput11 = getScalarInput(executionContext, this._filter_shape, 3);
        int p = (int) DnnUtils.getP(scalarInput7, scalarInput10, scalarInput3, scalarInput);
        int q = (int) DnnUtils.getQ(scalarInput8, scalarInput11, scalarInput4, scalarInput2);
        DnnParameters dnnParameters = new DnnParameters(scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput3, scalarInput4, scalarInput, scalarInput2, this._numThreads);
        dnnParameters.enableNative = NativeHelper.isNativeLibraryLoaded();
        if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling")) {
            if (matrixInput.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * p * q, true);
            } else {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * p * q, false).allocateBlock();
                LibMatrixDNN.PoolingType poolingType = (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("relu_maxpooling")) ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                if (this.instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                    dnnParameters.minValForMaxPoolOperations = 0.0d;
                }
                LibMatrixDNN.pooling(matrixInput, allocateBlock, dnnParameters, poolingType);
            }
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
            if (this.instOpcode.equalsIgnoreCase("avgpooling_backward") ? matrixInput2.isEmpty() : matrixInput.isEmpty() || matrixInput2.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, true);
            } else {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, false).allocateBlock();
                LibMatrixDNN.PoolingType poolingType2 = (this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
                boolean equalsIgnoreCase = this.instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
                if (equalsIgnoreCase) {
                    dnnParameters.minValForMaxPoolOperations = 0.0d;
                }
                LibMatrixDNN.poolingBackward(matrixInput, matrixInput2, allocateBlock, dnnParameters, equalsIgnoreCase, poolingType2);
            }
            executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            resetNumThreads(dnnParameters, scalarInput6 * scalarInput10 * scalarInput11, p * q, matrixInput.getNonZeros() / (matrixInput.getNumRows() * matrixInput.getNumColumns()));
            MatrixBlock matrixInput3 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
            if (matrixInput3.isEmpty() || matrixInput.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput9 * p * q, true);
            } else {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput9 * p * q, matrixInput.isUltraSparse(false) && dnnParameters.bias == null && matrixInput.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory((long) scalarInput5, (long) ((scalarInput9 * p) * q))).allocateBlock();
                if (!dnnParameters.enableNative || isFilterSparse(matrixInput3) || matrixInput.isInSparseFormat()) {
                    LibMatrixDNN.conv2d(matrixInput, matrixInput3, allocateBlock, dnnParameters);
                } else {
                    LibMatrixNative.conv2d(matrixInput, matrixInput3, allocateBlock, dnnParameters);
                }
            }
            executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            resetNumThreads(dnnParameters, scalarInput6 * scalarInput10 * scalarInput11, p * q, matrixInput.getNonZeros() / (matrixInput.getNumRows() * matrixInput.getNumColumns()));
            MatrixBlock matrixInput4 = executionContext.getMatrixInput(this._in3.getName(), getExtendedOpcode());
            MatrixBlock matrixInput5 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
            if (matrixInput5.getNumRows() != dnnParameters.K || matrixInput5.getNumColumns() != 1) {
                throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + matrixInput5.getNumRows() + " " + matrixInput5.getNumColumns() + "]. Expected: [" + dnnParameters.K + ", 1]");
            }
            boolean z = matrixInput4.isEmpty() || matrixInput.isEmpty();
            if (z && matrixInput5.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput9 * p * q, true);
            } else if (!z || matrixInput5.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput9 * p * q, false).allocateBlock();
                if (!matrixInput5.isEmpty()) {
                    dnnParameters.bias = matrixInput5;
                }
                if (!dnnParameters.enableNative || isFilterSparse(matrixInput4) || matrixInput.isInSparseFormat()) {
                    LibMatrixDNN.conv2d(matrixInput, matrixInput4, allocateBlock, dnnParameters);
                } else {
                    LibMatrixNative.conv2d(matrixInput, matrixInput4, allocateBlock, dnnParameters);
                }
            } else {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput9 * p * q, false).allocateBlock();
                for (int i = 0; i < dnnParameters.N; i++) {
                    DnnUtils.fillBias(matrixInput5, allocateBlock.getDenseBlockValues(), i, i + 1, dnnParameters.N, dnnParameters.K, dnnParameters.P * dnnParameters.Q);
                }
            }
            executionContext.releaseMatrixInput(this._in3.getName(), getExtendedOpcode());
            executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            MatrixBlock matrixInput6 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
            if (matrixInput6.isEmpty() || matrixInput.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput9, scalarInput6 * scalarInput10 * scalarInput11, true);
            } else {
                allocateBlock = new MatrixBlock(scalarInput9, scalarInput6 * scalarInput10 * scalarInput11, false).allocateBlock();
                if (!dnnParameters.enableNative || matrixInput.isInSparseFormat() || matrixInput6.isInSparseFormat()) {
                    LibMatrixDNN.conv2dBackwardFilter(matrixInput, matrixInput6, allocateBlock, dnnParameters);
                } else {
                    LibMatrixNative.conv2dBackwardFilter(matrixInput, matrixInput6, allocateBlock, dnnParameters);
                }
            }
            executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        } else {
            if (!this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
                throw new DMLRuntimeException("Unsupported op code " + this.instOpcode);
            }
            MatrixBlock matrixInput7 = executionContext.getMatrixInput(this._in2.getName(), getExtendedOpcode());
            if (matrixInput7.isEmpty() || matrixInput.isEmpty()) {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, true);
            } else {
                allocateBlock = new MatrixBlock(scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, false).allocateBlock();
                if (!dnnParameters.enableNative || isFilterSparse(matrixInput) || matrixInput7.isInSparseFormat()) {
                    LibMatrixDNN.conv2dBackwardData(matrixInput, matrixInput7, allocateBlock, dnnParameters);
                } else {
                    LibMatrixNative.conv2dBackwardData(matrixInput, matrixInput7, allocateBlock, dnnParameters);
                }
            }
            executionContext.releaseMatrixInput(this._in2.getName(), getExtendedOpcode());
        }
        if (!this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            executionContext.releaseMatrixInput(this.input1.getName(), getExtendedOpcode());
        }
        executionContext.setMatrixOutput(getOutputVariableName(), allocateBlock, getExtendedOpcode());
    }

    private void resetNumThreads(DnnParameters dnnParameters, int i, int i2, double d) {
        if (DMLScript.USE_ACCELERATOR) {
            int floor = (int) Math.floor(this._intermediateMemoryBudget / OptimizerUtils.estimateSizeExactSparsity(i, i2, d));
            if (dnnParameters.numThreads > floor) {
                dnnParameters.numThreads = floor;
                if (!warnedUnderUtilitization) {
                    LOG.warn("CPU Under-utilization to respect the intermediate memory budget. To avoid this, please try reducing the mini-batch or forcing gpu execution.");
                }
                warnedUnderUtilitization = true;
            }
        }
    }
}
