package org.apache.sysml.runtime.instructions.gpu.context;

import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcusolver.JCusolverDn;
import jcuda.jcusolver.cusolverDnHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUContext.class */
public class GPUContext {
    protected static final Log LOG = LogFactory.getLog(GPUContext.class.getName());
    final int MAJOR_REQUIRED = 3;
    final int MINOR_REQUIRED = 0;
    private final int deviceNum;
    private cudnnHandle cudnnHandle;
    private cublasHandle cublasHandle;
    private cusparseHandle cusparseHandle;
    private cusolverDnHandle cusolverDnHandle;
    private JCudaKernels kernels;
    private GPUMemoryManager memoryManager;

    public GPUMemoryManager getMemoryManager() {
        return this.memoryManager;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GPUContext(int i) {
        this.deviceNum = i;
        JCuda.cudaSetDevice(i);
        JCuda.cudaSetDeviceFlags(4);
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : -1L;
        initializeCudaLibraryHandles();
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - nanoTime;
        }
        this.memoryManager = new GPUMemoryManager(this);
    }

    public static int cudaGetDevice() {
        int[] iArr = new int[1];
        JCuda.cudaGetDevice(iArr);
        return iArr[0];
    }

    public void printMemoryInfo(String str) {
        if (LOG.isDebugEnabled()) {
            LOG.debug(str + ": " + this.memoryManager.toString());
        }
    }

    private void initializeCudaLibraryHandles() throws DMLRuntimeException {
        if (this.cudnnHandle == null) {
            this.cudnnHandle = new cudnnHandle();
            JCudnn.cudnnCreate(this.cudnnHandle);
        }
        if (this.cublasHandle == null) {
            this.cublasHandle = new cublasHandle();
            JCublas2.cublasCreate(this.cublasHandle);
        }
        if (this.cusparseHandle == null) {
            this.cusparseHandle = new cusparseHandle();
            JCusparse.cusparseCreate(this.cusparseHandle);
        }
        if (this.kernels == null) {
            this.kernels = new JCudaKernels();
        }
    }

    public int getDeviceNum() {
        return this.deviceNum;
    }

    public void initializeThread() {
        JCuda.cudaSetDevice(this.deviceNum);
        initializeCudaLibraryHandles();
    }

    public Pointer allocate(String str, long j) {
        return this.memoryManager.malloc(str, j);
    }

    public void cudaFreeHelper(String str, Pointer pointer, boolean z) {
        this.memoryManager.free(str, pointer, z);
    }

    public long getAvailableMemory() {
        return this.memoryManager.allocator.getAvailableMemory();
    }

    public void ensureComputeCapability() {
        int[] iArr = {-1};
        JCuda.cudaGetDeviceCount(iArr);
        if (iArr[0] == -1) {
            throw new DMLRuntimeException("Call to cudaGetDeviceCount returned 0 devices");
        }
        boolean z = true;
        for (int i = 0; i < iArr[0]; i++) {
            cudaDeviceProp gPUProperties = GPUContextPool.getGPUProperties(i);
            int i2 = gPUProperties.major;
            int i3 = gPUProperties.minor;
            if (i2 < 3) {
                z = false;
            } else if (i2 == 3 && i3 < 0) {
                z = false;
            }
        }
        if (!z) {
            throw new DMLRuntimeException("One of the CUDA cards on the system has compute capability lower than 3.0");
        }
    }

    public GPUObject createGPUObject(MatrixObject matrixObject) {
        GPUObject gPUObject = new GPUObject(this, matrixObject);
        getMemoryManager().getGPUMatrixMemoryManager().addGPUObject(gPUObject);
        return gPUObject;
    }

    public cudaDeviceProp getGPUProperties() {
        return GPUContextPool.getGPUProperties(this.deviceNum);
    }

    public int getMaxThreadsPerBlock() {
        return getGPUProperties().maxThreadsPerBlock;
    }

    public int getMaxBlocks() {
        return getGPUProperties().maxGridSize[0];
    }

    public long getMaxSharedMemory() {
        return getGPUProperties().sharedMemPerBlock;
    }

    public int getWarpSize() {
        return getGPUProperties().warpSize;
    }

    public cudnnHandle getCudnnHandle() {
        return this.cudnnHandle;
    }

    public cublasHandle getCublasHandle() {
        return this.cublasHandle;
    }

    public cusparseHandle getCusparseHandle() {
        return this.cusparseHandle;
    }

    public cusolverDnHandle getCusolverDnHandle() {
        if (this.cusolverDnHandle == null) {
            synchronized (this) {
                if (this.cusolverDnHandle == null) {
                    this.cusolverDnHandle = new cusolverDnHandle();
                    JCusolverDn.cusolverDnCreate(this.cusolverDnHandle);
                }
            }
        }
        return this.cusolverDnHandle;
    }

    public JCudaKernels getKernels() {
        return this.kernels;
    }

    public void destroy() {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : this context was destroyed, this = " + toString());
        }
        clearMemory();
        deleteCudaLibraryHandles();
    }

    private void deleteCudaLibraryHandles() {
        if (this.cudnnHandle != null) {
            JCudnn.cudnnDestroy(this.cudnnHandle);
        }
        if (this.cublasHandle != null) {
            JCublas2.cublasDestroy(this.cublasHandle);
        }
        if (this.cusparseHandle != null) {
            JCusparse.cusparseDestroy(this.cusparseHandle);
        }
        if (this.cusolverDnHandle != null) {
            JCusolverDn.cusolverDnDestroy(this.cusolverDnHandle);
        }
        this.cudnnHandle = null;
        this.cublasHandle = null;
        this.cusparseHandle = null;
        this.cusolverDnHandle = null;
    }

    public void clearMemory() {
        this.memoryManager.clearMemory();
    }

    public void clearTemporaryMemory() {
        this.memoryManager.clearTemporaryMemory();
    }

    public String toString() {
        return "GPUContext{deviceNum=" + this.deviceNum + '}';
    }
}
