package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;

/* loaded from: input_file:org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.class */
public class IPAPassPropagateReplaceLiterals extends IPAPass {
    @Override // org.apache.sysml.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return true;
    }

    @Override // org.apache.sysml.hops.ipa.IPAPass
    public void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        LocalVariableMap localVariableMap = new LocalVariableMap();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            localVariableMap.removeAllIn(next.variablesUpdated().getVariableNames());
            rReplaceLiterals(next, localVariableMap);
            if (HopRewriteUtils.isLastLevelStatementBlock(next)) {
                Iterator<Hop> it2 = next.getHops().iterator();
                while (it2.hasNext()) {
                    Hop next2 = it2.next();
                    if (HopRewriteUtils.isData(next2, Hop.DataOpTypes.TRANSIENTWRITE) && (next2.getInput().get(0) instanceof LiteralOp)) {
                        localVariableMap.put(next2.getName(), ScalarObjectFactory.createScalarObject((LiteralOp) next2.getInput().get(0)));
                    }
                }
            }
        }
        for (String str : functionCallGraph.getReachableFunctions()) {
            List<FunctionOp> functionCalls = functionCallGraph.getFunctionCalls(str);
            if (!functionCalls.isEmpty()) {
                FunctionOp functionOp = functionCalls.get(0);
                if (functionCallSizeInfo.hasSafeLiterals(str)) {
                    FunctionStatement functionStatement = (FunctionStatement) dMLProgram.getFunctionStatementBlock(str).getStatement(0);
                    ArrayList<DataIdentifier> inputParams = functionStatement.getInputParams();
                    LocalVariableMap localVariableMap2 = new LocalVariableMap();
                    for (int i = 0; i < inputParams.size(); i++) {
                        if (functionCallSizeInfo.isSafeLiteral(str, i)) {
                            LiteralOp literalOp = (LiteralOp) functionOp.getInput().get(i);
                            localVariableMap2.put(inputParams.get(i).getName(), ScalarObjectFactory.createScalarObject(literalOp.getValueType(), literalOp));
                        }
                    }
                    Iterator<StatementBlock> it3 = functionStatement.getBody().iterator();
                    while (it3.hasNext()) {
                        rReplaceLiterals(it3.next(), localVariableMap2);
                    }
                }
            }
        }
    }

    private void rReplaceLiterals(StatementBlock statementBlock, LocalVariableMap localVariableMap) {
        for (String str : statementBlock.variablesUpdated().getVariableNames()) {
            if (localVariableMap.keySet().contains(str)) {
                localVariableMap.remove(str);
            }
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) statementBlock.getStatement(0);
            replaceLiterals(((WhileStatementBlock) statementBlock).getPredicateHops(), localVariableMap);
            Iterator<StatementBlock> it = whileStatement.getBody().iterator();
            while (it.hasNext()) {
                rReplaceLiterals(it.next(), localVariableMap);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            replaceLiterals(((IfStatementBlock) statementBlock).getPredicateHops(), localVariableMap);
            Iterator<StatementBlock> it2 = ifStatement.getIfBody().iterator();
            while (it2.hasNext()) {
                rReplaceLiterals(it2.next(), localVariableMap);
            }
            Iterator<StatementBlock> it3 = ifStatement.getElseBody().iterator();
            while (it3.hasNext()) {
                rReplaceLiterals(it3.next(), localVariableMap);
            }
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            replaceLiterals(statementBlock.getHops(), localVariableMap);
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) statementBlock.getStatement(0);
        replaceLiterals(forStatementBlock.getFromHops(), localVariableMap);
        replaceLiterals(forStatementBlock.getToHops(), localVariableMap);
        replaceLiterals(forStatementBlock.getIncrementHops(), localVariableMap);
        Iterator<StatementBlock> it4 = forStatement.getBody().iterator();
        while (it4.hasNext()) {
            rReplaceLiterals(it4.next(), localVariableMap);
        }
    }

    private static void replaceLiterals(ArrayList<Hop> arrayList, LocalVariableMap localVariableMap) {
        if (arrayList == null) {
            return;
        }
        try {
            Hop.resetVisitStatus(arrayList);
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                Recompiler.rReplaceLiterals(it.next(), localVariableMap, true);
            }
            Hop.resetVisitStatus(arrayList);
        } catch (Exception e) {
            throw new HopsException(e);
        }
    }

    private static void replaceLiterals(Hop hop, LocalVariableMap localVariableMap) {
        if (hop == null) {
            return;
        }
        try {
            hop.resetVisitStatus();
            Recompiler.rReplaceLiterals(hop, localVariableMap, true);
            hop.resetVisitStatus();
        } catch (Exception e) {
            throw new HopsException(e);
        }
    }
}
