package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.codegen.SpoofCompiler;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/opt/ProgramRecompiler.class */
public class ProgramRecompiler {
    public static ArrayList<ProgramBlock> generatePartitialRuntimeProgram(Program program, ArrayList<StatementBlock> arrayList) {
        ArrayList<ProgramBlock> arrayList2 = new ArrayList<>();
        DMLConfig dMLConfig = ConfigurationManager.getDMLConfig();
        DMLTranslator dMLTranslator = new DMLTranslator(arrayList.get(0).getDMLProg());
        Iterator<StatementBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            dMLTranslator.constructLops(it.next());
        }
        Iterator<StatementBlock> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            arrayList2.add(dMLTranslator.createRuntimeProgramBlock(program, it2.next(), dMLConfig));
        }
        if (ConfigurationManager.isCodegenEnabled() && SpoofCompiler.INTEGRATION == SpoofCompiler.IntegrationType.RUNTIME) {
            Iterator<ProgramBlock> it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                dMLTranslator.codgenHopsDAG(it3.next());
            }
        }
        return arrayList2;
    }

    public static void rFindAndRecompileIndexingHOP(StatementBlock statementBlock, ProgramBlock programBlock, String str, ExecutionContext executionContext, boolean z) {
        if ((programBlock instanceof IfProgramBlock) && (statementBlock instanceof IfStatementBlock)) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            if (ifStatementBlock.getPredicateHops() != null) {
                ifProgramBlock.setPredicate(rFindAndRecompileIndexingHOP(ifStatementBlock.getPredicateHops(), ifProgramBlock.getPredicate(), str, executionContext, z));
            }
            int size = ifStatement.getIfBody().size();
            for (int i = 0; i < ifProgramBlock.getChildBlocksIfBody().size() && i < size; i++) {
                rFindAndRecompileIndexingHOP(ifStatement.getIfBody().get(i), ifProgramBlock.getChildBlocksIfBody().get(i), str, executionContext, z);
            }
            if (ifProgramBlock.getChildBlocksElseBody() != null) {
                int size2 = ifStatement.getElseBody().size();
                for (int i2 = 0; i2 < ifProgramBlock.getChildBlocksElseBody().size() && i2 < size2; i2++) {
                    rFindAndRecompileIndexingHOP(ifStatement.getElseBody().get(i2), ifProgramBlock.getChildBlocksElseBody().get(i2), str, executionContext, z);
                }
                return;
            }
            return;
        }
        if ((programBlock instanceof WhileProgramBlock) && (statementBlock instanceof WhileStatementBlock)) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) statementBlock.getStatement(0);
            if (whileStatementBlock.getPredicateHops() != null) {
                whileProgramBlock.setPredicate(rFindAndRecompileIndexingHOP(whileStatementBlock.getPredicateHops(), whileProgramBlock.getPredicate(), str, executionContext, z));
            }
            int size3 = whileStatement.getBody().size();
            for (int i3 = 0; i3 < whileProgramBlock.getChildBlocks().size() && i3 < size3; i3++) {
                rFindAndRecompileIndexingHOP(whileStatement.getBody().get(i3), whileProgramBlock.getChildBlocks().get(i3), str, executionContext, z);
            }
            return;
        }
        if (!(programBlock instanceof ForProgramBlock) || !(statementBlock instanceof ForStatementBlock)) {
            try {
                boolean z2 = false;
                Hop.resetVisitStatus(statementBlock.getHops());
                if (z) {
                    Iterator<Hop> it = statementBlock.getHops().iterator();
                    while (it.hasNext()) {
                        z2 |= rFindAndSetCPIndexingHOP(it.next(), str);
                    }
                } else {
                    Iterator<Hop> it2 = statementBlock.getHops().iterator();
                    while (it2.hasNext()) {
                        z2 |= rFindAndReleaseIndexingHOP(it2.next(), str);
                    }
                }
                if (z2) {
                    programBlock.setInstructions(Recompiler.recompileHopsDag(statementBlock, statementBlock.getHops(), executionContext.getVariables(), null, true, false, 0L));
                }
                return;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        if (forStatementBlock.getFromHops() != null) {
            forProgramBlock.setFromInstructions(rFindAndRecompileIndexingHOP(forStatementBlock.getFromHops(), forProgramBlock.getFromInstructions(), str, executionContext, z));
        }
        if (forStatementBlock.getToHops() != null) {
            forProgramBlock.setToInstructions(rFindAndRecompileIndexingHOP(forStatementBlock.getToHops(), forProgramBlock.getToInstructions(), str, executionContext, z));
        }
        if (forStatementBlock.getIncrementHops() != null) {
            forProgramBlock.setIncrementInstructions(rFindAndRecompileIndexingHOP(forStatementBlock.getIncrementHops(), forProgramBlock.getIncrementInstructions(), str, executionContext, z));
        }
        int size4 = forStatement.getBody().size();
        for (int i4 = 0; i4 < forProgramBlock.getChildBlocks().size() && i4 < size4; i4++) {
            rFindAndRecompileIndexingHOP(forStatement.getBody().get(i4), forProgramBlock.getChildBlocks().get(i4), str, executionContext, z);
        }
    }

    public static LocalVariableMap getReusableScalarVariables(DMLProgram dMLProgram, StatementBlock statementBlock, LocalVariableMap localVariableMap) {
        LocalVariableMap localVariableMap2 = new LocalVariableMap();
        for (String str : localVariableMap.keySet()) {
            Data data = localVariableMap.get(str);
            if ((data instanceof ScalarObject) && isApplicableForReuseVariable(dMLProgram, statementBlock, str)) {
                localVariableMap2.put(str, data);
            }
        }
        return localVariableMap2;
    }

    public static void replaceConstantScalarVariables(StatementBlock statementBlock, LocalVariableMap localVariableMap) {
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            replacePredicateLiterals(((IfStatementBlock) statementBlock).getPredicateHops(), localVariableMap);
            Iterator<StatementBlock> it = ifStatement.getIfBody().iterator();
            while (it.hasNext()) {
                replaceConstantScalarVariables(it.next(), localVariableMap);
            }
            Iterator<StatementBlock> it2 = ifStatement.getElseBody().iterator();
            while (it2.hasNext()) {
                replaceConstantScalarVariables(it2.next(), localVariableMap);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) statementBlock.getStatement(0);
            replacePredicateLiterals(((WhileStatementBlock) statementBlock).getPredicateHops(), localVariableMap);
            Iterator<StatementBlock> it3 = whileStatement.getBody().iterator();
            while (it3.hasNext()) {
                replaceConstantScalarVariables(it3.next(), localVariableMap);
            }
            return;
        }
        if (!(statementBlock instanceof ForStatementBlock)) {
            ArrayList<Hop> hops = statementBlock.getHops();
            if (hops != null) {
                Hop.resetVisitStatus(hops);
                Iterator<Hop> it4 = hops.iterator();
                while (it4.hasNext()) {
                    Recompiler.rReplaceLiterals(it4.next(), localVariableMap, true);
                }
                return;
            }
            return;
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
        ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
        replacePredicateLiterals(forStatementBlock.getFromHops(), localVariableMap);
        replacePredicateLiterals(forStatementBlock.getToHops(), localVariableMap);
        replacePredicateLiterals(forStatementBlock.getIncrementHops(), localVariableMap);
        Iterator<StatementBlock> it5 = forStatement.getBody().iterator();
        while (it5.hasNext()) {
            replaceConstantScalarVariables(it5.next(), localVariableMap);
        }
    }

    private static void replacePredicateLiterals(Hop hop, LocalVariableMap localVariableMap) {
        if (hop != null) {
            hop.resetVisitStatus();
            Recompiler.rReplaceLiterals(hop, localVariableMap, true);
        }
    }

    public static boolean isApplicableForReuseVariable(DMLProgram dMLProgram, StatementBlock statementBlock, String str) {
        boolean z = false;
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            z |= isApplicableForReuseVariable(it.next(), statementBlock, str);
        }
        return z;
    }

    private static boolean isApplicableForReuseVariable(StatementBlock statementBlock, StatementBlock statementBlock2, String str) {
        boolean z = false;
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            Iterator<StatementBlock> it = ifStatement.getIfBody().iterator();
            while (it.hasNext()) {
                z |= isApplicableForReuseVariable(it.next(), statementBlock2, str);
            }
            Iterator<StatementBlock> it2 = ifStatement.getElseBody().iterator();
            while (it2.hasNext()) {
                z |= isApplicableForReuseVariable(it2.next(), statementBlock2, str);
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it3 = ((WhileStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it3.hasNext()) {
                z |= isApplicableForReuseVariable(it3.next(), statementBlock2, str);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
            if (forStatementBlock == statementBlock2) {
                z = true;
            } else {
                Iterator<StatementBlock> it4 = forStatement.getBody().iterator();
                while (it4.hasNext()) {
                    z |= isApplicableForReuseVariable(it4.next(), statementBlock2, str);
                }
            }
        }
        return z && !statementBlock.variablesUpdated().containsVariable(str);
    }

    public static boolean containsAtLeastOneFunction(ProgramBlock programBlock) {
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            Iterator<ProgramBlock> it = ifProgramBlock.getChildBlocksIfBody().iterator();
            while (it.hasNext()) {
                if (containsAtLeastOneFunction(it.next())) {
                    return true;
                }
            }
            Iterator<ProgramBlock> it2 = ifProgramBlock.getChildBlocksElseBody().iterator();
            while (it2.hasNext()) {
                if (containsAtLeastOneFunction(it2.next())) {
                    return true;
                }
            }
            return false;
        }
        if (programBlock instanceof WhileProgramBlock) {
            Iterator<ProgramBlock> it3 = ((WhileProgramBlock) programBlock).getChildBlocks().iterator();
            while (it3.hasNext()) {
                if (containsAtLeastOneFunction(it3.next())) {
                    return true;
                }
            }
            return false;
        }
        if (programBlock instanceof ForProgramBlock) {
            Iterator<ProgramBlock> it4 = ((ForProgramBlock) programBlock).getChildBlocks().iterator();
            while (it4.hasNext()) {
                if (containsAtLeastOneFunction(it4.next())) {
                    return true;
                }
            }
            return false;
        }
        if (programBlock.getInstructions() == null) {
            return false;
        }
        Iterator<Instruction> it5 = programBlock.getInstructions().iterator();
        while (it5.hasNext()) {
            if (it5.next() instanceof FunctionCallCPInstruction) {
                return true;
            }
        }
        return false;
    }

    private static ArrayList<Instruction> rFindAndRecompileIndexingHOP(Hop hop, ArrayList<Instruction> arrayList, String str, ExecutionContext executionContext, boolean z) {
        ArrayList<Instruction> arrayList2 = arrayList;
        try {
            hop.resetVisitStatus();
            if (z ? rFindAndSetCPIndexingHOP(hop, str) : rFindAndReleaseIndexingHOP(hop, str)) {
                arrayList2 = Recompiler.recompileHopsDag(hop, executionContext.getVariables(), null, true, false, 0L);
            }
            return arrayList2;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean rFindAndSetCPIndexingHOP(Hop hop, String str) {
        boolean z = false;
        if (hop.isVisited()) {
            return false;
        }
        ArrayList<Hop> input = hop.getInput();
        if ((hop instanceof IndexingOp) && hop.getInput().get(0).getName().equals(str)) {
            if (hop.getMemEstimate() < OptimizerUtils.getLocalMemBudget()) {
                hop.setForcedExecType(LopProperties.ExecType.CP);
            } else {
                hop.setForcedExecType(LopProperties.ExecType.CP_FILE);
            }
            z = true;
        }
        if (input != null) {
            Iterator<Hop> it = input.iterator();
            while (it.hasNext()) {
                z |= rFindAndSetCPIndexingHOP(it.next(), str);
            }
        }
        hop.setVisited();
        return z;
    }

    private static boolean rFindAndReleaseIndexingHOP(Hop hop, String str) {
        boolean z = false;
        if (hop.isVisited()) {
            return false;
        }
        ArrayList<Hop> input = hop.getInput();
        if ((hop instanceof IndexingOp) && hop.getInput().get(0).getName().equals(str)) {
            hop.setForcedExecType(null);
            hop.clearMemEstimate();
            z = true;
        }
        if (input != null) {
            Iterator<Hop> it = input.iterator();
            while (it.hasNext()) {
                z |= rFindAndReleaseIndexingHOP(it.next(), str);
            }
        }
        hop.setVisited();
        return z;
    }

    protected static ArrayList<Instruction> createNestedParallelismToInstructionSet(String str, String str2) {
        ArrayList<Instruction> arrayList = new ArrayList<>();
        arrayList.add(BinaryCPInstruction.parseInstruction("CP°+°" + str + "·SCALAR·INT°" + str2 + "·SCALAR·INT°" + str + "·SCALAR·INT"));
        return arrayList;
    }
}
