package org.apache.sysml.parser;

import java.util.HashMap;
import org.apache.hadoop.fs.Path;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.LanguageException;

/* loaded from: input_file:org/apache/sysml/parser/BinaryExpression.class */
public class BinaryExpression extends Expression {
    private Expression _left;
    private Expression _right;
    private Expression.BinaryOp _opcode;

    @Override // org.apache.sysml.parser.Expression
    public Expression rewriteExpression(String str) {
        BinaryExpression binaryExpression = new BinaryExpression(this._opcode, this);
        binaryExpression.setLeft(this._left.rewriteExpression(str));
        binaryExpression.setRight(this._right.rewriteExpression(str));
        return binaryExpression;
    }

    public BinaryExpression(Expression.BinaryOp binaryOp) {
        this._opcode = binaryOp;
        setFilename("MAIN SCRIPT");
        setBeginLine(0);
        setBeginColumn(0);
        setEndLine(0);
        setEndColumn(0);
        setText(null);
    }

    public BinaryExpression(Expression.BinaryOp binaryOp, ParseInfo parseInfo) {
        this._opcode = binaryOp;
        setParseInfo(parseInfo);
    }

    public Expression.BinaryOp getOpCode() {
        return this._opcode;
    }

    public void setLeft(Expression expression) {
        this._left = expression;
        if (this._left != null) {
            setParseInfo(this._left);
        }
    }

    public void setRight(Expression expression) {
        this._right = expression;
        if (this._right != null) {
            setParseInfo(this._right);
        }
    }

    public Expression getLeft() {
        return this._left;
    }

    public Expression getRight() {
        return this._right;
    }

    @Override // org.apache.sysml.parser.Expression
    public void validateExpression(HashMap<String, DataIdentifier> hashMap, HashMap<String, ConstIdentifier> hashMap2, boolean z) {
        if ((this._left instanceof FunctionCallIdentifier) || (this._right instanceof FunctionCallIdentifier)) {
            raiseValidateError("User-defined function calls not supported in binary expressions.", false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
        }
        this._left.validateExpression(hashMap, hashMap2, z);
        this._right.validateExpression(hashMap, hashMap2, z);
        if (!z) {
            if ((this._left instanceof DataIdentifier) && hashMap2.containsKey(((DataIdentifier) this._left).getName())) {
                this._left = hashMap2.get(((DataIdentifier) this._left).getName());
            }
            if ((this._right instanceof DataIdentifier) && hashMap2.containsKey(((DataIdentifier) this._right).getName())) {
                this._right = hashMap2.get(((DataIdentifier) this._right).getName());
            }
        }
        DataIdentifier dataIdentifier = new DataIdentifier(getTempName());
        dataIdentifier.setParseInfo(this);
        dataIdentifier.setDataType(computeDataType(getLeft(), getRight(), true));
        Expression.ValueType computeValueType = computeValueType(getLeft(), getRight(), true);
        if (getOpCode() == Expression.BinaryOp.POW || getOpCode() == Expression.BinaryOp.DIV) {
            computeValueType = Expression.ValueType.DOUBLE;
        }
        dataIdentifier.setValueType(computeValueType);
        checkAndSetDimensions(dataIdentifier, z);
        if (getOpCode() == Expression.BinaryOp.MATMULT) {
            if (getLeft().getOutput().getDataType() != Expression.DataType.MATRIX || getRight().getOutput().getDataType() == Expression.DataType.MATRIX) {
            }
            if (getLeft().getOutput().getDim2() != -1 && getRight().getOutput().getDim1() != -1 && getLeft().getOutput().getDim2() != getRight().getOutput().getDim1()) {
                raiseValidateError("invalid dimensions for matrix multiplication (k1=" + getLeft().getOutput().getDim2() + ", k2=" + getRight().getOutput().getDim1() + ")", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
            }
            dataIdentifier.setDimensions(getLeft().getOutput().getDim1(), getRight().getOutput().getDim2());
        }
        setOutput(dataIdentifier);
    }

    private void checkAndSetDimensions(DataIdentifier dataIdentifier, boolean z) {
        Identifier output = getLeft().getOutput();
        Identifier output2 = getRight().getOutput();
        Identifier identifier = null;
        Identifier identifier2 = null;
        if (output.getDataType() == Expression.DataType.MATRIX) {
            identifier = output;
            if (output2.getDataType() == Expression.DataType.MATRIX) {
                identifier2 = output2;
            }
        } else if (output2.getDataType() == Expression.DataType.MATRIX) {
            identifier = output2;
        }
        if (identifier != null && identifier2 != null && isSameDimensionBinaryOp(getOpCode()) && identifier.dimsKnown() && identifier2.dimsKnown() && ((identifier.getDim1() != identifier2.getDim1() && identifier2.getDim1() > 1) || (identifier.getDim2() != identifier2.getDim2() && identifier2.getDim2() > 1))) {
            raiseValidateError("Mismatch in dimensions for operation '" + getText() + "'. " + identifier + " is " + identifier.getDim1() + "x" + identifier.getDim2() + " and " + identifier2 + " is " + identifier2.getDim1() + "x" + identifier2.getDim2() + Path.CUR_DIR, z);
        }
        if (identifier != null) {
            dataIdentifier.setDimensions(identifier.getDim1(), identifier.getDim2());
        }
    }

    public String toString() {
        return "(" + (this._left instanceof StringIdentifier ? "\"" + this._left.toString() + "\"" : this._left.toString()) + " " + this._opcode.toString() + " " + (this._right instanceof StringIdentifier ? "\"" + this._right.toString() + "\"" : this._right.toString()) + ")";
    }

    @Override // org.apache.sysml.parser.Expression
    public VariableSet variablesRead() {
        VariableSet variableSet = new VariableSet();
        variableSet.addVariables(this._left.variablesRead());
        variableSet.addVariables(this._right.variablesRead());
        return variableSet;
    }

    @Override // org.apache.sysml.parser.Expression
    public VariableSet variablesUpdated() {
        VariableSet variableSet = new VariableSet();
        variableSet.addVariables(this._left.variablesUpdated());
        variableSet.addVariables(this._right.variablesUpdated());
        return variableSet;
    }

    public static boolean isSameDimensionBinaryOp(Expression.BinaryOp binaryOp) {
        return binaryOp == Expression.BinaryOp.PLUS || binaryOp == Expression.BinaryOp.MINUS || binaryOp == Expression.BinaryOp.MULT || binaryOp == Expression.BinaryOp.DIV || binaryOp == Expression.BinaryOp.MODULUS || binaryOp == Expression.BinaryOp.INTDIV || binaryOp == Expression.BinaryOp.POW;
    }
}
