package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysml.lops.WeightedCrossEntropy;
import org.apache.sysml.lops.WeightedDivMM;
import org.apache.sysml.lops.WeightedDivMMR;
import org.apache.sysml.lops.WeightedSigmoid;
import org.apache.sysml.lops.WeightedSquaredLoss;
import org.apache.sysml.lops.WeightedSquaredLossR;
import org.apache.sysml.lops.WeightedUnaryMM;
import org.apache.sysml.lops.WeightedUnaryMMR;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/mr/QuaternaryInstruction.class */
public class QuaternaryInstruction extends MRInstruction implements IDistributedCacheConsumer {
    private byte _input1;
    private byte _input2;
    private byte _input3;
    private byte _input4;
    private boolean _cacheU;
    private boolean _cacheV;

    private QuaternaryInstruction(Operator operator, byte b, byte b2, byte b3, byte b4, byte b5, boolean z, boolean z2, String str) {
        super(MRInstruction.MRType.Quaternary, operator, b5);
        this._input1 = (byte) -1;
        this._input2 = (byte) -1;
        this._input3 = (byte) -1;
        this._input4 = (byte) -1;
        this._cacheU = false;
        this._cacheV = false;
        this.instString = str;
        this._input1 = b;
        this._input2 = b2;
        this._input3 = b3;
        this._input4 = b4;
        this._cacheU = z;
        this._cacheV = z2;
    }

    public byte getInput1() {
        return this._input1;
    }

    public byte getInput2() {
        return this._input2;
    }

    public byte getInput3() {
        return this._input3;
    }

    public byte getInput4() {
        return this._input4;
    }

