package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.HashMap;
import jcuda.driver.CUdevice;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import org.apache.sysml.runtime.DMLRuntimeException;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/ExecutionConfig.class */
public class ExecutionConfig {
    public int gridDimX;
    public int gridDimY;
    public int gridDimZ;
    public int blockDimX;
    public int blockDimY;
    public int blockDimZ;
    public int sharedMemBytes;
    public CUstream stream;
    private static HashMap<Integer, Integer> maxBlockDimForDevice = new HashMap<>();

    public ExecutionConfig(int i, int i2, int i3) {
        this.gridDimY = 1;
        this.gridDimZ = 1;
        this.blockDimY = 1;
        this.blockDimZ = 1;
        this.sharedMemBytes = 0;
        this.stream = null;
        this.gridDimX = i;
        this.blockDimX = i2;
        this.sharedMemBytes = i3;
    }

    public static ExecutionConfig getConfigForSimpleVectorOperations(int i) {
        if (i == 0) {
            throw new DMLRuntimeException("Attempting to invoke a kernel with 0 threads");
        }
        int maxBlockDim = getMaxBlockDim(0);
        return new ExecutionConfig((int) Math.ceil(i / maxBlockDim), maxBlockDim);
    }

    public static ExecutionConfig getConfigForSimpleMatrixOperations(int i, int i2) {
        return getConfigForSimpleVectorOperations(i * i2);
    }

    public ExecutionConfig(int i, int i2) {
        this.gridDimY = 1;
        this.gridDimZ = 1;
        this.blockDimY = 1;
        this.blockDimZ = 1;
        this.sharedMemBytes = 0;
        this.stream = null;
        this.gridDimX = i;
        this.blockDimX = i2;
    }

    public ExecutionConfig(int i, int i2, int i3, int i4) {
        this.gridDimY = 1;
        this.gridDimZ = 1;
        this.blockDimY = 1;
        this.blockDimZ = 1;
        this.sharedMemBytes = 0;
        this.stream = null;
        this.gridDimX = i;
        this.gridDimY = i2;
        this.blockDimX = i3;
        this.blockDimY = i4;
    }

    private static int getMaxBlockDim(int i) {
        Integer num = maxBlockDimForDevice.get(Integer.valueOf(i));
        if (num != null) {
            return num.intValue();
        }
        CUdevice cUdevice = new CUdevice();
        JCudaKernels.checkResult(JCudaDriver.cuDeviceGet(cUdevice, i));
        int[] iArr = {0};
        JCudaDriver.cuDeviceGetAttribute(iArr, 2, cUdevice);
        maxBlockDimForDevice.put(Integer.valueOf(i), Integer.valueOf(iArr[0]));
        return iArr[0];
    }

    public String toString() {
        return "ExecutionConfig{gridDimX=" + this.gridDimX + ", gridDimY=" + this.gridDimY + ", gridDimZ=" + this.gridDimZ + ", blockDimX=" + this.blockDimX + ", blockDimY=" + this.blockDimY + ", blockDimZ=" + this.blockDimZ + ", sharedMemBytes=" + this.sharedMemBytes + '}';
    }
}
