package org.apache.sysml.yarn.ropt;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.io.FileUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.cost.CostEstimationWrapper;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaData;
import org.apache.sysml.yarn.DMLYarnClient;
import org.apache.sysml.yarn.ropt.YarnOptimizerUtils;

/* loaded from: input_file:org/apache/sysml/yarn/ropt/ResourceOptimizer.class */
public class ResourceOptimizer {
    public static final long MIN_CP_BUDGET = 536870912;
    public static final boolean INCLUDE_PREDICATES = true;
    public static final boolean PRUNING_SMALL = true;
    public static final boolean PRUNING_UNKNOWN = true;
    public static final boolean COSTS_MAX_PARALLELISM = true;
    public static final boolean COST_INDIVIDUAL_BLOCKS = true;
    private static final Log LOG = LogFactory.getLog(ResourceOptimizer.class);
    private static long _cntCompilePB = 0;
    private static long _cntCostPB = 0;

    public static synchronized ResourceConfig optimizeResourceConfig(ArrayList<ProgramBlock> arrayList, YarnClusterConfig yarnClusterConfig, YarnOptimizerUtils.GridEnumType gridEnumType, YarnOptimizerUtils.GridEnumType gridEnumType2) {
        try {
            Timing timing = new Timing(true);
            initStatistics();
            long b = (long) (YarnOptimizerUtils.toB(yarnClusterConfig.getMaxAllocationMB()) / 1.5d);
            long max = (long) Math.max(YarnOptimizerUtils.toB(yarnClusterConfig.getMinAllocationMB()) / 1.5d, 5.36870912E8d);
            long computeMinContraint = YarnOptimizerUtils.computeMinContraint(max, b, yarnClusterConfig.getAvgNumCores());
            ArrayList<Long> enumerateGridPoints = enumerateGridPoints(arrayList, max, b, gridEnumType);
            ArrayList<Long> enumerateGridPoints2 = enumerateGridPoints(arrayList, computeMinContraint, b, gridEnumType2);
            ResourceConfig resourceConfig = new ResourceConfig(arrayList, computeMinContraint);
            double d = Double.MAX_VALUE;
            Iterator<Long> it = enumerateGridPoints.iterator();
            while (it.hasNext()) {
                Long next = it.next();
                ArrayList<ProgramBlock> compileProgram = compileProgram(arrayList, (ArrayList<ProgramBlock>) null, next.longValue(), computeMinContraint);
                ArrayList<ProgramBlock> pruneProgramBlocks = pruneProgramBlocks(compileProgram);
                LOG.debug("Enum (rc=" + next + "): |B|=" + compileProgram.size() + ", |Bp|=" + pruneProgramBlocks.size());
                double[][] initLocalMemoTable = initLocalMemoTable(pruneProgramBlocks, computeMinContraint);
                for (int i = 0; i < pruneProgramBlocks.size(); i++) {
                    ProgramBlock programBlock = pruneProgramBlocks.get(i);
                    Iterator<Long> it2 = enumerateGridPoints2.iterator();
                    while (it2.hasNext()) {
                        recompileProgramBlock(programBlock, next.longValue(), it2.next().longValue());
                        double programCosts = getProgramCosts(programBlock);
                        if (programCosts < initLocalMemoTable[i][1]) {
                            initLocalMemoTable[i][0] = r0.longValue();
                            initLocalMemoTable[i][1] = programCosts;
                        }
                    }
                }
                double[][] initGlobalMemoTable = initGlobalMemoTable(compileProgram, pruneProgramBlocks, initLocalMemoTable, computeMinContraint);
                recompileProgramBlocks(compileProgram, next.longValue(), initGlobalMemoTable);
                double programCosts2 = getProgramCosts(compileProgram.get(0).getProgram());
                if (programCosts2 < d) {
                    resourceConfig.setCPResource(next.longValue());
                    resourceConfig.setMRResources(compileProgram, initGlobalMemoTable);
                    d = programCosts2;
                    LOG.debug("Enum (rc=" + next + "): found new opt w/ cost=" + programCosts2);
                }
            }
            LOG.info("Optimization summary:");
            LOG.info("-- optimal plan (rc, rm): " + YarnOptimizerUtils.toMB(resourceConfig.getCPResource()) + "MB, " + YarnOptimizerUtils.toMB(resourceConfig.getMaxMRResource()) + "MB");
            LOG.info("-- costs of optimal plan: " + d);
            LOG.info("-- # of block compiles:   " + _cntCompilePB);
            LOG.info("-- # of block costings:   " + _cntCostPB);
            LOG.info("-- optimization time:     " + String.format("%.3f", Double.valueOf(timing.stop() / 1000.0d)) + " sec.");
            LOG.info("-- optimal plan details:  " + resourceConfig.serialize());
            return resourceConfig;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    public static ArrayList<ProgramBlock> compileProgram(ArrayList<ProgramBlock> arrayList, ResourceConfig resourceConfig) {
        ArrayList<ProgramBlock> compileProgram = compileProgram(arrayList, (ArrayList<ProgramBlock>) null, resourceConfig.getCPResource(), resourceConfig.getMaxMRResource());
        recompileProgramBlocks(compileProgram, resourceConfig.getCPResource(), resourceConfig.getMRResourcesMemo());
        return compileProgram;
    }

    private static ArrayList<ProgramBlock> compileProgram(ArrayList<ProgramBlock> arrayList, ArrayList<ProgramBlock> arrayList2, double d, double d2) {
        if (arrayList2 == null) {
            arrayList2 = new ArrayList<>();
            InfrastructureAnalyzer.setLocalMaxMemory((long) d);
            InfrastructureAnalyzer.setRemoteMaxMemoryMap((long) d2);
            InfrastructureAnalyzer.setRemoteMaxMemoryReduce((long) d2);
            OptimizerUtils.resetDefaultSize();
        }
        Iterator<ProgramBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            compileProgram(it.next(), arrayList2, d, d2);
        }
        return arrayList2;
    }

    private static ArrayList<ProgramBlock> compileProgram(ProgramBlock programBlock, ArrayList<ProgramBlock> arrayList, double d, double d2) {
        if (programBlock instanceof FunctionProgramBlock) {
            compileProgram(((FunctionProgramBlock) programBlock).getChildBlocks(), arrayList, d, d2);
        } else if (programBlock instanceof WhileProgramBlock) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) programBlock.getStatementBlock();
            if (whileStatementBlock != null && whileStatementBlock.getPredicateHops() != null) {
                whileProgramBlock.setPredicate(Recompiler.recompileHopsDag(whileStatementBlock.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L));
                arrayList.add(whileProgramBlock);
                _cntCompilePB++;
            }
            compileProgram(whileProgramBlock.getChildBlocks(), arrayList, d, d2);
        } else if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) ifProgramBlock.getStatementBlock();
            if (ifStatementBlock != null && ifStatementBlock.getPredicateHops() != null) {
                ifProgramBlock.setPredicate(Recompiler.recompileHopsDag(ifStatementBlock.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L));
                arrayList.add(ifProgramBlock);
                _cntCompilePB++;
            }
            compileProgram(ifProgramBlock.getChildBlocksIfBody(), arrayList, d, d2);
            compileProgram(ifProgramBlock.getChildBlocksElseBody(), arrayList, d, d2);
        } else if (programBlock instanceof ForProgramBlock) {
            ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
            ForStatementBlock forStatementBlock = (ForStatementBlock) forProgramBlock.getStatementBlock();
            if (forStatementBlock != null) {
                if (forStatementBlock.getFromHops() != null) {
                    forProgramBlock.setFromInstructions(Recompiler.recompileHopsDag(forStatementBlock.getFromHops(), new LocalVariableMap(), null, false, false, 0L));
                }
                if (forStatementBlock.getToHops() != null) {
                    forProgramBlock.setToInstructions(Recompiler.recompileHopsDag(forStatementBlock.getToHops(), new LocalVariableMap(), null, false, false, 0L));
                }
                if (forStatementBlock.getIncrementHops() != null) {
                    forProgramBlock.setIncrementInstructions(Recompiler.recompileHopsDag(forStatementBlock.getIncrementHops(), new LocalVariableMap(), null, false, false, 0L));
                }
                arrayList.add(forProgramBlock);
                _cntCompilePB++;
            }
            compileProgram(forProgramBlock.getChildBlocks(), arrayList, d, d2);
        } else {
            StatementBlock statementBlock = programBlock.getStatementBlock();
            programBlock.setInstructions(Recompiler.recompileHopsDag(statementBlock, statementBlock.getHops(), new LocalVariableMap(), null, false, false, 0L));
            arrayList.add(programBlock);
            _cntCompilePB++;
        }
        return arrayList;
    }

    private static void recompileProgramBlocks(ArrayList<ProgramBlock> arrayList, long j, double[][] dArr) {
        for (int i = 0; i < arrayList.size(); i++) {
            recompileProgramBlock(arrayList.get(i), j, (long) dArr[i][0]);
        }
    }

    private static void recompileProgramBlock(ProgramBlock programBlock, long j, long j2) {
        InfrastructureAnalyzer.setLocalMaxMemory(j);
        InfrastructureAnalyzer.setRemoteMaxMemoryMap(j2);
        InfrastructureAnalyzer.setRemoteMaxMemoryReduce(j2);
        OptimizerUtils.resetDefaultSize();
        if (programBlock instanceof WhileProgramBlock) {
            WhileProgramBlock whileProgramBlock = (WhileProgramBlock) programBlock;
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) programBlock.getStatementBlock();
            if (whileStatementBlock != null && whileStatementBlock.getPredicateHops() != null) {
                whileProgramBlock.setPredicate(annotateMRJobInstructions(Recompiler.recompileHopsDag(whileStatementBlock.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
            }
        } else if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatementBlock ifStatementBlock = (IfStatementBlock) ifProgramBlock.getStatementBlock();
            if (ifStatementBlock != null && ifStatementBlock.getPredicateHops() != null) {
                ifProgramBlock.setPredicate(annotateMRJobInstructions(Recompiler.recompileHopsDag(ifStatementBlock.getPredicateHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
            }
        } else if (programBlock instanceof ForProgramBlock) {
            ForProgramBlock forProgramBlock = (ForProgramBlock) programBlock;
            ForStatementBlock forStatementBlock = (ForStatementBlock) forProgramBlock.getStatementBlock();
            if (forStatementBlock != null) {
                if (forStatementBlock.getFromHops() != null) {
                    forProgramBlock.setFromInstructions(annotateMRJobInstructions(Recompiler.recompileHopsDag(forStatementBlock.getFromHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
                }
                if (forStatementBlock.getToHops() != null) {
                    forProgramBlock.setToInstructions(annotateMRJobInstructions(Recompiler.recompileHopsDag(forStatementBlock.getToHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
                }
                if (forStatementBlock.getIncrementHops() != null) {
                    forProgramBlock.setIncrementInstructions(annotateMRJobInstructions(Recompiler.recompileHopsDag(forStatementBlock.getIncrementHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
                }
            }
        } else {
            StatementBlock statementBlock = programBlock.getStatementBlock();
            programBlock.setInstructions(annotateMRJobInstructions(Recompiler.recompileHopsDag(statementBlock, statementBlock.getHops(), new LocalVariableMap(), null, false, false, 0L), j, j2));
        }
        _cntCompilePB++;
    }

    private static ArrayList<Instruction> annotateMRJobInstructions(ArrayList<Instruction> arrayList, long j, long j2) {
        if (arrayList == null) {
            return arrayList;
        }
        for (int i = 0; i < arrayList.size(); i++) {
            try {
                Instruction instruction = arrayList.get(i);
                if (instruction instanceof MRJobInstruction) {
                    MRJobResourceInstruction mRJobResourceInstruction = new MRJobResourceInstruction((MRJobInstruction) instruction);
                    mRJobResourceInstruction.setMaxMRTasks(((YarnClusterAnalyzer.getNumNodes() * YarnClusterAnalyzer.getMaxAllocationBytes()) - DMLYarnClient.computeMemoryAllocation(j)) / DMLYarnClient.computeMemoryAllocation(j2));
                    arrayList.set(i, mRJobResourceInstruction);
                }
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        return arrayList;
    }

    private static double getProgramCosts(ProgramBlock programBlock) {
        LocalVariableMap localVariableMap = new LocalVariableMap();
        collectReadVariables(programBlock.getStatementBlock().getHops(), localVariableMap);
        ExecutionContext createContext = ExecutionContextFactory.createContext(false, (Program) null);
        createContext.setVariables(localVariableMap);
        double timeEstimate = CostEstimationWrapper.getTimeEstimate(programBlock, createContext, false);
        _cntCostPB++;
        return timeEstimate;
    }

    private static double getProgramCosts(Program program) {
        double timeEstimate = CostEstimationWrapper.getTimeEstimate(program, ExecutionContextFactory.createContext());
        _cntCostPB++;
        return timeEstimate;
    }

    private static void collectReadVariables(ArrayList<Hop> arrayList, LocalVariableMap localVariableMap) {
        if (arrayList != null) {
            Hop.resetVisitStatus(arrayList);
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                collectReadVariables(it.next(), localVariableMap);
            }
        }
    }

    private static void collectReadVariables(Hop hop, LocalVariableMap localVariableMap) {
        if (hop == null) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            collectReadVariables(it.next(), localVariableMap);
        }
        if ((hop instanceof DataOp) && hop.getDataType() == Expression.DataType.MATRIX && (((DataOp) hop).getDataOpType() == Hop.DataOpTypes.TRANSIENTREAD || ((DataOp) hop).getDataOpType() == Hop.DataOpTypes.PERSISTENTREAD)) {
            localVariableMap.put(hop.getName(), new MatrixObject(Expression.ValueType.DOUBLE, "/tmp", new MetaData(new MatrixCharacteristics(hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz()))));
        }
        hop.setVisited();
    }

    private static ArrayList<ProgramBlock> pruneProgramBlocks(ArrayList<ProgramBlock> arrayList) {
        ArrayList arrayList2 = new ArrayList();
        Iterator<ProgramBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            ProgramBlock next = it.next();
            if (OptTreeConverter.containsMRJobInstruction(next.getInstructions(), false, true)) {
                arrayList2.add(next);
            }
        }
        ArrayList<ProgramBlock> arrayList3 = new ArrayList<>();
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            ProgramBlock programBlock = (ProgramBlock) it2.next();
            if (!pruneHasOnlyUnknownMR(programBlock)) {
                arrayList3.add(programBlock);
            }
        }
        return arrayList3;
    }

    private static boolean pruneHasOnlyUnknownMR(ProgramBlock programBlock) {
        if (programBlock instanceof WhileProgramBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) programBlock.getStatementBlock();
            whileStatementBlock.getPredicateHops().resetVisitStatus();
            return pruneHasOnlyUnknownMR(whileStatementBlock.getPredicateHops());
        }
        if (programBlock instanceof IfProgramBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) programBlock.getStatementBlock();
            ifStatementBlock.getPredicateHops().resetVisitStatus();
            return pruneHasOnlyUnknownMR(ifStatementBlock.getPredicateHops());
        }
        if (!(programBlock instanceof ForProgramBlock)) {
            return pruneHasOnlyUnknownMR(programBlock.getStatementBlock().getHops());
        }
        ForStatementBlock forStatementBlock = (ForStatementBlock) programBlock.getStatementBlock();
        forStatementBlock.getFromHops().resetVisitStatus();
        forStatementBlock.getToHops().resetVisitStatus();
        forStatementBlock.getIncrementHops().resetVisitStatus();
        return pruneHasOnlyUnknownMR(forStatementBlock.getFromHops()) && pruneHasOnlyUnknownMR(forStatementBlock.getToHops()) && pruneHasOnlyUnknownMR(forStatementBlock.getIncrementHops());
    }

    private static boolean pruneHasOnlyUnknownMR(ArrayList<Hop> arrayList) {
        boolean z = false;
        if (arrayList != null) {
            z = true;
            Hop.resetVisitStatus(arrayList);
            Iterator<Hop> it = arrayList.iterator();
            while (it.hasNext()) {
                z &= pruneHasOnlyUnknownMR(it.next());
            }
        }
        return z;
    }

    private static boolean pruneHasOnlyUnknownMR(Hop hop) {
        if (hop == null || hop.isVisited()) {
            return true;
        }
        boolean z = true;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z &= pruneHasOnlyUnknownMR(it.next());
        }
        if (hop.getExecType() == LopProperties.ExecType.MR) {
            boolean z2 = false | (!hop.dimsKnown());
            Iterator<Hop> it2 = hop.getInput().iterator();
            while (it2.hasNext()) {
                z2 |= !it2.next().dimsKnown();
            }
            z &= z2;
        }
        hop.setVisited();
        return z;
    }

    private static ArrayList<Long> enumerateGridPoints(ArrayList<ProgramBlock> arrayList, long j, long j2, YarnOptimizerUtils.GridEnumType gridEnumType) {
        GridEnumeration gridEnumerationHybrid;
        switch (gridEnumType) {
            case EQUI_GRID:
                gridEnumerationHybrid = new GridEnumerationEqui(arrayList, j, j2);
                break;
            case EXP_GRID:
                gridEnumerationHybrid = new GridEnumerationExp(arrayList, j, j2);
                break;
            case MEM_EQUI_GRID:
                gridEnumerationHybrid = new GridEnumerationMemory(arrayList, j, j2);
                break;
            case HYBRID_MEM_EXP_GRID:
                gridEnumerationHybrid = new GridEnumerationHybrid(arrayList, j, j2);
                break;
            default:
                throw new DMLRuntimeException("Unsupported grid enumeration type: " + gridEnumType);
        }
        ArrayList<Long> enumerateGridPoints = gridEnumerationHybrid.enumerateGridPoints();
        LOG.debug("Gen: min=" + YarnOptimizerUtils.toMB(j) + ", max=" + YarnOptimizerUtils.toMB(j2) + ", npoints=" + enumerateGridPoints.size());
        return enumerateGridPoints;
    }

    private static double[][] initLocalMemoTable(ArrayList<ProgramBlock> arrayList, double d) {
        int size = arrayList.size();
        double[][] dArr = new double[size][2];
        for (int i = 0; i < size; i++) {
            ProgramBlock programBlock = arrayList.get(i);
            ExecutionContext createContext = ExecutionContextFactory.createContext();
            dArr[i][0] = d;
            dArr[i][1] = CostEstimationWrapper.getTimeEstimate(programBlock.getProgram(), createContext);
        }
        return dArr;
    }

    private static double[][] initGlobalMemoTable(ArrayList<ProgramBlock> arrayList, ArrayList<ProgramBlock> arrayList2, double[][] dArr, double d) {
        int size = arrayList.size();
        int size2 = arrayList2.size();
        double[][] dArr2 = new double[size][2];
        for (int i = 0; i < size; i++) {
            dArr2[i][0] = d;
            dArr2[i][1] = -1.0d;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < size && i2 < size2; i3++) {
            if (arrayList.get(i3) == arrayList2.get(i2)) {
                dArr2[i3][0] = dArr[i2][0];
                dArr2[i3][1] = -1.0d;
                i2++;
            }
        }
        return dArr2;
    }

    public static void initStatistics() {
        _cntCompilePB = 0L;
        _cntCostPB = 0L;
    }

    public static long jvmToPhy(long j, boolean z) {
        long ceil = (long) Math.ceil(j * 1.5d);
        if (z) {
            long minMRContarinerPhyMB = YarnClusterAnalyzer.getMinMRContarinerPhyMB() * FileUtils.ONE_KB * FileUtils.ONE_KB;
            if (ceil < minMRContarinerPhyMB) {
                return minMRContarinerPhyMB;
            }
        }
        return ceil;
    }

    public static long budgetToJvm(double d) {
        return (long) Math.ceil(d / OptimizerUtils.MEM_UTIL_FACTOR);
    }

    public static double phyToBudget(long j) throws IOException {
        return (j / 1.5d) * OptimizerUtils.MEM_UTIL_FACTOR;
    }
}
