package org.apache.sysml.runtime.instructions.gpu.context;

import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.jcusparse.cusparseMatDescr;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.class */
public class CSRPointer {
    private static final Log LOG = LogFactory.getLog(CSRPointer.class.getName());
    private static final double ULTRA_SPARSITY_TURN_POINT = 4.0E-5d;
    public static cusparseMatDescr matrixDescriptor;
    private final GPUContext gpuContext;
    public long nnz;
    public Pointer val = new Pointer();
    public Pointer rowPtr = new Pointer();
    public Pointer colInd = new Pointer();
    public cusparseMatDescr descr;

    private CSRPointer(GPUContext gPUContext) {
        this.gpuContext = gPUContext;
        allocateMatDescrPointer();
    }

    private static long getDataTypeSizeOf(long j) {
        return j * LibMatrixCUDA.sizeOfDataType;
    }

    private static long getIntSizeOf(long j) {
        return j * 4;
    }

    public static int toIntExact(long j) {
        if (j < -2147483648L || j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("Cannot be cast to int:" + j);
        }
        return (int) j;
    }

    public static cusparseMatDescr getDefaultCuSparseMatrixDescriptor() {
        if (matrixDescriptor == null) {
            matrixDescriptor = new cusparseMatDescr();
            JCusparse.cusparseCreateMatDescr(matrixDescriptor);
            JCusparse.cusparseSetMatType(matrixDescriptor, 0);
            JCusparse.cusparseSetMatIndexBase(matrixDescriptor, 0);
        }
        return matrixDescriptor;
    }

    public static long estimateSize(long j, long j2) {
        long dataTypeSizeOf = getDataTypeSizeOf(j);
        long intSizeOf = getIntSizeOf(j2 + 1);
        return dataTypeSizeOf + intSizeOf + getIntSizeOf(j) + getIntSizeOf(4L);
    }

