package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixOuterAgg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/mr/UaggOuterChainInstruction.class */
public class UaggOuterChainInstruction extends BinaryInstruction implements IDistributedCacheConsumer {
    private AggregateUnaryOperator _uaggOp;
    private AggregateOperator _aggOp;
    private BinaryOperator _bOp;
    private MatrixValue _tmpVal1;
    private MatrixValue _tmpVal2;
    private double[] _bv;
    private int[] _bvi;

    private UaggOuterChainInstruction(BinaryOperator binaryOperator, AggregateUnaryOperator aggregateUnaryOperator, AggregateOperator aggregateOperator, byte b, byte b2, byte b3, String str) {
        super(MRInstruction.MRType.UaggOuterChain, null, b, b2, b3, str);
        this._uaggOp = null;
        this._aggOp = null;
        this._bOp = null;
        this._tmpVal1 = null;
        this._tmpVal2 = null;
        this._bv = null;
        this._bvi = null;
        this._uaggOp = aggregateUnaryOperator;
        this._aggOp = aggregateOperator;
        this._bOp = binaryOperator;
        this._tmpVal1 = new MatrixBlock();
        this._tmpVal2 = new MatrixBlock();
        this.instString = str;
    }

    public static UaggOuterChainInstruction parseInstruction(String str) {
        InstructionUtils.checkNumFields(str, 5);
        String[] instructionParts = InstructionUtils.getInstructionParts(str);
        AggregateUnaryOperator parseBasicAggregateUnaryOperator = InstructionUtils.parseBasicAggregateUnaryOperator(instructionParts[1]);
        BinaryOperator parseBinaryOperator = InstructionUtils.parseBinaryOperator(instructionParts[2]);
        byte parseByte = Byte.parseByte(instructionParts[3]);
        byte parseByte2 = Byte.parseByte(instructionParts[4]);
        byte parseByte3 = Byte.parseByte(instructionParts[5]);
        String deriveAggregateOperatorOpcode = InstructionUtils.deriveAggregateOperatorOpcode(instructionParts[1]);
        PartialAggregate.CorrectionLocationType deriveAggregateOperatorCorrectionLocation = InstructionUtils.deriveAggregateOperatorCorrectionLocation(instructionParts[1]);
        return new UaggOuterChainInstruction(parseBinaryOperator, parseBasicAggregateUnaryOperator, InstructionUtils.parseAggregateOperator(deriveAggregateOperatorOpcode, deriveAggregateOperatorCorrectionLocation != PartialAggregate.CorrectionLocationType.NONE ? "true" : ExternalFunctionStatement.DEFAULT_SIDE_EFFECTS, deriveAggregateOperatorCorrectionLocation.toString()), parseByte, parseByte2, parseByte3, str);
    }

    public void computeOutputCharacteristics(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, MatrixCharacteristics matrixCharacteristics3) {
        if (this._uaggOp.indexFn instanceof ReduceAll) {
            matrixCharacteristics3.set(1L, 1L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics2.getColsPerBlock());
        } else if (this._uaggOp.indexFn instanceof ReduceCol) {
            matrixCharacteristics3.set(matrixCharacteristics.getRows(), 1L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics2.getColsPerBlock());
        } else if (this._uaggOp.indexFn instanceof ReduceRow) {
            matrixCharacteristics3.set(1L, matrixCharacteristics2.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics2.getColsPerBlock());
        }
    }

