package org.apache.sysml.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.FunctionStatement;
import org.apache.sysml.parser.StatementBlock;

/* loaded from: input_file:org/apache/sysml/hops/ipa/IPAPassInlineFunctions.class */
public class IPAPassInlineFunctions extends IPAPass {
    @Override // org.apache.sysml.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return true;
    }

    @Override // org.apache.sysml.hops.ipa.IPAPass
    public void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        for (String str : functionCallGraph.getReachableFunctions()) {
            FunctionStatement functionStatement = (FunctionStatement) dMLProgram.getFunctionStatementBlock(str).getStatement(0);
            if (functionStatement.getBody().size() == 1 && HopRewriteUtils.isLastLevelStatementBlock(functionStatement.getBody().get(0)) && !containsFunctionOp(functionStatement.getBody().get(0).getHops()) && (functionCallGraph.getFunctionCalls(str).size() == 1 || countOperators(functionStatement.getBody().get(0).getHops()) <= 10)) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("IPA: Inline function '" + str + "'");
                }
                ArrayList<Hop> hops = functionStatement.getBody().get(0).getHops();
                List<FunctionOp> functionCalls = functionCallGraph.getFunctionCalls(str);
                List<StatementBlock> functionCallsSB = functionCallGraph.getFunctionCallsSB(str);
                boolean z = true;
                for (int i = 0; i < functionCalls.size(); i++) {
                    FunctionOp functionOp = functionCalls.get(i);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("-- inline '" + str + "' at line " + functionOp.getBeginLine());
                    }
                    if (functionOp.getInput().size() == functionStatement.getInputParams().size() && functionOp.getOutputVariableNames().length == functionStatement.getOutputParams().size()) {
                        ArrayList<Hop> deepCopyHopsDag = Recompiler.deepCopyHopsDag(hops);
                        HashMap hashMap = new HashMap();
                        for (int i2 = 0; i2 < functionOp.getInput().size(); i2++) {
                            String str2 = functionOp.getInputVariableNames()[i2];
                            if (functionStatement.getInputParam(str2) == null) {
                                throw new HopsException("Non-existing named function argument: '" + str2 + "' in function call '" + functionOp.getFunctionKey() + "' (line " + functionOp.getBeginLine() + ").");
                            }
                            hashMap.put(str2, functionOp.getInput().get(i2));
                        }
                        replaceTransientReads(deepCopyHopsDag, hashMap);
                        HashMap hashMap2 = new HashMap();
                        String[] outputVariableNames = functionOp.getOutputVariableNames();
                        for (int i3 = 0; i3 < outputVariableNames.length; i3++) {
                            hashMap2.put(functionStatement.getOutputParams().get(i3).getName(), outputVariableNames[i3]);
                        }
                        for (int i4 = 0; i4 < deepCopyHopsDag.size(); i4++) {
                            Hop hop = deepCopyHopsDag.get(i4);
                            if (HopRewriteUtils.isData(hop, Hop.DataOpTypes.TRANSIENTWRITE)) {
                                hop.setName((String) hashMap2.get(hop.getName()));
                                if (hop.getName() == null) {
                                    deepCopyHopsDag.remove(i4);
                                }
                            }
                        }
                        functionCallsSB.get(i).getHops().remove(functionOp);
                        functionCallsSB.get(i).getHops().addAll(deepCopyHopsDag);
                    } else {
                        z = false;
                    }
                }
                if (z) {
                    functionCallGraph.removeFunctionCalls(str);
                }
            }
        }
    }

    private static boolean containsFunctionOp(ArrayList<Hop> arrayList) {
        if (arrayList == null || arrayList.isEmpty()) {
            return false;
        }
        Hop.resetVisitStatus(arrayList);
        boolean containsOp = HopRewriteUtils.containsOp(arrayList, FunctionOp.class);
        Hop.resetVisitStatus(arrayList);
        return containsOp;
    }

    private static int countOperators(ArrayList<Hop> arrayList) {
        if (arrayList == null || arrayList.isEmpty()) {
            return 0;
        }
        Hop.resetVisitStatus(arrayList);
        int i = 0;
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            i += rCountOperators(it.next());
        }
        Hop.resetVisitStatus(arrayList);
        return i;
    }

    private static int rCountOperators(Hop hop) {
        if (hop.isVisited()) {
            return 0;
        }
        int i = ((hop instanceof DataOp) || (hop instanceof LiteralOp)) ? 0 : 1;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            i += rCountOperators(it.next());
        }
        hop.setVisited();
        return i;
    }

    private static void replaceTransientReads(ArrayList<Hop> arrayList, HashMap<String, Hop> hashMap) {
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rReplaceTransientReads(it.next(), hashMap);
        }
        Hop.resetVisitStatus(arrayList);
    }

    private static void rReplaceTransientReads(Hop hop, HashMap<String, Hop> hashMap) {
        if (hop.isVisited()) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            rReplaceTransientReads(hop2, hashMap);
            if (HopRewriteUtils.isData(hop2, Hop.DataOpTypes.TRANSIENTREAD)) {
                HopRewriteUtils.replaceChildReference(hop, hop2, hashMap.get(hop2.getName()));
            }
        }
        hop.setVisited();
    }
}