    public void computeMatrixCharacteristics(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, MatrixCharacteristics matrixCharacteristics3, MatrixCharacteristics matrixCharacteristics4) {
        QuaternaryOperator quaternaryOperator = (QuaternaryOperator) this.optr;
        if (quaternaryOperator.wtype1 != null || quaternaryOperator.wtype4 != null) {
            matrixCharacteristics4.set(1L, 1L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
            return;
        }
        if (quaternaryOperator.wtype2 != null || quaternaryOperator.wtype5 != null) {
            matrixCharacteristics4.set(matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
        } else if (quaternaryOperator.wtype3 != null) {
            boolean z = this._cacheU && this._cacheV;
            MatrixCharacteristics computeOutputCharacteristics = quaternaryOperator.wtype3.computeOutputCharacteristics(matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), quaternaryOperator.wtype3.isLeft() ? z ? matrixCharacteristics3.getCols() : matrixCharacteristics3.getNonZeros() : z ? matrixCharacteristics2.getCols() : matrixCharacteristics2.getNonZeros());
            matrixCharacteristics4.set(computeOutputCharacteristics.getRows(), computeOutputCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
        }
    }

    public static QuaternaryInstruction parseInstruction(String str) {
        String opCode = InstructionUtils.getOpCode(str);
        if (!InstructionUtils.isDistQuaternaryOpcode(opCode)) {
            throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str);
        }
        if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opCode) || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opCode)) {
            boolean equalsIgnoreCase = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opCode);
            if (equalsIgnoreCase) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] instructionParts = InstructionUtils.getInstructionParts(str);
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedSquaredLoss.WeightsType.valueOf(instructionParts[6])), Byte.parseByte(instructionParts[1]), Byte.parseByte(instructionParts[2]), Byte.parseByte(instructionParts[3]), Byte.parseByte(instructionParts[4]), Byte.parseByte(instructionParts[5]), equalsIgnoreCase ? Boolean.parseBoolean(instructionParts[7]) : true, equalsIgnoreCase ? Boolean.parseBoolean(instructionParts[8]) : true, str);
        }
        if (WeightedUnaryMM.OPCODE.equalsIgnoreCase(opCode) || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opCode)) {
            boolean equalsIgnoreCase2 = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opCode);
            if (equalsIgnoreCase2) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] instructionParts2 = InstructionUtils.getInstructionParts(str);
            String str2 = instructionParts2[1];
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedUnaryMM.WUMMType.valueOf(instructionParts2[6]), str2), Byte.parseByte(instructionParts2[2]), Byte.parseByte(instructionParts2[3]), Byte.parseByte(instructionParts2[4]), (byte) -1, Byte.parseByte(instructionParts2[5]), equalsIgnoreCase2 ? Boolean.parseBoolean(instructionParts2[7]) : true, equalsIgnoreCase2 ? Boolean.parseBoolean(instructionParts2[8]) : true, str);
        }
        if (WeightedDivMM.OPCODE.equalsIgnoreCase(opCode) || WeightedDivMMR.OPCODE.equalsIgnoreCase(opCode)) {
            boolean startsWith = opCode.startsWith("red");
            if (startsWith) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] instructionParts3 = InstructionUtils.getInstructionParts(str);
            WeightedDivMM.WDivMMType valueOf = WeightedDivMM.WDivMMType.valueOf(instructionParts3[6]);
            return new QuaternaryInstruction(new QuaternaryOperator(valueOf), Byte.parseByte(instructionParts3[1]), Byte.parseByte(instructionParts3[2]), Byte.parseByte(instructionParts3[3]), valueOf.hasScalar() ? (byte) -1 : Byte.parseByte(instructionParts3[4]), Byte.parseByte(instructionParts3[5]), startsWith ? Boolean.parseBoolean(instructionParts3[7]) : true, startsWith ? Boolean.parseBoolean(instructionParts3[8]) : true, str);
        }
        boolean startsWith2 = opCode.startsWith("red");
        int i = opCode.endsWith(WeightedCrossEntropy.OPCODE_CP) ? 1 : 0;
        if (startsWith2) {
            InstructionUtils.checkNumFields(str, 7 + i);
        } else {
            InstructionUtils.checkNumFields(str, 5 + i);
        }
        String[] instructionParts4 = InstructionUtils.getInstructionParts(str);
        byte parseByte = Byte.parseByte(instructionParts4[1]);
        byte parseByte2 = Byte.parseByte(instructionParts4[2]);
        byte parseByte3 = Byte.parseByte(instructionParts4[3]);
        byte parseByte4 = Byte.parseByte(instructionParts4[4 + i]);
        boolean parseBoolean = startsWith2 ? Boolean.parseBoolean(instructionParts4[6 + i]) : true;
        boolean parseBoolean2 = startsWith2 ? Boolean.parseBoolean(instructionParts4[7 + i]) : true;
        if (opCode.endsWith(WeightedSigmoid.OPCODE_CP)) {
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedSigmoid.WSigmoidType.valueOf(instructionParts4[5])), parseByte, parseByte2, parseByte3, (byte) -1, parseByte4, parseBoolean, parseBoolean2, str);
        }
        if (opCode.endsWith(WeightedCrossEntropy.OPCODE_CP)) {
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedCrossEntropy.WCeMMType.valueOf(instructionParts4[6])), parseByte, parseByte2, parseByte3, (byte) -1, parseByte4, parseBoolean, parseBoolean2, str);
        }
        return null;
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public boolean isDistCacheOnlyIndex(String str, byte b) {
        return (this._cacheU && this._cacheV) ? ((b != this._input2 || b == this._input1 || b == this._input4) && (b != this._input3 || b == this._input1 || b == this._input4)) ? false : true : (this._cacheU && b == this._input2 && b != this._input1 && b != this._input4) || (this._cacheV && b == this._input3 && b != this._input1 && b != this._input4);
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public void addDistCacheIndex(String str, ArrayList<Byte> arrayList) {
        if (this._cacheU) {
            arrayList.add(Byte.valueOf(this._input2));
        }
        if (this._cacheV) {
            arrayList.add(Byte.valueOf(this._input3));
        }
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public byte[] getInputIndexes() {
        return ((QuaternaryOperator) this.optr).hasFourInputs() ? new byte[]{this._input1, this._input2, this._input3, this._input4} : new byte[]{this._input1, this._input2, this._input3};
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public byte[] getAllIndexes() {
        return ((QuaternaryOperator) this.optr).hasFourInputs() ? new byte[]{this._input1, this._input2, this._input3, this._input4, this.output} : new byte[]{this._input1, this._input2, this._input3, this.output};
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public void processInstruction(Class<? extends MatrixValue> cls, CachedValueMap cachedValueMap, IndexedMatrixValue indexedMatrixValue, IndexedMatrixValue indexedMatrixValue2, int i, int i2) {
        QuaternaryOperator quaternaryOperator = (QuaternaryOperator) this.optr;
        ArrayList<IndexedMatrixValue> arrayList = cachedValueMap.get(this._input1);
        if (arrayList != null) {
            Iterator<IndexedMatrixValue> it = arrayList.iterator();
            while (it.hasNext()) {
                IndexedMatrixValue next = it.next();
                if (next != null) {
                    MatrixIndexes indexes = next.getIndexes();
                    MatrixBlock matrixBlock = (MatrixBlock) next.getValue();
                    IndexedMatrixValue holdPlace = this.output == this._input1 ? indexedMatrixValue : cachedValueMap.holdPlace(this.output, cls);
                    MatrixIndexes indexes2 = holdPlace.getIndexes();
                    MatrixValue value = holdPlace.getValue();
                    IndexedMatrixValue first = this._input4 != -1 ? cachedValueMap.getFirst(this._input4) : null;
                    MatrixValue value2 = first != null ? first.getValue() : null;
                    if (null == value2 && quaternaryOperator.hasFourInputs()) {
                        value2 = new MatrixBlock(Double.valueOf(InstructionUtils.getInstructionParts(this.instString)[4]).doubleValue());
                    }
                    MatrixValue value3 = !this._cacheU ? cachedValueMap.getFirst(this._input2).getValue() : MRBaseForCommonInstructions.dcValues.get(Byte.valueOf(this._input2)).getDataBlock((int) indexes.getRowIndex(), 1).getValue();
                    MatrixValue value4 = !this._cacheV ? cachedValueMap.getFirst(this._input3).getValue() : MRBaseForCommonInstructions.dcValues.get(Byte.valueOf(this._input3)).getDataBlock((int) indexes.getColumnIndex(), 1).getValue();
                    if (value3.getNumColumns() != value4.getNumColumns()) {
                        value4 = LibMatrixReorg.reorg((MatrixBlock) value4, new MatrixBlock(value4.getNumColumns(), value4.getNumRows(), value4.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
                    }
                    matrixBlock.quaternaryOperations(quaternaryOperator, (MatrixBlock) value3, (MatrixBlock) value4, (MatrixBlock) value2, (MatrixBlock) value);
                    if (quaternaryOperator.wtype1 != null || quaternaryOperator.wtype4 != null) {
                        indexes2.setIndexes(1L, 1L);
                    } else if (quaternaryOperator.wtype2 == null && quaternaryOperator.wtype5 == null && (quaternaryOperator.wtype3 == null || !quaternaryOperator.wtype3.isBasic())) {
                        indexes2.setIndexes(quaternaryOperator.wtype3.isLeft() ? indexes.getColumnIndex() : indexes.getRowIndex(), 1L);
                    } else {
                        indexes2.setIndexes(indexes);
                    }
                    if (holdPlace == indexedMatrixValue) {
                        cachedValueMap.add(this.output, holdPlace);
                    }
                }
            }
        }
    }
}
