package org.apache.sysml.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnDropoutDescriptor;
import jcuda.jcudnn.cudnnFilterDescriptor;
import jcuda.jcudnn.cudnnRNNDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixCuDNNRnnAlgorithm.class */
public class LibMatrixCuDNNRnnAlgorithm implements AutoCloseable {
    GPUContext gCtx;
    String instName;
    cudnnDropoutDescriptor dropoutDesc;
    cudnnRNNDescriptor rnnDesc;
    cudnnTensorDescriptor[] xDesc;
    cudnnTensorDescriptor[] dxDesc;
    cudnnTensorDescriptor[] yDesc;
    cudnnTensorDescriptor[] dyDesc;
    cudnnTensorDescriptor hxDesc;
    cudnnTensorDescriptor cxDesc;
    cudnnTensorDescriptor hyDesc;
    cudnnTensorDescriptor cyDesc;
    cudnnTensorDescriptor dhxDesc;
    cudnnTensorDescriptor dcxDesc;
    cudnnTensorDescriptor dhyDesc;
    cudnnTensorDescriptor dcyDesc;
    cudnnFilterDescriptor wDesc;
    cudnnFilterDescriptor dwDesc;
    long sizeInBytes;
    Pointer workSpace;
    long reserveSpaceSizeInBytes;
    Pointer reserveSpace;
    long dropOutSizeInBytes;
    Pointer dropOutStateSpace;

    public LibMatrixCuDNNRnnAlgorithm(ExecutionContext executionContext, GPUContext gPUContext, String str, String str2, int i, int i2, int i3, int i4, boolean z, Pointer pointer) throws DMLRuntimeException {
        this.gCtx = gPUContext;
        this.instName = str;
        this.xDesc = new cudnnTensorDescriptor[i2];
        this.dxDesc = new cudnnTensorDescriptor[i2];
        this.yDesc = new cudnnTensorDescriptor[i2];
        this.dyDesc = new cudnnTensorDescriptor[i2];
        for (int i5 = 0; i5 < i2; i5++) {
            this.xDesc[i5] = allocateTensorDescriptorWithStride(i, i4, 1);
            this.dxDesc[i5] = allocateTensorDescriptorWithStride(i, i4, 1);
            this.yDesc[i5] = allocateTensorDescriptorWithStride(i, i3, 1);
            this.dyDesc[i5] = allocateTensorDescriptorWithStride(i, i3, 1);
        }
        this.hxDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.dhxDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.cxDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.dcxDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.hyDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.dhyDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.cyDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.dcyDesc = allocateTensorDescriptorWithStride(1, i, i3);
        this.dropoutDesc = new cudnnDropoutDescriptor();
        JCudnn.cudnnCreateDropoutDescriptor(this.dropoutDesc);
        long[] jArr = {-1};
        JCudnn.cudnnDropoutGetStatesSize(gPUContext.getCudnnHandle(), jArr);
        this.dropOutSizeInBytes = jArr[0];
        this.dropOutStateSpace = new Pointer();
        if (this.dropOutSizeInBytes != 0) {
            this.dropOutStateSpace = gPUContext.allocate(str, this.dropOutSizeInBytes);
        }
        JCudnn.cudnnSetDropoutDescriptor(this.dropoutDesc, gPUContext.getCudnnHandle(), 0.0f, this.dropOutStateSpace, this.dropOutSizeInBytes, 12345L);
        this.rnnDesc = new cudnnRNNDescriptor();
        JCudnn.cudnnCreateRNNDescriptor(this.rnnDesc);
        JCudnn.cudnnSetRNNDescriptor_v6(gPUContext.getCudnnHandle(), this.rnnDesc, i3, 1, this.dropoutDesc, 0, 0, getCuDNNRnnMode(str2), 0, LibMatrixCUDA.CUDNN_DATA_TYPE);
        int expectedNumWeights = getExpectedNumWeights();
        if (str2.equalsIgnoreCase("lstm") && (i4 + i3 + 2) * 4 * i3 != expectedNumWeights) {
            throw new DMLRuntimeException("Incorrect number of RNN parameters " + ((i4 + i3 + 2) * 4 * i3) + " != " + expectedNumWeights + ", where numFeatures=" + i4 + ", hiddenSize=" + i3);
        }
        this.wDesc = allocateFilterDescriptor(expectedNumWeights);
        this.dwDesc = allocateFilterDescriptor(expectedNumWeights);
        this.workSpace = new Pointer();
        this.reserveSpace = new Pointer();
        this.sizeInBytes = getWorkspaceSize(i2);
        if (this.sizeInBytes != 0) {
            this.workSpace = gPUContext.allocate(str, this.sizeInBytes);
        }
        this.reserveSpaceSizeInBytes = 0L;
        if (z) {
            this.reserveSpaceSizeInBytes = getReservespaceSize(i2);
            if (this.reserveSpaceSizeInBytes != 0) {
                this.reserveSpace = gPUContext.allocate(str, this.reserveSpaceSizeInBytes);
            }
        }
    }

    private int getNumLinearLayers(String str) throws DMLRuntimeException {
        int i;
        if (str.equalsIgnoreCase("rnn_relu") || str.equalsIgnoreCase("rnn_tanh")) {
            i = 2;
        } else if (str.equalsIgnoreCase("lstm")) {
            i = 8;
        } else {
            if (!str.equalsIgnoreCase("gru")) {
                throw new DMLRuntimeException("Unsupported rnn mode:" + str);
            }
            i = 6;
        }
        return i;
    }