    public static void copyToDevice(GPUContext gPUContext, CSRPointer cSRPointer, int i, long j, int[] iArr, int[] iArr2, double[] dArr) {
        long j2 = 0;
        if (DMLScript.STATISTICS) {
            j2 = System.nanoTime();
        }
        cSRPointer.nnz = j;
        if (i < 0) {
            throw new DMLRuntimeException("Incorrect input parameter: rows=" + i);
        }
        if (j < 0) {
            throw new DMLRuntimeException("Incorrect input parameter: nnz=" + j);
        }
        if (iArr.length < i + 1) {
            throw new DMLRuntimeException("The length of rowPtr needs to be greater than or equal to " + (i + 1));
        }
        if (iArr2.length < j) {
            throw new DMLRuntimeException("The length of colInd needs to be greater than or equal to " + j);
        }
        if (dArr.length < j) {
            throw new DMLRuntimeException("The length of values needs to be greater than or equal to " + j);
        }
        LibMatrixCUDA.cudaSupportFunctions.hostToDevice(gPUContext, dArr, cSRPointer.val, null);
        JCuda.cudaMemcpy(cSRPointer.rowPtr, Pointer.to(iArr), getIntSizeOf(i + 1), 1);
        JCuda.cudaMemcpy(cSRPointer.colInd, Pointer.to(iArr2), getIntSizeOf(j), 1);
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaToDevTime.add(System.nanoTime() - j2);
        }
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaToDevCount.add(3L);
        }
    }

    public static void copyPtrToHost(CSRPointer cSRPointer, int i, long j, int[] iArr, int[] iArr2) {
        JCuda.cudaMemcpy(Pointer.to(iArr), cSRPointer.rowPtr, getIntSizeOf(i + 1), 2);
        JCuda.cudaMemcpy(Pointer.to(iArr2), cSRPointer.colInd, getIntSizeOf(j), 2);
    }

    public static CSRPointer allocateForDgeam(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer, CSRPointer cSRPointer2, int i, int i2) {
        if (cSRPointer.nnz >= OptimizerUtils.MAX_NUMCELLS_CP_DENSE || cSRPointer2.nnz >= OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("Number of non zeroes is larger than supported by cuSparse");
        }
        CSRPointer cSRPointer3 = new CSRPointer(gPUContext);
        step1AllocateRowPointers(gPUContext, cusparsehandle, cSRPointer3, i);
        step2GatherNNZGeam(gPUContext, cusparsehandle, cSRPointer, cSRPointer2, cSRPointer3, i, i2);
        step3AllocateValNInd(gPUContext, cusparsehandle, cSRPointer3);
        return cSRPointer3;
    }

    public static CSRPointer allocateForMatrixMultiply(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer, int i, CSRPointer cSRPointer2, int i2, int i3, int i4, int i5) {
        CSRPointer cSRPointer3 = new CSRPointer(gPUContext);
        step1AllocateRowPointers(gPUContext, cusparsehandle, cSRPointer3, i3);
        step2GatherNNZGemm(gPUContext, cusparsehandle, cSRPointer, i, cSRPointer2, i2, cSRPointer3, i3, i4, i5);
        step3AllocateValNInd(gPUContext, cusparsehandle, cSRPointer3);
        return cSRPointer3;
    }

    public static CSRPointer allocateEmpty(GPUContext gPUContext, long j, long j2) {
        LOG.trace("GPU : allocateEmpty from CSRPointer with nnz=" + j + " and rows=" + j2 + ", GPUContext=" + gPUContext);
        if (j < 0) {
            throw new DMLRuntimeException("Incorrect usage of internal API, number of non zeroes is less than 0 when trying to allocate sparse data on GPU");
        }
        if (j2 <= 0) {
            throw new DMLRuntimeException("Incorrect usage of internal API, number of rows is less than or equal to 0 when trying to allocate sparse data on GPU");
        }
        CSRPointer cSRPointer = new CSRPointer(gPUContext);
        cSRPointer.nnz = j;
        if (j == 0) {
            return cSRPointer;
        }
        cSRPointer.val = gPUContext.allocate(null, getDataTypeSizeOf(j));
        cSRPointer.rowPtr = gPUContext.allocate(null, getIntSizeOf(j2 + 1));
        cSRPointer.colInd = gPUContext.allocate(null, getIntSizeOf(j));
        return cSRPointer;
    }

    private static void step1AllocateRowPointers(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer, int i) {
        LOG.trace("GPU : step1AllocateRowPointers, GPUContext=" + gPUContext);
        JCusparse.cusparseSetPointerMode(cusparsehandle, 0);
        cSRPointer.rowPtr = gPUContext.allocate(null, getIntSizeOf(i + 1));
    }

    private static void step2GatherNNZGeam(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer, CSRPointer cSRPointer2, CSRPointer cSRPointer3, int i, int i2) {
        LOG.trace("GPU : step2GatherNNZGeam for DGEAM, GPUContext=" + gPUContext);
        int[] iArr = {-1};
        JCusparse.cusparseXcsrgeamNnz(cusparsehandle, i, i2, cSRPointer.descr, toIntExact(cSRPointer.nnz), cSRPointer.rowPtr, cSRPointer.colInd, cSRPointer2.descr, toIntExact(cSRPointer2.nnz), cSRPointer2.rowPtr, cSRPointer2.colInd, cSRPointer3.descr, cSRPointer3.rowPtr, Pointer.to(iArr));
        if (iArr[0] != -1) {
            cSRPointer3.nnz = iArr[0];
            return;
        }
        JCuda.cudaMemcpy(Pointer.to(iArr), cSRPointer3.rowPtr.withByteOffset(getIntSizeOf(i)), getIntSizeOf(1L), 2);
        JCuda.cudaMemcpy(Pointer.to(new int[]{0}), cSRPointer3.rowPtr, getIntSizeOf(1L), 2);
        cSRPointer3.nnz = iArr[0] - r0[0];
    }

    private static void step2GatherNNZGemm(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer, int i, CSRPointer cSRPointer2, int i2, CSRPointer cSRPointer3, int i3, int i4, int i5) {
        LOG.trace("GPU : step2GatherNNZGemm for DGEMM, GPUContext=" + gPUContext);
        int[] iArr = {-1};
        if (cSRPointer.nnz >= OptimizerUtils.MAX_NUMCELLS_CP_DENSE || cSRPointer2.nnz >= OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            throw new DMLRuntimeException("Number of non zeroes is larger than supported by cuSparse");
        }
        JCusparse.cusparseXcsrgemmNnz(cusparsehandle, i, i2, i3, i4, i5, cSRPointer.descr, toIntExact(cSRPointer.nnz), cSRPointer.rowPtr, cSRPointer.colInd, cSRPointer2.descr, toIntExact(cSRPointer2.nnz), cSRPointer2.rowPtr, cSRPointer2.colInd, cSRPointer3.descr, cSRPointer3.rowPtr, Pointer.to(iArr));
        if (iArr[0] != -1) {
            cSRPointer3.nnz = iArr[0];
            return;
        }
        JCuda.cudaMemcpy(Pointer.to(iArr), cSRPointer3.rowPtr.withByteOffset(getIntSizeOf(i3)), getIntSizeOf(1L), 2);
        JCuda.cudaMemcpy(Pointer.to(new int[]{0}), cSRPointer3.rowPtr, getIntSizeOf(1L), 2);
        cSRPointer3.nnz = iArr[0] - r0[0];
    }

    private static void step3AllocateValNInd(GPUContext gPUContext, cusparseHandle cusparsehandle, CSRPointer cSRPointer) {
        LOG.trace("GPU : step3AllocateValNInd, GPUContext=" + gPUContext);
        cSRPointer.val = gPUContext.allocate(null, getDataTypeSizeOf(cSRPointer.nnz));
        cSRPointer.colInd = gPUContext.allocate(null, getIntSizeOf(cSRPointer.nnz));
    }

    public CSRPointer clone(int i) {
        CSRPointer cSRPointer = new CSRPointer(getGPUContext());
        cSRPointer.allocateMatDescrPointer();
        cSRPointer.nnz = this.nnz;
        cSRPointer.val = allocate(cSRPointer.nnz * LibMatrixCUDA.sizeOfDataType);
        cSRPointer.rowPtr = allocate(i * 4);
        cSRPointer.colInd = allocate(cSRPointer.nnz * 4);
        JCuda.cudaMemcpy(cSRPointer.val, this.val, cSRPointer.nnz * LibMatrixCUDA.sizeOfDataType, 3);
        JCuda.cudaMemcpy(cSRPointer.rowPtr, this.rowPtr, i * 4, 3);
        JCuda.cudaMemcpy(cSRPointer.colInd, this.colInd, cSRPointer.nnz * 4, 3);
        return cSRPointer;
    }

    private Pointer allocate(long j) {
        return getGPUContext().allocate(null, j);
    }

    private GPUContext getGPUContext() {
        return this.gpuContext;
    }

    public boolean isUltraSparse(int i, int i2) {
        return (((double) this.nnz) / ((double) i)) / ((double) i2) < 4.0E-5d;
    }

    private void allocateMatDescrPointer() {
        this.descr = getDefaultCuSparseMatrixDescriptor();
    }

    public Pointer toColumnMajorDenseMatrix(cusparseHandle cusparsehandle, cublasHandle cublashandle, int i, int i2, String str) {
        long nanoTime = (!DMLScript.FINEGRAINED_STATISTICS || str == null) ? 0L : System.nanoTime();
        LOG.trace("GPU : sparse -> column major dense (inside CSRPointer) on " + this + ", GPUContext=" + getGPUContext());
        Pointer allocate = allocate(i * getDataTypeSizeOf(i2));
        if (this.val == null || this.rowPtr == null || this.colInd == null || this.nnz <= 0) {
            LOG.debug("in CSRPointer, the values array, row pointers array or column indices array was null");
        } else {
            LibMatrixCUDA.cudaSupportFunctions.cusparsecsr2dense(cusparsehandle, i, i2, this.descr, this.val, this.rowPtr, this.colInd, allocate, i);
        }
        if (DMLScript.FINEGRAINED_STATISTICS && str != null) {
            GPUStatistics.maintainCPMiscTimes(str, "s2d", System.nanoTime() - nanoTime);
        }
        return allocate;
    }

    public void deallocate() {
        deallocate(DMLScript.EAGER_CUDA_FREE);
    }

    public void deallocate(boolean z) {
        if (this.nnz > 0) {
            if (this.val != null) {
                getGPUContext().cudaFreeHelper(null, this.val, z);
            }
            if (this.rowPtr != null) {
                getGPUContext().cudaFreeHelper(null, this.rowPtr, z);
            }
            if (this.colInd != null) {
                getGPUContext().cudaFreeHelper(null, this.colInd, z);
            }
        }
        this.val = null;
        this.rowPtr = null;
        this.colInd = null;
    }

    public String toString() {
        return "CSRPointer{nnz=" + this.nnz + '}';
    }
}