    @Override // org.apache.sysml.runtime.instructions.mr.BinaryInstruction, org.apache.sysml.runtime.instructions.mr.MRInstruction
    public void processInstruction(Class<? extends MatrixValue> cls, CachedValueMap cachedValueMap, IndexedMatrixValue indexedMatrixValue, IndexedMatrixValue indexedMatrixValue2, int i, int i2) {
        boolean z = (this._uaggOp.indexFn instanceof ReduceCol) || (this._uaggOp.indexFn instanceof ReduceAll) || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp);
        ArrayList<IndexedMatrixValue> arrayList = z ? cachedValueMap.get(this.input1) : cachedValueMap.get(this.input2);
        if (arrayList == null) {
            return;
        }
        Iterator<IndexedMatrixValue> it = arrayList.iterator();
        while (it.hasNext()) {
            IndexedMatrixValue next = it.next();
            if (next != null) {
                MatrixIndexes indexes = next.getIndexes();
                MatrixValue value = next.getValue();
                IndexedMatrixValue holdPlace = cachedValueMap.holdPlace(this.output, cls);
                MatrixIndexes indexes2 = holdPlace.getIndexes();
                MatrixValue value2 = holdPlace.getValue();
                MatrixBlock matrixBlock = null;
                DistributedCacheInput distributedCacheInput = MRBaseForCommonInstructions.dcValues.get(Byte.valueOf(z ? this.input2 : this.input1));
                if (LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
                    if (LibMatrixOuterAgg.isRowIndexMax(this._uaggOp) || LibMatrixOuterAgg.isRowIndexMin(this._uaggOp)) {
                        if (this._bv == null) {
                            if (z) {
                                this._bv = distributedCacheInput.getRowVectorArray();
                            } else {
                                this._bv = distributedCacheInput.getColumnVectorArray();
                            }
                            this._bvi = LibMatrixOuterAgg.prepareRowIndices(this._bv.length, this._bv, this._bOp, this._uaggOp);
                        }
                    } else if (this._bv == null) {
                        if (z) {
                            this._bv = distributedCacheInput.getRowVectorArray();
                        } else {
                            this._bv = distributedCacheInput.getColumnVectorArray();
                        }
                        Arrays.sort(this._bv);
                    }
                    LibMatrixOuterAgg.resetOutputMatrix(indexes, (MatrixBlock) value, indexes2, (MatrixBlock) value2, this._uaggOp);
                    LibMatrixOuterAgg.aggregateMatrix((MatrixBlock) value, (MatrixBlock) value2, this._bv, this._bvi, this._bOp, this._uaggOp);
                } else {
                    long ceil = (long) Math.ceil(distributedCacheInput.getNumCols() / distributedCacheInput.getNumColsPerBlock());
                    for (int i3 = 1; i3 <= ceil; i3++) {
                        OperationsOnMatrixValues.performBinaryIgnoreIndexes(value, distributedCacheInput.getDataBlock(1, i3).getValue(), this._tmpVal1, this._bOp);
                        OperationsOnMatrixValues.performAggregateUnary(indexes, this._tmpVal1, indexes2, this._tmpVal2, this._uaggOp, i, i2);
                        if (matrixBlock == null) {
                            value2.reset(this._tmpVal2.getNumRows(), this._tmpVal2.getNumColumns(), false);
                            matrixBlock = new MatrixBlock(this._tmpVal2.getNumRows(), this._tmpVal2.getNumColumns(), false);
                        }
                        if (this._aggOp.correctionExists) {
                            OperationsOnMatrixValues.incrementalAggregation(value2, matrixBlock, this._tmpVal2, this._aggOp, true);
                        } else {
                            OperationsOnMatrixValues.incrementalAggregation(value2, null, this._tmpVal2, this._aggOp, true);
                        }
                    }
                }
            }
        }
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public boolean isDistCacheOnlyIndex(String str, byte b) {
        return ((this._uaggOp.indexFn instanceof ReduceCol) || (this._uaggOp.indexFn instanceof ReduceAll) || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) ? b == this.input2 && b != this.input1 : b == this.input1 && b != this.input2;
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public void addDistCacheIndex(String str, ArrayList<Byte> arrayList) {
        if ((this._uaggOp.indexFn instanceof ReduceCol) || (this._uaggOp.indexFn instanceof ReduceAll) || !LibMatrixOuterAgg.isSupportedUaggOp(this._uaggOp, this._bOp)) {
            arrayList.add(Byte.valueOf(this.input2));
        } else {
            arrayList.add(Byte.valueOf(this.input1));
        }
    }
}
