package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.commons.io.FileUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.class */
public class GPUMemoryManager {
    private static final boolean DEBUG_MEMORY_LEAK = false;
    protected final GPUMemoryAllocator allocator;
    private static final double WARN_UTILIZATION_FACTOR = 0.7d;
    protected static final Log LOG = LogFactory.getLog(GPUMemoryManager.class.getName());
    private static final int[] DEBUG_MEMORY_LEAK_STACKTRACE_DEPTH = {5, 6, 7, 8, 9, 10};
    protected final HashMap<Pointer, PointerInfo> allPointers = new HashMap<>();
    protected final GPUMatrixMemoryManager matrixMemoryManager = new GPUMatrixMemoryManager(this);
    protected final GPULazyCudaFreeMemoryManager lazyCudaFreeMemoryManager = new GPULazyCudaFreeMemoryManager(this);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager$CustomPointer.class */
    public static class CustomPointer extends Pointer {
        public CustomPointer(Pointer pointer) {
            super(pointer);
        }

        public long getNativePointer() {
            return super.getNativePointer();
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager$EvictionPolicyBasedComparator.class */
    public static class EvictionPolicyBasedComparator implements Comparator<GPUObject> {
        private long neededSize;

        public EvictionPolicyBasedComparator(long j) {
            this.neededSize = j;
        }

        private int minEvictCompare(GPUObject gPUObject, GPUObject gPUObject2) {
            long sizeOnDevice = gPUObject.getSizeOnDevice() - this.neededSize;
            long sizeOnDevice2 = gPUObject2.getSizeOnDevice() - this.neededSize;
            return (sizeOnDevice < 0 || sizeOnDevice2 < 0) ? Long.compare(sizeOnDevice, sizeOnDevice2) : Long.compare(sizeOnDevice2, sizeOnDevice);
        }

        @Override // java.util.Comparator
        public int compare(GPUObject gPUObject, GPUObject gPUObject2) {
            if (gPUObject.isLocked() && gPUObject2.isLocked()) {
                return 0;
            }
            if (gPUObject.isLocked()) {
                return -1;
            }
            if (gPUObject2.isLocked()) {
                return 1;
            }
            if (DMLScript.GPU_EVICTION_POLICY != DMLScript.EvictionPolicy.ALIGN_MEMORY) {
                return DMLScript.GPU_EVICTION_POLICY == DMLScript.EvictionPolicy.MIN_EVICT ? minEvictCompare(gPUObject, gPUObject2) : Long.compare(gPUObject2.timestamp.get(), gPUObject.timestamp.get());
            }
            if (!gPUObject.isDensePointerNull() && !gPUObject2.isDensePointerNull()) {
                return new CustomPointer(gPUObject.getDensePointer()).getNativePointer() <= new CustomPointer(gPUObject2.getDensePointer()).getNativePointer() ? -1 : 1;
            }
            if (gPUObject.isDensePointerNull() && !gPUObject2.isDensePointerNull()) {
                return -1;
            }
            if (gPUObject.isDensePointerNull() || !gPUObject2.isDensePointerNull()) {
                return minEvictCompare(gPUObject, gPUObject2);
            }
            return 1;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager$PointerInfo.class */
    public static class PointerInfo {
        private long sizeInBytes;
        private StackTraceElement[] stackTraceElements;

        public PointerInfo(long j) {
            this.sizeInBytes = j;
        }

        public long getSizeInBytes() {
            return this.sizeInBytes;
        }
    }

    public GPUMatrixMemoryManager getGPUMatrixMemoryManager() {
        return this.matrixMemoryManager;
    }

    public GPULazyCudaFreeMemoryManager getGPULazyCudaFreeMemoryManager() {
        return this.lazyCudaFreeMemoryManager;
    }

    private Set<Pointer> getNonMatrixLockedPointers() {
        Set<Pointer> pointers = this.matrixMemoryManager.getPointers();
        pointers.addAll(this.lazyCudaFreeMemoryManager.getAllPointers());
        return nonIn(this.allPointers.keySet(), pointers);
    }

    public long getSizeAllocatedGPUPointer(Pointer pointer) {
        if (this.allPointers.containsKey(pointer)) {
            return this.allPointers.get(pointer).getSizeInBytes();
        }
        return -1L;
    }

    public GPUMemoryManager(GPUContext gPUContext) {
        if (DMLScript.GPU_MEMORY_ALLOCATOR.equals("cuda")) {
            this.allocator = new CudaMemoryAllocator();
        } else {
            if (!DMLScript.GPU_MEMORY_ALLOCATOR.equals("unified_memory")) {
                throw new RuntimeException("Unsupported value (" + DMLScript.GPU_MEMORY_ALLOCATOR + ") for the configuration " + DMLConfig.GPU_MEMORY_ALLOCATOR + ". Supported values are cuda, unified_memory.");
            }
            this.allocator = new UnifiedMemoryAllocator();
        }
        JCuda.cudaMemGetInfo(new long[]{0}, new long[]{0});
        if (r0[0] < WARN_UTILIZATION_FACTOR * r0[0]) {
            LOG.warn("Potential under-utilization: GPU memory - Total: " + (r0[0] * 1.0E-6d) + " MB, Available: " + (r0[0] * 1.0E-6d) + " MB on " + gPUContext + ". This can happen if there are other processes running on the GPU at the same time.");
        } else {
            LOG.info("GPU memory - Total: " + (r0[0] * 1.0E-6d) + " MB, Available: " + (r0[0] * 1.0E-6d) + " MB on " + gPUContext);
        }
        if (GPUContextPool.initialGPUMemBudget() > OptimizerUtils.getLocalMemBudget()) {
            LOG.warn("Potential under-utilization: GPU memory (" + GPUContextPool.initialGPUMemBudget() + ") > driver memory budget (" + OptimizerUtils.getLocalMemBudget() + "). Consider increasing the driver memory budget.");
        }
    }

    private Pointer cudaMallocNoWarn(Pointer pointer, long j, String str) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        try {
            this.allocator.allocate(pointer, j);
            this.allPointers.put(pointer, new PointerInfo(j));
            if (DMLScript.STATISTICS) {
                long nanoTime2 = System.nanoTime() - nanoTime;
                GPUStatistics.cudaAllocSuccessTime.add(nanoTime2);
                GPUStatistics.cudaAllocSuccessCount.increment();
                GPUStatistics.cudaAllocTime.add(nanoTime2);
                GPUStatistics.cudaAllocCount.increment();
            }
            if (str != null && (DMLScript.PRINT_GPU_MEMORY_INFO || LOG.isTraceEnabled())) {
                LOG.info("Success: " + str + ":" + byteCountToDisplaySize(j));
            }
            return pointer;
        } catch (CudaException e) {
            if (DMLScript.STATISTICS) {
                long nanoTime3 = System.nanoTime() - nanoTime;
                GPUStatistics.cudaAllocFailedTime.add(System.nanoTime() - nanoTime);
                GPUStatistics.cudaAllocFailedCount.increment();
                GPUStatistics.cudaAllocTime.add(nanoTime3);
                GPUStatistics.cudaAllocCount.increment();
            }
            if (str == null) {
                return null;
            }
            if (!DMLScript.PRINT_GPU_MEMORY_INFO && !LOG.isTraceEnabled()) {
                return null;
            }
            LOG.info("Failed: " + str + ":" + byteCountToDisplaySize(j));
            LOG.info("GPU Memory info " + str + ":" + toString());
            return null;
        }
    }

    private String getCallerInfo(StackTraceElement[] stackTraceElementArr, int i) {
        return stackTraceElementArr.length <= i ? "->" : "->" + stackTraceElementArr[i].getClassName() + Path.CUR_DIR + stackTraceElementArr[i].getMethodName() + "(" + stackTraceElementArr[i].getFileName() + ":" + stackTraceElementArr[i].getLineNumber() + ")";
    }

    private String byteCountToDisplaySize(long j) {
        if (j < FileUtils.ONE_KB) {
            return j + " bytes";
        }
        int log = (int) (Math.log(j) / 6.931471805599453d);
        return String.format("%.3f %sB", Double.valueOf(j / Math.pow(1024.0d, log)), Character.valueOf("KMGTP".charAt(log - 1)));
    }

    public Pointer malloc(String str, long j) {
        if (j < 0) {
            throw new DMLRuntimeException("Cannot allocate memory of size " + byteCountToDisplaySize(j));
        }
        Pointer rmvarPointer = this.lazyCudaFreeMemoryManager.getRmvarPointer(str, j);
        Pointer pointer = rmvarPointer == null ? new Pointer() : null;
        if (rmvarPointer == null && this.allocator.canAllocate(j)) {
            rmvarPointer = cudaMallocNoWarn(pointer, j, "allocate a new pointer");
        }
        if (rmvarPointer == null) {
            rmvarPointer = this.lazyCudaFreeMemoryManager.getRmvarPointerMinSize(str, j);
            if (rmvarPointer != null) {
                guardedCudaFree(rmvarPointer);
                rmvarPointer = cudaMallocNoWarn(pointer, j, "reuse non-exact match of rmvarGPUPointers");
                if (rmvarPointer == null) {
                    LOG.warn("cudaMalloc failed after clearing one of rmvarGPUPointers.");
                }
            }
        }
        if (rmvarPointer == null) {
            this.lazyCudaFreeMemoryManager.clearAll();
            if (this.allocator.canAllocate(j)) {
                rmvarPointer = cudaMallocNoWarn(pointer, j, "allocate a new pointer after eager free");
            }
        }
        if (rmvarPointer == null) {
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            Optional min = this.matrixMemoryManager.gpuObjects.stream().filter(gPUObject -> {
                return !gPUObject.isLocked() && this.matrixMemoryManager.getWorstCaseContiguousMemorySize(gPUObject) >= j;
            }).min((gPUObject2, gPUObject3) -> {
                return worstCaseContiguousMemorySizeCompare(gPUObject2, gPUObject3);
            });
            if (min.isPresent()) {
                evictOrClear((GPUObject) min.get(), str);
                rmvarPointer = cudaMallocNoWarn(pointer, j, null);
                if (rmvarPointer == null) {
                    LOG.warn("cudaMalloc failed after clearing/evicting based on size.");
                }
                if (DMLScript.STATISTICS) {
                    long nanoTime2 = System.nanoTime() - nanoTime;
                    GPUStatistics.cudaEvictTime.add(nanoTime2);
                    GPUStatistics.cudaEvictSizeTime.add(nanoTime2);
                    GPUStatistics.cudaEvictCount.increment();
                    GPUStatistics.cudaEvictSizeCount.increment();
                }
            }
        }
        if (rmvarPointer == null) {
            long nanoTime3 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            long availableMemory = this.allocator.getAvailableMemory();
            boolean z = false;
            List list = (List) this.matrixMemoryManager.gpuObjects.stream().filter(gPUObject4 -> {
                return !gPUObject4.isLocked();
            }).collect(Collectors.toList());
            Collections.sort(list, new EvictionPolicyBasedComparator(j));
            while (rmvarPointer == null && list.size() > 0) {
                GPUObject gPUObject5 = (GPUObject) list.remove(list.size() - 1);
                evictOrClear(gPUObject5, str);
                if (!z) {
                    availableMemory += gPUObject5.getSizeOnDevice();
                    if (availableMemory >= j) {
                        z = true;
                    }
                }
                if (z) {
                    rmvarPointer = cudaMallocNoWarn(pointer, j, null);
                }
                if (DMLScript.STATISTICS) {
                    GPUStatistics.cudaEvictCount.increment();
                }
            }
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaEvictTime.add(System.nanoTime() - nanoTime3);
            }
        }
        if (rmvarPointer == null) {
            LOG.warn("Potential fragmentation of the GPU memory. Forcibly evicting all ...");
            LOG.info("Before clearAllUnlocked, GPU Memory info:" + toString());
            this.matrixMemoryManager.clearAllUnlocked(str);
            LOG.info("GPU Memory info after evicting all unlocked matrices:" + toString());
            rmvarPointer = cudaMallocNoWarn(pointer, j, null);
        }
        if (rmvarPointer == null) {
            throw new DMLRuntimeException("There is not enough memory on device for this matrix, requested = " + byteCountToDisplaySize(j) + ". \n " + toString());
        }
        long nanoTime4 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        JCuda.cudaMemset(rmvarPointer, 0, j);
        addMiscTime(str, GPUStatistics.cudaMemSet0Time, GPUStatistics.cudaMemSet0Count, GPUInstruction.MISC_TIMER_SET_ZERO, nanoTime4);
        return rmvarPointer;
    }

    private int worstCaseContiguousMemorySizeCompare(GPUObject gPUObject, GPUObject gPUObject2) {
        long worstCaseContiguousMemorySize = this.matrixMemoryManager.getWorstCaseContiguousMemorySize(gPUObject) - this.matrixMemoryManager.getWorstCaseContiguousMemorySize(gPUObject2);
        if (worstCaseContiguousMemorySize < 0) {
            return -1;
        }
        return worstCaseContiguousMemorySize == 0 ? 0 : 1;
    }

    private void evictOrClear(GPUObject gPUObject, String str) {
        if (gPUObject.isDirty()) {
            gPUObject.copyFromDeviceToHost(str, true, true);
        } else {
            gPUObject.clearData(str, true);
        }
    }

    private void printPointers(Set<Pointer> set, StringBuilder sb) {
        HashMap hashMap = new HashMap();
        Iterator<Pointer> it = set.iterator();
        while (it.hasNext()) {
            PointerInfo pointerInfo = this.allPointers.get(it.next());
            String str = "";
            for (int i : DEBUG_MEMORY_LEAK_STACKTRACE_DEPTH) {
                str = str + getCallerInfo(pointerInfo.stackTraceElements, i);
            }
            if (hashMap.containsKey(str)) {
                hashMap.put(str, Integer.valueOf(((Integer) hashMap.get(str)).intValue() + 1));
            } else {
                hashMap.put(str, 1);
            }
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            sb.append(">>" + ((String) entry.getKey()) + " => " + entry.getValue() + "\n");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void guardedCudaFree(Pointer pointer) {
        if (!this.allPointers.containsKey(pointer)) {
            throw new RuntimeException("Attempting to free an unaccounted pointer:" + pointer);
        }
        long sizeInBytes = this.allPointers.get(pointer).getSizeInBytes();
        if (LOG.isTraceEnabled()) {
            LOG.trace("Free-ing up the pointer of size " + byteCountToDisplaySize(sizeInBytes));
        }
        this.allPointers.remove(pointer);
        this.lazyCudaFreeMemoryManager.removeIfPresent(sizeInBytes, pointer);
        this.allocator.free(pointer);
        if (DMLScript.SYNCHRONIZE_GPU) {
            JCuda.cudaDeviceSynchronize();
        }
    }

    public void free(String str, Pointer pointer, boolean z) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("Free-ing the pointer with eager=" + z);
        }
        if (z) {
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            guardedCudaFree(pointer);
            addMiscTime(str, GPUStatistics.cudaDeAllocTime, GPUStatistics.cudaDeAllocCount, GPUInstruction.MISC_TIMER_CUDA_FREE, nanoTime);
        } else if (this.allPointers.containsKey(pointer)) {
            this.lazyCudaFreeMemoryManager.add(this.allPointers.get(pointer).getSizeInBytes(), pointer);
        } else {
            LOG.info("GPU memory info before failure:" + toString());
            throw new RuntimeException("ERROR : Internal state corrupted, cache block size map is not aware of a block it trying to free up");
        }
    }

    public void removeGPUObject(GPUObject gPUObject) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("Removing the GPU object: " + gPUObject);
        }
        this.matrixMemoryManager.gpuObjects.remove(gPUObject);
    }

