package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.class */
public class RewriteMergeBlockSequence extends StatementBlockRewriteRule {
    private ProgramRewriter rewriter = new ProgramRewriter(new RewriteCommonSubexpressionElimination(true));

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        return Arrays.asList(statementBlock);
    }

    @Override // org.apache.sysml.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        if (list == null || list.isEmpty()) {
            return list;
        }
        ArrayList arrayList = new ArrayList(list);
        boolean z = true;
        while (z) {
            z = false;
            for (int i = 0; i < arrayList.size() - 1; i++) {
                StatementBlock statementBlock = (StatementBlock) arrayList.get(i);
                StatementBlock statementBlock2 = (StatementBlock) arrayList.get(i + 1);
                if (HopRewriteUtils.isLastLevelStatementBlock(statementBlock) && HopRewriteUtils.isLastLevelStatementBlock(statementBlock2) && !statementBlock.isSplitDag() && !statementBlock2.isSplitDag() && ((!hasExternalFunctionOpRootWithSideEffect(statementBlock) || !hasExternalFunctionOpRootWithSideEffect(statementBlock2)) && ((!hasFunctionOpRoot(statementBlock) || !hasFunctionIOConflict(statementBlock, statementBlock2)) && (!hasFunctionOpRoot(statementBlock2) || !hasFunctionIOConflict(statementBlock2, statementBlock))))) {
                    ArrayList<Hop> hops = statementBlock.getHops();
                    ArrayList<Hop> hops2 = statementBlock2.getHops();
                    ArrayList<Hop> arrayList2 = new ArrayList<>();
                    Hop.resetVisitStatus(hops2);
                    HashMap<String, Hop> hashMap = new HashMap<>();
                    HashMap<String, Hop> hashMap2 = new HashMap<>();
                    Iterator<Hop> it = hops2.iterator();
                    while (it.hasNext()) {
                        rCollectTransientReadWrites(it.next(), hashMap, hashMap2);
                    }
                    Hop.resetVisitStatus(hops2);
                    Hop.resetVisitStatus(hops);
                    Iterator<Hop> it2 = hops.iterator();
                    while (it2.hasNext()) {
                        Hop next = it2.next();
                        if (HopRewriteUtils.isData(next, Hop.DataOpTypes.TRANSIENTWRITE) && hashMap.containsKey(next.getName())) {
                            Hop hop = hashMap.get(next.getName());
                            Hop hop2 = next.getInput().get(0);
                            Iterator it3 = new ArrayList(hop.getParent()).iterator();
                            while (it3.hasNext()) {
                                HopRewriteUtils.replaceChildReference((Hop) it3.next(), hop, hop2);
                            }
                            HopRewriteUtils.removeAllChildReferences(next);
                            if (!hashMap2.containsKey(next.getName()) && statementBlock2.liveOut().containsVariable(next.getName())) {
                                arrayList2.add(HopRewriteUtils.createDataOp(next.getName(), hop2, Hop.DataOpTypes.TRANSIENTWRITE));
                            }
                        } else if (!HopRewriteUtils.isData(next, Hop.DataOpTypes.TRANSIENTWRITE) || (!hashMap2.containsKey(next.getName()) && statementBlock2.liveOut().containsVariable(next.getName()))) {
                            arrayList2.add(next);
                        }
                    }
                    hops.clear();
                    arrayList2.addAll(hops2);
                    statementBlock2.setHops(arrayList2);
                    Hop.resetVisitStatus(statementBlock2.getHops());
                    this.rewriter.rewriteHopDAG(statementBlock2.getHops(), new ProgramRewriteStatus());
                    statementBlock2.setLiveIn(statementBlock.liveIn());
                    statementBlock2.setGen(VariableSet.minus(VariableSet.union(statementBlock.getGen(), statementBlock2.getGen()), statementBlock.getKill()));
                    statementBlock2.setKill(VariableSet.union(statementBlock.getKill(), statementBlock2.getKill()));
                    statementBlock2.setReadVariables(VariableSet.union(statementBlock.variablesRead(), statementBlock2.variablesRead()));
                    statementBlock2.setUpdatedVariables(VariableSet.union(statementBlock.variablesUpdated(), statementBlock2.variablesUpdated()));
                    LOG.debug("Applied mergeStatementBlockSequences (blocks of lines " + statementBlock.getBeginLine() + "-" + statementBlock.getEndLine() + " and " + statementBlock2.getBeginLine() + "-" + statementBlock2.getEndLine() + ").");
                    statementBlock2.setBeginLine(statementBlock.getBeginLine());
                    statementBlock2.setBeginColumn(statementBlock.getBeginColumn());
                    arrayList.remove(i);
                    z = true;
                }
            }
        }
        return arrayList;
    }

    private void rCollectTransientReadWrites(Hop hop, HashMap<String, Hop> hashMap, HashMap<String, Hop> hashMap2) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectTransientReadWrites(it.next(), hashMap, hashMap2);
        }
        if (HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTREAD)) {
            hashMap.put(hop.getName(), hop);
        } else if (HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTWRITE)) {
            hashMap2.put(hop.getName(), hop);
        } else if (hop instanceof FunctionOp) {
            for (String str : ((FunctionOp) hop).getOutputVariableNames()) {
                hashMap2.put(str, null);
            }
        }
        hop.setVisited();
    }

    private static boolean hasFunctionOpRoot(StatementBlock statementBlock) {
        if (statementBlock == null || statementBlock.getHops() == null) {
            return false;
        }
        boolean z = false;
        Iterator<Hop> it = statementBlock.getHops().iterator();
        while (it.hasNext()) {
            z |= it.next() instanceof FunctionOp;
        }
        return z;
    }

    private static boolean hasExternalFunctionOpRootWithSideEffect(StatementBlock statementBlock) {
        FunctionStatementBlock functionStatementBlock;
        if (statementBlock == null || statementBlock.getHops() == null) {
            return false;
        }
        Iterator<Hop> it = statementBlock.getHops().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if ((next instanceof FunctionOp) && (functionStatementBlock = statementBlock.getDMLProg().getFunctionStatementBlock(((FunctionOp) next).getFunctionKey())) != null && (functionStatementBlock.getStatement(0) instanceof ExternalFunctionStatement) && ((ExternalFunctionStatement) functionStatementBlock.getStatement(0)).hasSideEffects()) {
                return true;
            }
        }
        return false;
    }

    private static boolean hasFunctionIOConflict(StatementBlock statementBlock, StatementBlock statementBlock2) {
        HashSet hashSet = new HashSet();
        Iterator<Hop> it = statementBlock.getHops().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (next instanceof FunctionOp) {
                hashSet.addAll(Arrays.asList(((FunctionOp) next).getOutputVariableNames()));
            }
        }
        return statementBlock2.variablesRead().containsAnyName(hashSet) || statementBlock2.variablesUpdated().containsAnyName(hashSet);
    }
}
