package org.apache.sysml.hops.codegen.cplan;

import java.util.Arrays;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/hops/codegen/cplan/CNodeTernary.class */
public class CNodeTernary extends CNode {
    private final TernaryType _type;

    /* loaded from: input_file:org/apache/sysml/hops/codegen/cplan/CNodeTernary$TernaryType.class */
    public enum TernaryType {
        PLUS_MULT,
        MINUS_MULT,
        BIASADD,
        BIASMULT,
        REPLACE,
        REPLACE_NAN,
        IFELSE,
        LOOKUP_RC1,
        LOOKUP_RVECT1;

        public static boolean contains(String str) {
            return Arrays.stream(values()).anyMatch(ternaryType -> {
                return ternaryType.name().equals(str);
            });
        }

        public String getTemplate(boolean z) {
            switch (this) {
                case PLUS_MULT:
                    return "    double %TMP% = %IN1% + %IN2% * %IN3%;\n";
                case MINUS_MULT:
                    return "    double %TMP% = %IN1% - %IN2% * %IN3%;\n";
                case BIASADD:
                    return "    double %TMP% = %IN1% + getValue(%IN2%, cix/%IN3%);\n";
                case BIASMULT:
                    return "    double %TMP% = %IN1% * getValue(%IN2%, cix/%IN3%);\n";
                case REPLACE:
                    return "    double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) && Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
                case REPLACE_NAN:
                    return "    double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
                case IFELSE:
                    return "    double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
                case LOOKUP_RC1:
                    return z ? "    double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" : "    double %TMP% = getValue(%IN1%, %IN2%, rix, %IN3%-1);\n";
                case LOOKUP_RVECT1:
                    return "    double[] %TMP% = getVector(%IN1%, %IN2%, rix, %IN3%-1);\n";
                default:
                    throw new RuntimeException("Invalid ternary type: " + toString());
            }
        }

        public boolean isVectorPrimitive() {
            return this == LOOKUP_RVECT1;
        }
    }

    public CNodeTernary(CNode cNode, CNode cNode2, CNode cNode3, TernaryType ternaryType) {
        this._inputs.add(cNode);
        this._inputs.add(cNode2);
        this._inputs.add(cNode3);
        this._type = ternaryType;
        setOutputDims();
    }

    public TernaryType getType() {
        return this._type;
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public String codegen(boolean z) {
        if (isGenerated()) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(this._inputs.get(0).codegen(z));
        sb.append(this._inputs.get(1).codegen(z));
        sb.append(this._inputs.get(2).codegen(z));
        String replace = this._type.getTemplate(z && (this._inputs.get(0) instanceof CNodeData) && this._inputs.get(0).getVarname().startsWith(GPUInstruction.MISC_TIMER_ALLOCATE) && !this._inputs.get(0).isLiteral()).replace("%TMP%", createVarname());
        for (int i = 1; i <= 3; i++) {
            String varname = this._inputs.get(i - 1).getVarname();
            replace = replace.replace("%IN" + i + "v%", varname + (varname.startsWith(GPUInstruction.MISC_TIMER_ALLOCATE) ? "vals" : "")).replace("%IN" + i + "i%", varname + (varname.startsWith(GPUInstruction.MISC_TIMER_ALLOCATE) ? "ix" : "")).replace("%IN" + i + "%", varname);
        }
        sb.append(replace);
        this._generated = true;
        return sb.toString();
    }

    public String toString() {
        switch (this._type) {
            case PLUS_MULT:
                return "t(+*)";
            case MINUS_MULT:
                return "t(-*)";
            case BIASADD:
                return "t(bias+)";
            case BIASMULT:
                return "t(bias*)";
            case REPLACE:
            case REPLACE_NAN:
                return "t(rplc)";
            case IFELSE:
                return "t(ifelse)";
            case LOOKUP_RC1:
                return "u(ixrc1)";
            case LOOKUP_RVECT1:
                return "u(ixrv1)";
            default:
                return super.toString();
        }
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public void setOutputDims() {
        switch (this._type) {
            case PLUS_MULT:
            case MINUS_MULT:
            case BIASADD:
            case BIASMULT:
            case REPLACE:
            case REPLACE_NAN:
            case IFELSE:
            case LOOKUP_RC1:
                this._rows = 0L;
                this._cols = 0L;
                this._dataType = Expression.DataType.SCALAR;
                return;
            case LOOKUP_RVECT1:
                this._rows = 1L;
                this._cols = this._inputs.get(0)._cols;
                this._dataType = Expression.DataType.MATRIX;
                return;
            default:
                return;
        }
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public int hashCode() {
        if (this._hash == 0) {
            this._hash = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
        }
        return this._hash;
    }

    @Override // org.apache.sysml.hops.codegen.cplan.CNode
    public boolean equals(Object obj) {
        if (!(obj instanceof CNodeTernary)) {
            return false;
        }
        CNodeTernary cNodeTernary = (CNodeTernary) obj;
        return super.equals(cNodeTernary) && this._type == cNodeTernary._type;
    }
}