    public void clearMemory() {
        Iterator<GPUObject> it = this.matrixMemoryManager.gpuObjects.iterator();
        while (it.hasNext()) {
            GPUObject next = it.next();
            if (next.isDirty()) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Attempted to free GPU Memory when a block[" + next + "] is still on GPU memory, copying it back to host.");
                }
                next.copyFromDeviceToHost(null, true, true);
            } else {
                next.clearData(null, true);
            }
        }
        this.matrixMemoryManager.gpuObjects.clear();
        Iterator it2 = new HashSet(this.allPointers.keySet()).iterator();
        while (it2.hasNext()) {
            guardedCudaFree((Pointer) it2.next());
        }
        this.allPointers.clear();
    }

    private Set<Pointer> nonIn(Set<Pointer> set, Set<Pointer> set2) {
        HashSet hashSet = new HashSet();
        for (Pointer pointer : set) {
            if (!set2.contains(pointer)) {
                hashSet.add(pointer);
            }
        }
        return hashSet;
    }

    public void clearTemporaryMemory() {
        Iterator<Pointer> it = nonIn(this.allPointers.keySet(), this.matrixMemoryManager.getPointers(false, true)).iterator();
        while (it.hasNext()) {
            guardedCudaFree(it.next());
        }
    }

    private void addMiscTime(String str, LongAdder longAdder, LongAdder longAdder2, String str2, long j) {
        if (DMLScript.STATISTICS) {
            long nanoTime = System.nanoTime() - j;
            longAdder.add(nanoTime);
            longAdder2.add(1L);
            if (str == null || !DMLScript.FINEGRAINED_STATISTICS) {
                return;
            }
            GPUStatistics.maintainCPMiscTimes(str, str2, nanoTime);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addMiscTime(String str, String str2, long j) {
        if (str == null || !DMLScript.FINEGRAINED_STATISTICS) {
            return;
        }
        GPUStatistics.maintainCPMiscTimes(str, str2, System.nanoTime() - j);
    }

    public String toString() {
        long j = 0;
        int i = 0;
        int i2 = 0;
        long j2 = 0;
        int i3 = 0;
        int i4 = 0;
        long j3 = 0;
        int i5 = 0;
        int i6 = 0;
        Iterator<GPUObject> it = this.matrixMemoryManager.gpuObjects.iterator();
        while (it.hasNext()) {
            GPUObject next = it.next();
            if (next.isLocked()) {
                i++;
                j += next.getSizeOnDevice();
                i2 += this.matrixMemoryManager.getPointers(next).size();
            } else if (next.isDirty()) {
                i3++;
                j2 += next.getSizeOnDevice();
                i4 += this.matrixMemoryManager.getPointers(next).size();
            } else {
                i5++;
                j3 += next.getSizeOnDevice();
                i6 += this.matrixMemoryManager.getPointers(next).size();
            }
        }
        long j4 = 0;
        Iterator<PointerInfo> it2 = this.allPointers.values().iterator();
        while (it2.hasNext()) {
            j4 += it2.next().getSizeInBytes();
        }
        Set<Pointer> nonMatrixLockedPointers = getNonMatrixLockedPointers();
        long j5 = 0;
        Iterator it3 = ((List) nonMatrixLockedPointers.stream().map(pointer -> {
            return Long.valueOf(this.allPointers.get(pointer).sizeInBytes);
        }).collect(Collectors.toList())).iterator();
        while (it3.hasNext()) {
            j5 += ((Long) it3.next()).longValue();
        }
        StringBuilder sb = new StringBuilder();
        sb.append("\n====================================================\n");
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "", "Num Objects", "Num Pointers", "Size"));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Unlocked Dirty GPU objects", Integer.valueOf(i3), Integer.valueOf(i4), byteCountToDisplaySize(j2)));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Unlocked NonDirty GPU objects", Integer.valueOf(i5), Integer.valueOf(i6), byteCountToDisplaySize(j3)));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Locked GPU objects", Integer.valueOf(i), Integer.valueOf(i2), byteCountToDisplaySize(j)));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Cached rmvar-ed pointers", "-", Integer.valueOf(this.lazyCudaFreeMemoryManager.getNumPointers()), byteCountToDisplaySize(this.lazyCudaFreeMemoryManager.getTotalMemoryAllocated())));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Non-matrix/non-cached pointers", "-", Integer.valueOf(nonMatrixLockedPointers.size()), byteCountToDisplaySize(j5)));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "All pointers", "-", Integer.valueOf(this.allPointers.size()), byteCountToDisplaySize(j4)));
        long[] jArr = {0};
        long[] jArr2 = {0};
        JCuda.cudaMemGetInfo(jArr, jArr2);
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Free mem (from cudaMemGetInfo)", "-", "-", byteCountToDisplaySize(jArr[0])));
        sb.append(String.format("%-35s%-15s%-15s%-15s\n", "Total mem (from cudaMemGetInfo)", "-", "-", byteCountToDisplaySize(jArr2[0])));
        sb.append("====================================================\n");
        return sb.toString();
    }
}
