package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import org.apache.sysml.hops.OptimizerUtils;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu.class */
public class LibMatrixDNNRelu {

    /* loaded from: input_file:org/apache/sysml/runtime/matrix/data/LibMatrixDNNRelu$ReluBackward.class */
    public static class ReluBackward implements Callable<Long> {
        public final int _rl;
        public final int _ru;
        private final DnnParameters _params;

        public ReluBackward(int i, int i2, DnnParameters dnnParameters) {
            this._rl = i;
            this._ru = i2;
            this._params = dnnParameters;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() throws Exception {
            MatrixBlock matrixBlock = this._params.input1;
            MatrixBlock matrixBlock2 = this._params.input2;
            MatrixBlock matrixBlock3 = this._params.output;
            int numColumns = matrixBlock.getNumColumns();
            if (matrixBlock.isEmptyBlock(false) || matrixBlock2.isEmptyBlock(false)) {
                return 0L;
            }
            if (!matrixBlock.isInSparseFormat() && !matrixBlock2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardDenseDense(matrixBlock.getDenseBlock(), matrixBlock2.getDenseBlock(), matrixBlock3.getDenseBlock(), numColumns, this._rl, this._ru);
            } else if (!matrixBlock.isInSparseFormat() && matrixBlock2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardDenseSparse(matrixBlock.getDenseBlock(), matrixBlock2.getSparseBlock(), matrixBlock3.getSparseBlock(), this._rl, this._ru);
            } else if (!matrixBlock.isInSparseFormat() || matrixBlock2.isInSparseFormat()) {
                LibMatrixDNNRelu.reluBackwardSparseSparse(matrixBlock.getSparseBlock(), matrixBlock2.getSparseBlock(), matrixBlock3.getSparseBlock(), this._rl, this._ru);
            } else {
                LibMatrixDNNRelu.reluBackwardSparseDense(matrixBlock.getSparseBlock(), matrixBlock2.getDenseBlock(), matrixBlock3.getSparseBlock(), this._rl, this._ru);
            }
            return Long.valueOf(matrixBlock3.recomputeNonZeros(this._rl, this._ru - 1));
        }
    }

    public static ArrayList<Callable<Long>> getReluBackwardWorkers(DnnParameters dnnParameters) {
        ArrayList<Callable<Long>> arrayList = new ArrayList<>();
        int ceil = (int) Math.ceil(dnnParameters.N / OptimizerUtils.getConstrainedNumThreads(dnnParameters.numThreads));
        for (int i = 0; i * ceil < dnnParameters.N; i++) {
            arrayList.add(new ReluBackward(i * ceil, Math.min((i + 1) * ceil, dnnParameters.N), dnnParameters));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void reluBackwardDenseDense(DenseBlock denseBlock, DenseBlock denseBlock2, DenseBlock denseBlock3, int i, int i2, int i3) {
        for (int i4 = i2; i4 < i3; i4++) {
            double[] values = denseBlock.values(i4);
            double[] values2 = denseBlock2.values(i4);
            double[] values3 = denseBlock3.values(i4);
            int pos = denseBlock.pos(i4);
            for (int i5 = 0; i5 < i; i5++) {
                values3[pos + i5] = values[pos + i5] > 0.0d ? values2[pos + i5] : 0.0d;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void reluBackwardDenseSparse(DenseBlock denseBlock, SparseBlock sparseBlock, SparseBlock sparseBlock2, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            if (!sparseBlock.isEmpty(i3)) {
                int pos = sparseBlock.pos(i3);
                int size = sparseBlock.size(i3);
                int[] indexes = sparseBlock.indexes(i3);
                double[] values = sparseBlock.values(i3);
                double[] values2 = denseBlock.values(i3);
                int pos2 = denseBlock.pos(i3);
                sparseBlock2.allocate(i3, size);
                for (int i4 = pos; i4 < pos + size; i4++) {
                    sparseBlock2.append(i3, indexes[i4], values2[pos2 + indexes[i4]] > 0.0d ? values[i4] : 0.0d);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void reluBackwardSparseDense(SparseBlock sparseBlock, DenseBlock denseBlock, SparseBlock sparseBlock2, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            if (!sparseBlock.isEmpty(i3)) {
                int pos = sparseBlock.pos(i3);
                int size = sparseBlock.size(i3);
                int[] indexes = sparseBlock.indexes(i3);
                double[] values = sparseBlock.values(i3);
                double[] values2 = denseBlock.values(i3);
                int pos2 = denseBlock.pos(i3);
                sparseBlock2.allocate(i3, size);
                for (int i4 = pos; i4 < pos + size; i4++) {
                    sparseBlock2.append(i3, indexes[i4], values[i4] > 0.0d ? values2[pos2 + indexes[i4]] : 0.0d);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void reluBackwardSparseSparse(SparseBlock sparseBlock, SparseBlock sparseBlock2, SparseBlock sparseBlock3, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            if (!sparseBlock.isEmpty(i3) && !sparseBlock2.isEmpty(i3)) {
                int pos = sparseBlock2.pos(i3);
                int size = sparseBlock2.size(i3);
                int[] indexes = sparseBlock2.indexes(i3);
                double[] values = sparseBlock2.values(i3);
                sparseBlock3.allocate(i3, size);
                for (int i4 = pos; i4 < pos + size; i4++) {
                    sparseBlock3.append(i3, indexes[i4], sparseBlock.get(i3, indexes[i4]) > 0.0d ? values[i4] : 0.0d);
                }
            }
        }
    }
}
