package org.apache.sysml.hops;

import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.Nary;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysml/hops/NaryOp.class */
public class NaryOp extends Hop {
    protected Hop.OpOpN _op;

    protected NaryOp() {
        this._op = null;
    }

    public NaryOp(String str, Expression.DataType dataType, Expression.ValueType valueType, Hop.OpOpN opOpN, Hop... hopArr) {
        super(str, dataType, valueType);
        this._op = null;
        this._op = opOpN;
        for (int i = 0; i < hopArr.length; i++) {
            getInput().add(i, hopArr[i]);
            hopArr[i].getParent().add(this);
        }
        refreshSizeInformation();
    }

    @Override // org.apache.sysml.hops.Hop
    public void checkArity() {
    }

    public Hop.OpOpN getOp() {
        return this._op;
    }

    @Override // org.apache.sysml.hops.Hop
    public String getOpString() {
        return "m(" + this._op.name().toLowerCase() + ")";
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysml.hops.Hop
    public Lop constructLops() {
        if (getLops() != null) {
            return getLops();
        }
        try {
            Lop[] lopArr = new Lop[getInput().size()];
            for (int i = 0; i < getInput().size(); i++) {
                lopArr[i] = getInput().get(i).constructLops();
            }
            Nary.OperationType operationType = HopsOpOpNLops.get(this._op);
            if (operationType == null) {
                throw new HopsException("Unknown Nary Lop type for '" + this._op + "'");
            }
            Nary nary = new Nary(operationType, getDataType(), getValueType(), lopArr, optFindExecType());
            setOutputDimensions(nary);
            setLineNumbers(nary);
            setLops(nary);
            constructAndSetLopsDataFlowProperties();
            return getLops();
        } catch (Exception e) {
            throw new HopsException(printErrorLocation() + "error constructing Lops for NaryOp -- \n ", e);
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean allowsAllExecTypes() {
        return false;
    }

    @Override // org.apache.sysml.hops.Hop
    public void computeMemEstimate(MemoTable memoTable) {
        super.computeMemEstimate(memoTable);
        if (this._op == Hop.OpOpN.EVAL) {
            this._memEstimate = 4.0d;
            this._outputMemEstimate = 4.0d;
            this._processingMemEstimate = 0.0d;
        }
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, OptimizerUtils.getSparsity(j, j2, j3));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.hops.Hop
    public LopProperties.ExecType optFindExecType() {
        checkAndSetForcedPlatform();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (areDimsBelowThreshold()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        setRequiresRecompileIfNecessary();
        if (this._op == Hop.OpOpN.PRINTF || this._op == Hop.OpOpN.EVAL || this._op == Hop.OpOpN.LIST) {
            this._etype = LopProperties.ExecType.CP;
        }
        return this._etype;
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return 0.0d;
    }

    @Override // org.apache.sysml.hops.Hop
    protected long[] inferOutputCharacteristics(MemoTable memoTable) {
        if (getDataType().isScalar()) {
            return null;
        }
        MatrixCharacteristics[] allInputStats = memoTable.getAllInputStats(getInput());
        switch (this._op) {
            case CBIND:
                return new long[]{HopRewriteUtils.getMaxInputDim(allInputStats, true), HopRewriteUtils.getSumValidInputDims(allInputStats, false), HopRewriteUtils.getSumValidInputNnz(allInputStats, true)};
            case RBIND:
                return new long[]{HopRewriteUtils.getSumValidInputDims(allInputStats, true), HopRewriteUtils.getMaxInputDim(allInputStats, false), HopRewriteUtils.getSumValidInputNnz(allInputStats, true)};
            case MIN:
            case MAX:
                return new long[]{HopRewriteUtils.getMaxInputDim((Hop) this, true), HopRewriteUtils.getMaxInputDim((Hop) this, false), -1};
            case LIST:
                return new long[]{getInput().size(), 1, -1};
            default:
                return null;
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public void refreshSizeInformation() {
        switch (this._op) {
            case CBIND:
                setDim1(HopRewriteUtils.getMaxInputDim((Hop) this, true));
                setDim2(HopRewriteUtils.getSumValidInputDims((Hop) this, false));
                setNnz(HopRewriteUtils.getSumValidInputNnz(this));
                return;
            case RBIND:
                setDim1(HopRewriteUtils.getSumValidInputDims((Hop) this, true));
                setDim2(HopRewriteUtils.getMaxInputDim((Hop) this, false));
                setNnz(HopRewriteUtils.getSumValidInputNnz(this));
                return;
            case MIN:
            case MAX:
                setDim1(getDataType().isScalar() ? 0L : HopRewriteUtils.getMaxInputDim((Hop) this, true));
                setDim2(getDataType().isScalar() ? 0L : HopRewriteUtils.getMaxInputDim((Hop) this, false));
                return;
            case LIST:
                setDim1(getInput().size());
                setDim2(1L);
                return;
            case PRINTF:
            case EVAL:
            default:
                return;
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        NaryOp naryOp = new NaryOp();
        naryOp.clone(this, false);
        naryOp._op = this._op;
        return naryOp;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof NaryOp) || this._op == Hop.OpOpN.PRINTF) {
            return false;
        }
        NaryOp naryOp = (NaryOp) hop;
        boolean z = this._op == naryOp._op && getInput().size() == naryOp.getInput().size();
        for (int i = 0; i < getInput().size() && z; i++) {
            z &= getInput().get(i) == naryOp.getInput().get(i);
        }
        return z;
    }
}
