package org.apache.sysml.runtime.controlprogram.context;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.debug.DMLFrame;
import org.apache.sysml.debug.DMLProgramCounter;
import org.apache.sysml.debug.DebugState;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaData;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.utils.GPUStatistics;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/context/ExecutionContext.class */
public class ExecutionContext {
    protected static final Log LOG = LogFactory.getLog(ExecutionContext.class.getName());
    protected Program _prog;
    protected LocalVariableMap _variables;
    protected DebugState _dbState;
    protected List<GPUContext> _gpuContexts;

    protected ExecutionContext() {
        this(true, null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ExecutionContext(boolean z, Program program) {
        this._prog = null;
        this._dbState = null;
        this._gpuContexts = new ArrayList();
        if (z) {
            this._variables = new LocalVariableMap();
        } else {
            this._variables = null;
        }
        this._prog = program;
        if (DMLScript.ENABLE_DEBUG_MODE) {
            this._dbState = DebugState.getInstance();
        }
    }

    public Program getProgram() {
        return this._prog;
    }

    public void setProgram(Program program) {
        this._prog = program;
    }

    public LocalVariableMap getVariables() {
        return this._variables;
    }

    public void setVariables(LocalVariableMap localVariableMap) {
        this._variables = localVariableMap;
    }

    public GPUContext getGPUContext(int i) {
        try {
            return this._gpuContexts.get(i);
        } catch (IndexOutOfBoundsException e) {
            return null;
        }
    }

    public void setGPUContexts(List<GPUContext> list) {
        this._gpuContexts = list;
    }

    public List<GPUContext> getGPUContexts() {
        return this._gpuContexts;
    }

    public int getNumGPUContexts() {
        return this._gpuContexts.size();
    }

    public Data getVariable(String str) {
        return this._variables.get(str);
    }

    public Data getVariable(CPOperand cPOperand) {
        return cPOperand.getDataType().isScalar() ? getScalarInput(cPOperand) : getVariable(cPOperand.getName());
    }

    public void setVariable(String str, Data data) {
        this._variables.put(str, data);
    }

    public boolean containsVariable(String str) {
        return this._variables.keySet().contains(str);
    }

    public Data removeVariable(String str) {
        return this._variables.remove(str);
    }

    public void setMetaData(String str, MetaData metaData) {
        this._variables.get(str).setMetaData(metaData);
    }

    public MetaData getMetaData(String str) {
        return this._variables.get(str).getMetaData();
    }

    public boolean isMatrixObject(String str) {
        Data variable = getVariable(str);
        return variable != null && (variable instanceof MatrixObject);
    }

    public MatrixObject getMatrixObject(CPOperand cPOperand) {
        return getMatrixObject(cPOperand.getName());
    }

    public MatrixObject getMatrixObject(String str) {
        Data variable = getVariable(str);
        if (variable == null) {
            throw new DMLRuntimeException("Variable '" + str + "' does not exist in the symbol table.");
        }
        if (variable instanceof MatrixObject) {
            return (MatrixObject) variable;
        }
        throw new DMLRuntimeException("Variable '" + str + "' is not a matrix.");
    }

    public boolean isFrameObject(String str) {
        Data variable = getVariable(str);
        return variable != null && (variable instanceof FrameObject);
    }

    public FrameObject getFrameObject(CPOperand cPOperand) {
        return getFrameObject(cPOperand.getName());
    }

    public FrameObject getFrameObject(String str) {
        Data variable = getVariable(str);
        if (variable == null) {
            throw new DMLRuntimeException("Variable '" + str + "' does not exist in the symbol table.");
        }
        if (variable instanceof FrameObject) {
            return (FrameObject) variable;
        }
        throw new DMLRuntimeException("Variable '" + str + "' is not a frame.");
    }

    public CacheableData<?> getCacheableData(CPOperand cPOperand) {
        return getCacheableData(cPOperand.getName());
    }

    public CacheableData<?> getCacheableData(String str) {
        Data variable = getVariable(str);
        if (variable == null) {
            throw new DMLRuntimeException("Variable '" + str + "' does not exist in the symbol table.");
        }
        if (variable instanceof CacheableData) {
            return (CacheableData) variable;
        }
        throw new DMLRuntimeException("Variable '" + str + "' is not a matrix or frame.");
    }

    public void releaseCacheableData(String str) {
        getCacheableData(str).release();
    }

    public MatrixCharacteristics getMatrixCharacteristics(String str) {
        return getMetaData(str).getMatrixCharacteristics();
    }

    public MatrixBlock getMatrixInput(String str, String str2) {
        long nanoTime = (str2 != null && DMLScript.STATISTICS && DMLScript.FINEGRAINED_STATISTICS) ? System.nanoTime() : 0L;
        MatrixBlock matrixInput = getMatrixInput(str);
        if (str2 != null && DMLScript.STATISTICS && DMLScript.FINEGRAINED_STATISTICS) {
            long nanoTime2 = System.nanoTime();
            if (matrixInput.isInSparseFormat()) {
                GPUStatistics.maintainCPMiscTimes(str2, CPInstruction.MISC_TIMER_GET_SPARSE_MB, nanoTime2 - nanoTime);
            } else {
                GPUStatistics.maintainCPMiscTimes(str2, CPInstruction.MISC_TIMER_GET_DENSE_MB, nanoTime2 - nanoTime);
            }
        }
        return matrixInput;
    }

    public MatrixBlock getMatrixInput(String str) {
        return getMatrixObject(str).acquireRead();
    }

    public void setMetaData(String str, long j, long j2) {
        MatrixObject matrixObject = getMatrixObject(str);
        if (matrixObject.getNumRows() == j && matrixObject.getNumColumns() == j2) {
            return;
        }
        MetaData metaData = matrixObject.getMetaData();
        if (metaData == null || !(metaData instanceof MetaDataFormat)) {
            throw new DMLRuntimeException("Metadata not available");
        }
        matrixObject.setMetaData(new MetaDataFormat(new MatrixCharacteristics(j, j2, (int) matrixObject.getNumRowsPerBlock(), (int) matrixObject.getNumColumnsPerBlock()), ((MetaDataFormat) metaData).getOutputInfo(), ((MetaDataFormat) metaData).getInputInfo()));
    }

    private static long validateDimensions(long j, long j2) {
        if (j < 0 || j2 < 0 || j == j2) {
            return Math.max(j, j2);
        }
        throw new DMLRuntimeException("Incorrect dimensions:" + j + " != " + j2);
    }

    public Pair<MatrixObject, Boolean> getDenseMatrixOutputForGPUInstruction(String str, long j, long j2) {
        MatrixObject allocateGPUMatrixObject = allocateGPUMatrixObject(str, j, j2);
        boolean acquireDeviceModifyDense = allocateGPUMatrixObject.getGPUObject(getGPUContext(0)).acquireDeviceModifyDense();
        allocateGPUMatrixObject.getMatrixCharacteristics().setNonZeros(-1L);
        return new Pair<>(allocateGPUMatrixObject, Boolean.valueOf(acquireDeviceModifyDense));
    }

    public Pair<MatrixObject, Boolean> getSparseMatrixOutputForGPUInstruction(String str, long j, long j2, long j3) {
        MatrixObject allocateGPUMatrixObject = allocateGPUMatrixObject(str, j, j2);
        allocateGPUMatrixObject.getMatrixCharacteristics().setNonZeros(j3);
        return new Pair<>(allocateGPUMatrixObject, Boolean.valueOf(allocateGPUMatrixObject.getGPUObject(getGPUContext(0)).acquireDeviceModifySparse()));
    }

    public MatrixObject allocateGPUMatrixObject(String str, long j, long j2) {
        MatrixObject matrixObject = getMatrixObject(str);
        long j3 = -1;
        long j4 = -1;
        DMLRuntimeException dMLRuntimeException = null;
        try {
            j3 = validateDimensions(matrixObject.getNumRows(), j);
        } catch (DMLRuntimeException e) {
            dMLRuntimeException = e;
        }
        try {
            j4 = validateDimensions(matrixObject.getNumColumns(), j2);
        } catch (DMLRuntimeException e2) {
            dMLRuntimeException = e2;
        }
        if (dMLRuntimeException != null) {
            throw new DMLRuntimeException("Incorrect dimensions given to allocateGPUMatrixObject: [" + j + "," + j2 + "], [" + matrixObject.getNumRows() + "," + matrixObject.getNumColumns() + "]", dMLRuntimeException);
        }
        if (j3 != matrixObject.getNumRows() || j4 != matrixObject.getNumColumns()) {
            matrixObject.getMatrixCharacteristics().setDimension(j3, j4);
        }
        if (matrixObject.getGPUObject(getGPUContext(0)) == null) {
            matrixObject.setGPUObject(getGPUContext(0), getGPUContext(0).createGPUObject(matrixObject));
        }
        matrixObject.getGPUObject(getGPUContext(0)).addWriteLock();
        return matrixObject;
    }

    public MatrixObject getMatrixInputForGPUInstruction(String str, String str2) {
        GPUContext gPUContext = getGPUContext(0);
        MatrixObject matrixObject = getMatrixObject(str);
        if (matrixObject == null) {
            throw new DMLRuntimeException("No matrix object available for variable:" + str);
        }
        if (matrixObject.getGPUObject(gPUContext) == null) {
            matrixObject.setGPUObject(gPUContext, gPUContext.createGPUObject(matrixObject));
        }
        matrixObject.getGPUObject(gPUContext).acquireDeviceRead(str2);
        return matrixObject;
    }

    public void releaseMatrixInput(String str) {
        getMatrixObject(str).release();
    }

    public void releaseMatrixInput(String str, String str2) {
        long nanoTime = (str2 != null && DMLScript.STATISTICS && DMLScript.FINEGRAINED_STATISTICS) ? System.nanoTime() : 0L;
        releaseMatrixInput(str);
        if (str2 != null && DMLScript.STATISTICS && DMLScript.FINEGRAINED_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(str2, CPInstruction.MISC_TIMER_RELEASE_INPUT_MB, System.nanoTime() - nanoTime);
        }
    }

    public void releaseMatrixInputForGPUInstruction(String str) {
        getMatrixObject(str).getGPUObject(getGPUContext(0)).releaseInput();
    }

    public FrameBlock getFrameInput(String str) {
        return getFrameObject(str).acquireRead();
    }

    public void releaseFrameInput(String str) {
        getFrameObject(str).release();
    }

    public ScalarObject getScalarInput(CPOperand cPOperand) {
        return getScalarInput(cPOperand.getName(), cPOperand.getValueType(), cPOperand.isLiteral());
    }

    public ScalarObject getScalarInput(String str, Expression.ValueType valueType, boolean z) {
        if (z) {
            return ScalarObjectFactory.createScalarObject(valueType, str);
        }
        Data variable = getVariable(str);
        if (variable == null) {
            throw new DMLRuntimeException("Unknown variable: " + str);
        }
        return (ScalarObject) variable;
    }

    public void setScalarOutput(String str, ScalarObject scalarObject) {
        setVariable(str, scalarObject);
    }

    public ListObject getListObject(String str) {
        Data variable = getVariable(str);
        if (variable == null) {
            throw new DMLRuntimeException("Variable '" + str + "' does not exist in the symbol table.");
        }
        if (variable instanceof ListObject) {
            return (ListObject) variable;
        }
        throw new DMLRuntimeException("Variable '" + str + "' is not a list.");
    }

    public void releaseMatrixOutputForGPUInstruction(String str) {
        MatrixObject matrixObject = getMatrixObject(str);
        if (matrixObject.getGPUObject(getGPUContext(0)) == null || !matrixObject.getGPUObject(getGPUContext(0)).isAllocated()) {
            throw new DMLRuntimeException("No output is allocated on GPU");
        }
        setMetaData(str, new MetaDataFormat(matrixObject.getMatrixCharacteristics(), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
        matrixObject.getGPUObject(getGPUContext(0)).releaseOutput();
    }

    public void setMatrixOutput(String str, MatrixBlock matrixBlock) {
        MatrixObject matrixObject = getMatrixObject(str);
        matrixObject.acquireModify(matrixBlock);
        matrixObject.release();
        setVariable(str, matrixObject);
    }

    public void setMatrixOutput(String str, MatrixBlock matrixBlock, String str2) {
        setMatrixOutput(str, matrixBlock);
    }

    public void setMatrixOutput(String str, MatrixBlock matrixBlock, MatrixObject.UpdateType updateType) {
        if (updateType.isInPlace()) {
            getMatrixObject(str).setUpdateType(updateType);
        }
        setMatrixOutput(str, matrixBlock);
    }

    public void setMatrixOutput(String str, MatrixBlock matrixBlock, MatrixObject.UpdateType updateType, String str2) {
        setMatrixOutput(str, matrixBlock, updateType);
    }

    public void setFrameOutput(String str, FrameBlock frameBlock) {
        FrameObject frameObject = getFrameObject(str);
        frameObject.acquireModify(frameBlock);
        frameObject.release();
        setVariable(str, frameObject);
    }

    public List<MatrixBlock> getMatrixInputs(CPOperand[] cPOperandArr) {
        return (List) Arrays.stream(cPOperandArr).filter(cPOperand -> {
            return cPOperand.isMatrix();
        }).map(cPOperand2 -> {
            return getMatrixInput(cPOperand2.getName());
        }).collect(Collectors.toList());
    }

    public List<ScalarObject> getScalarInputs(CPOperand[] cPOperandArr) {
        return (List) Arrays.stream(cPOperandArr).filter(cPOperand -> {
            return cPOperand.isScalar();
        }).map(cPOperand2 -> {
            return getScalarInput(cPOperand2);
        }).collect(Collectors.toList());
    }

    public void releaseMatrixInputs(CPOperand[] cPOperandArr) {
        Arrays.stream(cPOperandArr).filter(cPOperand -> {
            return cPOperand.isMatrix();
        }).forEach(cPOperand2 -> {
            releaseMatrixInput(cPOperand2.getName());
        });
    }

    public boolean[] pinVariables(List<String> list) {
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            Data data = this._variables.get(list.get(i3));
            if (data instanceof ListObject) {
                i2 += ((ListObject) data).getNumCacheableData();
                i++;
            }
        }
        boolean[] zArr = new boolean[(list.size() - i) + i2];
        int i4 = 0;
        for (int i5 = 0; i5 < list.size(); i5++) {
            Data data2 = this._variables.get(list.get(i5));
            if (data2 instanceof CacheableData) {
                int i6 = i4;
                i4++;
                zArr[i6] = ((CacheableData) data2).isCleanupEnabled();
            } else if (data2 instanceof ListObject) {
                for (Data data3 : ((ListObject) data2).getData()) {
                    if (data3 instanceof CacheableData) {
                        int i7 = i4;
                        i4++;
                        zArr[i7] = ((CacheableData) data3).isCleanupEnabled();
                    }
                }
            }
        }
        for (int i8 = 0; i8 < list.size(); i8++) {
            Data data4 = this._variables.get(list.get(i8));
            if (data4 instanceof CacheableData) {
                ((CacheableData) data4).enableCleanup(false);
            } else if (data4 instanceof ListObject) {
                for (Data data5 : ((ListObject) data4).getData()) {
                    if (data5 instanceof CacheableData) {
                        ((CacheableData) data5).enableCleanup(false);
                    }
                }
            }
        }
        return zArr;
    }

    public void unpinVariables(List<String> list, boolean[] zArr) {
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            Data data = this._variables.get(list.get(i2));
            if (data instanceof CacheableData) {
                int i3 = i;
                i++;
                ((CacheableData) data).enableCleanup(zArr[i3]);
            } else if (data instanceof ListObject) {
                for (Data data2 : ((ListObject) data).getData()) {
                    if (data2 instanceof CacheableData) {
                        int i4 = i;
                        i++;
                        ((CacheableData) data2).enableCleanup(zArr[i4]);
                    }
                }
            }
        }
    }

    public ArrayList<String> getVarList() {
        return new ArrayList<>(this._variables.keySet());
    }

    public ArrayList<String> getVarListPartitioned() {
        ArrayList<String> arrayList = new ArrayList<>();
        for (String str : this._variables.keySet()) {
            Data data = this._variables.get(str);
            if ((data instanceof MatrixObject) && ((MatrixObject) data).isPartitioned()) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    public final void cleanupDataObject(Data data) {
        if (data == null) {
            return;
        }
        if (data instanceof CacheableData) {
            cleanupCacheableData((CacheableData) data);
            return;
        }
        if (data instanceof ListObject) {
            for (Data data2 : ((ListObject) data).getData()) {
                if (data2 instanceof CacheableData) {
                    cleanupCacheableData((CacheableData) data2);
                }
            }
        }
    }

    public void cleanupCacheableData(CacheableData<?> cacheableData) {
        if (DMLScript.JMLC_MEM_STATISTICS) {
            Statistics.removeCPMemObject(System.identityHashCode(cacheableData));
        }
        boolean z = cacheableData.isHDFSFileExists() && cacheableData.getFileName() != null;
        if (CacheableData.isCachingActive() || z) {
            try {
                if (cacheableData.isCleanupEnabled() && !getVariables().hasReferences(cacheableData)) {
                    cacheableData.clearData();
                    if (z) {
                        MapReduceTool.deleteFileIfExistOnHDFS(cacheableData.getFileName());
                        MapReduceTool.deleteFileIfExistOnHDFS(cacheableData.getFileName() + ".mtd");
                    }
                }
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }

    public void initDebugProgramCounters() {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            this._dbState.pc = new DMLProgramCounter(DMLProgram.DEFAULT_NAMESPACE, "main", 0, 0);
            this._dbState.prevPC = new DMLProgramCounter(DMLProgram.DEFAULT_NAMESPACE, "main", 0, 0);
        }
    }

    public void updateDebugState(int i) {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            this._dbState.getPC().setProgramBlockNumber(i);
        }
    }

    public void updateDebugState(Instruction instruction) {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            this._dbState.nextCommand = false;
            this._dbState.getPC().setInstID(instruction.getInstID());
            this._dbState.getPC().setLineNumber(instruction.getLineNum());
            suspendIfAskedInDebugMode(instruction);
        }
    }

    public void clearDebugProgramCounters() {
        if (DMLScript.ENABLE_DEBUG_MODE) {
            this._dbState.pc = null;
        }
    }

    public void handleDebugException(Exception exc) {
        this._dbState.getDMLStackTrace(exc);
        this._dbState.suspend = true;
    }

    public void handleDebugFunctionEntry(FunctionCallCPInstruction functionCallCPInstruction) {
        this._dbState.pushFrame(getVariables(), this._dbState.getPC());
        this._dbState.pc = new DMLProgramCounter(functionCallCPInstruction.getNamespace(), functionCallCPInstruction.getFunctionName(), 0, 0);
    }

    public void handleDebugFunctionExit(FunctionCallCPInstruction functionCallCPInstruction) {
        DMLFrame popFrame = this._dbState.popFrame();
        this._dbState.pc = popFrame.getPC();
    }

    public DebugState getDebugState() {
        return this._dbState;
    }

    private void suspendIfAskedInDebugMode(Instruction instruction) {
        if (!DMLScript.ENABLE_DEBUG_MODE) {
            System.err.println("ERROR: The function suspendIfAskedInDebugMode should not be called in non-debug mode.");
        }
        if (!this._dbState.suspend && this._dbState.dbCommand != null) {
            if (this._dbState.dbCommand.equalsIgnoreCase("step_instruction")) {
                System.out.format("Step instruction reached at %s.\n", this._dbState.getPC().toString());
                this._dbState.suspend = true;
            } else if (this._dbState.dbCommand.equalsIgnoreCase("step_line") && this._dbState.prevPC.getLineNumber() != instruction.getLineNum() && this._dbState.prevPC.getLineNumber() != 0) {
                System.out.format("Step reached at %s.\n", this._dbState.getPC().toStringWithoutInstructionID());
                this._dbState.suspend = true;
            } else if (this._dbState.dbCommand.equalsIgnoreCase("step return") && (instruction instanceof FunctionCallCPInstruction)) {
                FunctionCallCPInstruction functionCallCPInstruction = (FunctionCallCPInstruction) instruction;
                if (this._dbState.dbCommandArg == null || functionCallCPInstruction.getFunctionName().equalsIgnoreCase(this._dbState.dbCommandArg)) {
                    System.out.format("Step return reached at %s.\n", this._dbState.getPC().toStringWithoutInstructionID());
                    this._dbState.suspend = true;
                }
            }
        }
        if (this._dbState.suspend) {
            this._dbState.dbCommand = null;
            this._dbState.dbCommandArg = null;
            if (instruction.getLineNum() != 0) {
                this._dbState.printDMLSourceLine(instruction.getLineNum());
            }
            this._dbState.setVariables(getVariables());
            this._dbState.nextCommand = true;
            Thread.currentThread().suspend();
            this._dbState.nextCommand = false;
        }
        this._dbState.suspend = false;
        this._dbState.prevPC.setFunctionName(this._dbState.getPC().getFunctionName());
        this._dbState.prevPC.setProgramBlockNumber(this._dbState.getPC().getProgramBlockNumber());
        this._dbState.prevPC.setLineNumber(instruction.getLineNum());
    }
}