    private long getWorkspaceSize(int i) {
        long[] jArr = new long[1];
        JCudnn.cudnnGetRNNWorkspaceSize(this.gCtx.getCudnnHandle(), this.rnnDesc, i, this.xDesc, jArr);
        return jArr[0];
    }

    private long getReservespaceSize(int i) {
        long[] jArr = new long[1];
        JCudnn.cudnnGetRNNTrainingReserveSize(this.gCtx.getCudnnHandle(), this.rnnDesc, i, this.xDesc, jArr);
        return jArr[0];
    }

    private int getCuDNNRnnMode(String str) throws DMLRuntimeException {
        int i;
        if (str.equalsIgnoreCase("rnn_relu")) {
            i = 0;
        } else if (str.equalsIgnoreCase("rnn_tanh")) {
            i = 1;
        } else if (str.equalsIgnoreCase("lstm")) {
            i = 2;
        } else {
            if (!str.equalsIgnoreCase("gru")) {
                throw new DMLRuntimeException("Unsupported rnn mode:" + str);
            }
            i = 3;
        }
        return i;
    }

    private int getExpectedNumWeights() throws DMLRuntimeException {
        long[] jArr = {-1};
        JCudnn.cudnnGetRNNParamsSize(this.gCtx.getCudnnHandle(), this.rnnDesc, this.xDesc[0], jArr, LibMatrixCUDA.CUDNN_DATA_TYPE);
        return LibMatrixCUDA.toInt(jArr[0] / LibMatrixCUDA.sizeOfDataType);
    }

    private cudnnFilterDescriptor allocateFilterDescriptor(int i) {
        cudnnFilterDescriptor cudnnfilterdescriptor = new cudnnFilterDescriptor();
        JCudnn.cudnnCreateFilterDescriptor(cudnnfilterdescriptor);
        JCudnn.cudnnSetFilterNdDescriptor(cudnnfilterdescriptor, LibMatrixCUDA.CUDNN_DATA_TYPE, 0, 3, new int[]{i, 1, 1});
        return cudnnfilterdescriptor;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptorWithStride(int i, int i2, int i3) throws DMLRuntimeException {
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        int[] iArr = {i, i2, i3};
        JCudnn.cudnnSetTensorNdDescriptor(cudnntensordescriptor, LibMatrixCUDA.CUDNN_DATA_TYPE, 3, iArr, new int[]{iArr[2] * iArr[1], iArr[2], 1});
        return cudnntensordescriptor;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.dropoutDesc != null) {
            JCudnn.cudnnDestroyDropoutDescriptor(this.dropoutDesc);
        }
        this.dropoutDesc = null;
        if (this.rnnDesc != null) {
            JCudnn.cudnnDestroyRNNDescriptor(this.rnnDesc);
        }
        this.rnnDesc = null;
        if (this.hxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.hxDesc);
        }
        this.hxDesc = null;
        if (this.dhxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.dhxDesc);
        }
        this.dhxDesc = null;
        if (this.hyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.hyDesc);
        }
        this.hyDesc = null;
        if (this.dhyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.dhyDesc);
        }
        this.dhyDesc = null;
        if (this.cxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.cxDesc);
        }
        this.cxDesc = null;
        if (this.dcxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.dcxDesc);
        }
        this.dcxDesc = null;
        if (this.cyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.cyDesc);
        }
        this.cyDesc = null;
        if (this.dcyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor(this.dcyDesc);
        }
        this.dcyDesc = null;
        if (this.wDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor(this.wDesc);
        }
        this.wDesc = null;
        if (this.dwDesc != null) {
            JCudnn.cudnnDestroyFilterDescriptor(this.dwDesc);
        }
        this.dwDesc = null;
        if (this.xDesc != null) {
            for (cudnnTensorDescriptor cudnntensordescriptor : this.xDesc) {
                JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor);
            }
            this.xDesc = null;
        }
        if (this.dxDesc != null) {
            for (cudnnTensorDescriptor cudnntensordescriptor2 : this.dxDesc) {
                JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor2);
            }
            this.dxDesc = null;
        }
        if (this.yDesc != null) {
            for (cudnnTensorDescriptor cudnntensordescriptor3 : this.yDesc) {
                JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor3);
            }
            this.yDesc = null;
        }
        if (this.dyDesc != null) {
            for (cudnnTensorDescriptor cudnntensordescriptor4 : this.dyDesc) {
                JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor4);
            }
            this.dyDesc = null;
        }
        if (this.sizeInBytes != 0) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.workSpace, DMLScript.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e) {
                throw new RuntimeException(e);
            }
        }
        this.workSpace = null;
        if (this.reserveSpaceSizeInBytes != 0) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.reserveSpace, DMLScript.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e2) {
                throw new RuntimeException(e2);
            }
        }
        this.reserveSpace = null;
        if (this.dropOutSizeInBytes != 0) {
            try {
                this.gCtx.cudaFreeHelper(this.instName, this.dropOutStateSpace, DMLScript.EAGER_CUDA_FREE);
            } catch (DMLRuntimeException e3) {
                throw new RuntimeException(e3);
            }
        }
    }
}
