package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.class */
public class IPAPassEliminateDeadCode extends IPAPass {
    @Override // org.apache.sysml.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return true;
    }

    @Override // org.apache.sysml.hops.ipa.IPAPass
    public void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        findAndRemoveDeadCode(dMLProgram.getStatementBlocks(), new HashSet(), functionCallGraph);
        Iterator<FunctionStatementBlock> it = dMLProgram.getFunctionStatementBlocks().iterator();
        while (it.hasNext()) {
            FunctionStatementBlock next = it.next();
            HashSet hashSet = new HashSet();
            FunctionStatement functionStatement = (FunctionStatement) next.getStatement(0);
            functionStatement.getOutputParams().stream().forEach(dataIdentifier -> {
                hashSet.add(dataIdentifier.getName());
            });
            findAndRemoveDeadCode(functionStatement.getBody(), hashSet, functionCallGraph);
        }
    }

    private static void findAndRemoveDeadCode(List<StatementBlock> list, Set<String> set, FunctionCallGraph functionCallGraph) {
        for (int size = list.size() - 1; size >= 0; size--) {
            if (HopRewriteUtils.isLastLevelStatementBlock(list.get(size))) {
                ArrayList<Hop> hops = list.get(size).getHops();
                int i = 0;
                while (i < hops.size()) {
                    Hop hop = hops.get(i);
                    boolean isData = HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTWRITE);
                    boolean isFunctionCallWithUnusedOutputs = isFunctionCallWithUnusedOutputs(hop, set, functionCallGraph);
                    if ((isData && !set.contains(hop.getName())) || isFunctionCallWithUnusedOutputs) {
                        if (isFunctionCallWithUnusedOutputs) {
                            functionCallGraph.removeFunctionCall(((FunctionOp) hop).getFunctionKey(), (FunctionOp) hop, list.get(size));
                        }
                        hops.remove(i);
                        i--;
                        rRemoveOpFromDAG(hop);
                    }
                    i++;
                }
            }
            set.addAll(rCollectReadVariableNames(list.get(size), new HashSet()));
        }
    }

    private static boolean isFunctionCallWithUnusedOutputs(Hop hop, Set<String> set, FunctionCallGraph functionCallGraph) {
        return (hop instanceof FunctionOp) && functionCallGraph.isSideEffectFreeFunction(((FunctionOp) hop).getFunctionKey()) && Arrays.stream(((FunctionOp) hop).getOutputVariableNames()).allMatch(str -> {
            return !set.contains(str);
        });
    }

    private static void rRemoveOpFromDAG(Hop hop) {
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            next.getParent().remove(hop);
            if (next.getParent().isEmpty()) {
                rRemoveOpFromDAG(next);
            }
        }
        hop.getInput().clear();
    }

    private static Set<String> rCollectReadVariableNames(StatementBlock statementBlock, Set<String> set) {
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) statementBlock.getStatement(0);
            collectReadVariableNames(((WhileStatementBlock) statementBlock).getPredicateHops(), set);
            Iterator<StatementBlock> it = whileStatement.getBody().iterator();
            while (it.hasNext()) {
                rCollectReadVariableNames(it.next(), set);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) statementBlock.getStatement(0);
            collectReadVariableNames(forStatementBlock.getFromHops(), set);
            collectReadVariableNames(forStatementBlock.getToHops(), set);
            collectReadVariableNames(forStatementBlock.getIncrementHops(), set);
            Iterator<StatementBlock> it2 = forStatement.getBody().iterator();
            while (it2.hasNext()) {
                rCollectReadVariableNames(it2.next(), set);
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            collectReadVariableNames(((IfStatementBlock) statementBlock).getPredicateHops(), set);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                rCollectReadVariableNames(it3.next(), set);
            }
            if (ifStatement.getElseBody() != null) {
                Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
                while (it4.hasNext()) {
                    rCollectReadVariableNames(it4.next(), set);
                }
            }
        } else if (statementBlock.getHops() != null) {
            Hop.resetVisitStatus(statementBlock.getHops());
            Iterator<Hop> it5 = statementBlock.getHops().iterator();
            while (it5.hasNext()) {
                rCollectReadVariableNames(it5.next(), set);
            }
        }
        return set;
    }

    private static Set<String> collectReadVariableNames(Hop hop, Set<String> set) {
        if (hop == null) {
            return set;
        }
        hop.resetVisitStatus();
        return rCollectReadVariableNames(hop, set);
    }

    private static Set<String> rCollectReadVariableNames(Hop hop, Set<String> set) {
        if (hop.isVisited()) {
            return set;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectReadVariableNames(it.next(), set);
        }
        if (HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTREAD)) {
            set.add(hop.getName());
        }
        hop.setVisited();
        return set;
    }
}
