package org.apache.sysml.runtime.instructions;

import java.util.HashMap;
import org.apache.sysml.lops.Append;
import org.apache.sysml.lops.RightIndex;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.AggregateBinaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.AggregateUnaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.ArithmeticBinaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.BuiltinBinaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.BuiltinUnaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MMTSJGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MatrixAppendGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MatrixIndexingGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MatrixMatrixAxpyGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.MatrixReshapeGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.RelationalBinaryGPUInstruction;
import org.apache.sysml.runtime.instructions.gpu.ReorgGPUInstruction;
import org.slf4j.Marker;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/GPUInstructionParser.class */
public class GPUInstructionParser extends InstructionParser {
    static final HashMap<String, GPUInstruction.GPUINSTRUCTION_TYPE> String2GPUInstructionType = new HashMap<>();

    public static GPUInstruction parseSingleInstruction(String str) {
        if (str == null || str.isEmpty()) {
            return null;
        }
        GPUInstruction.GPUINSTRUCTION_TYPE gPUType = InstructionUtils.getGPUType(str);
        if (gPUType == null) {
            throw new DMLRuntimeException("Unable derive cptype for instruction: " + str);
        }
        GPUInstruction parseSingleInstruction = parseSingleInstruction(gPUType, str);
        if (parseSingleInstruction == null) {
            throw new DMLRuntimeException("Unable to parse instruction: " + str);
        }
        return parseSingleInstruction;
    }

    public static GPUInstruction parseSingleInstruction(GPUInstruction.GPUINSTRUCTION_TYPE gpuinstruction_type, String str) {
        if (str == null || str.isEmpty()) {
            return null;
        }
        if (gpuinstruction_type == null) {
            throw new DMLRuntimeException("The instruction is not GPU-enabled:" + str);
        }
        switch (gpuinstruction_type) {
            case AggregateUnary:
                return AggregateUnaryGPUInstruction.parseInstruction(str);
            case AggregateBinary:
                return AggregateBinaryGPUInstruction.parseInstruction(str);
            case BuiltinUnary:
                return BuiltinUnaryGPUInstruction.parseInstruction(str);
            case BuiltinBinary:
                return BuiltinBinaryGPUInstruction.parseInstruction(str);
            case Append:
                return MatrixAppendGPUInstruction.parseInstruction(str);
            case Dnn:
                return DnnGPUInstruction.parseInstruction(str);
            case MMTSJ:
                return MMTSJGPUInstruction.parseInstruction(str);
            case Reorg:
                return ReorgGPUInstruction.parseInstruction(str);
            case MatrixReshape:
                return MatrixReshapeGPUInstruction.parseInstruction(str);
            case ArithmeticBinary:
                String opCode = InstructionUtils.getOpCode(str);
                return (opCode.equals("+*") || opCode.equals("-*")) ? MatrixMatrixAxpyGPUInstruction.parseInstruction(str) : ArithmeticBinaryGPUInstruction.parseInstruction(str);
            case RelationalBinary:
                return RelationalBinaryGPUInstruction.parseInstruction(str);
            case MatrixIndexing:
                return MatrixIndexingGPUInstruction.parseInstruction(str);
            default:
                throw new DMLRuntimeException("Invalid GPU Instruction Type: " + gpuinstruction_type);
        }
    }

    static {
        String2GPUInstructionType.put("relu_backward", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("conv2d", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("conv2d_bias_add", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("conv2d_backward_filter", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("conv2d_backward_data", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("maxpooling", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("maxpooling_backward", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("avgpooling", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("avgpooling_backward", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("bias_add", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("bias_multiply", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("channel_sums", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("lstm", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("lstm_backward", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("batch_norm2d", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("batch_norm2d_backward", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("batch_norm2d_test", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("batch_norm2d_train", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("update_nesterov_x", GPUInstruction.GPUINSTRUCTION_TYPE.Dnn);
        String2GPUInstructionType.put("ba+*", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateBinary);
        String2GPUInstructionType.put("tsmm", GPUInstruction.GPUINSTRUCTION_TYPE.MMTSJ);
        String2GPUInstructionType.put("r'", GPUInstruction.GPUINSTRUCTION_TYPE.Reorg);
        String2GPUInstructionType.put("rshape", GPUInstruction.GPUINSTRUCTION_TYPE.MatrixReshape);
        String2GPUInstructionType.put(Append.OPCODE, GPUInstruction.GPUINSTRUCTION_TYPE.Append);
        String2GPUInstructionType.put(Marker.ANY_NON_NULL_MARKER, GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("-", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("*", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("/", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("%%", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("%/%", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("^", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("1-*", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("^2", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("*2", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("-nz", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("+*", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("-*", GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary);
        String2GPUInstructionType.put("exp", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put(ParForStatementBlock.OPT_LOG, GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("abs", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("sqrt", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("round", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("floor", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("ceil", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("sin", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("cos", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("tan", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("sinh", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("cosh", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("tanh", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("asin", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("acos", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("atan", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("sign", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("sigmoid", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("softmax", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinUnary);
        String2GPUInstructionType.put("solve", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinBinary);
        String2GPUInstructionType.put("min", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinBinary);
        String2GPUInstructionType.put("max", GPUInstruction.GPUINSTRUCTION_TYPE.BuiltinBinary);
        String2GPUInstructionType.put("ua+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uak+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uar+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uark+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uac+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uack+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("ua*", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uamean", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uarmean", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uacmean", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uamax", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uarmax", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uacmax", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uamin", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uarmin", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uacmin", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uasqk+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uarsqk+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uacsqk+", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uavar", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uarvar", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("uacvar", GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary);
        String2GPUInstructionType.put("==", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put("!=", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put("<", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put(">", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put("<=", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put(">=", GPUInstruction.GPUINSTRUCTION_TYPE.RelationalBinary);
        String2GPUInstructionType.put(RightIndex.OPCODE, GPUInstruction.GPUINSTRUCTION_TYPE.MatrixIndexing);
    }
}
