package org.apache.sysml.hops.estim;

import org.apache.commons.lang.NotImplementedException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.estim.SparsityEstimator;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorDensityMap.class */
public class EstimatorDensityMap extends SparsityEstimator {
    private static final int BLOCK_SIZE = 256;
    private final int _b;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorDensityMap$DensityMap.class */
    public static class DensityMap {
        private final MatrixBlock _map;
        private final int _rlen;
        private final int _clen;
        private final int _b;
        private boolean _scaled;

        public DensityMap(MatrixBlock matrixBlock, int i) {
            this._rlen = matrixBlock.getNumRows();
            this._clen = matrixBlock.getNumColumns();
            this._b = i;
            this._map = init(matrixBlock);
            this._scaled = false;
        }

        public DensityMap(MatrixBlock matrixBlock, int i, int i2, int i3, boolean z) {
            this._rlen = i;
            this._clen = i2;
            this._b = i3;
            this._map = matrixBlock;
            this._scaled = z;
        }

        public int getNumRows() {
            return this._map.getNumRows();
        }

        public int getNumColumns() {
            return this._map.getNumColumns();
        }

        public int getNumRowsOrig() {
            return this._rlen;
        }

        public int getNumColumnsOrig() {
            return this._clen;
        }

        public long getNonZeros() {
            if (this._scaled) {
                toNnz();
            }
            return Math.round(this._map.sum());
        }

        public int getRowBlockize(int i) {
            return UtilFunctions.computeBlockSize(this._rlen, i + 1, this._b);
        }

        public int getColBlockize(int i) {
            return UtilFunctions.computeBlockSize(this._clen, i + 1, this._b);
        }

        public double get(int i, int i2) {
            return this._map.quickGetValue(i, i2);
        }

        public void toSparsity() {
            if (this._scaled) {
                return;
            }
            int numRows = this._map.getNumRows();
            int numColumns = this._map.getNumColumns();
            DenseBlock denseBlock = this._map.getDenseBlock();
            for (int i = 0; i < numRows; i++) {
                int rowBlockize = getRowBlockize(i);
                for (int i2 = 0; i2 < numColumns; i2++) {
                    double d = denseBlock.get(i, i2);
                    if (d != 0.0d) {
                        denseBlock.set(i, i2, (d / rowBlockize) / getColBlockize(i2));
                    }
                }
            }
            this._scaled = true;
        }

        public void toNnz() {
            if (this._scaled) {
                int numRows = this._map.getNumRows();
                int numColumns = this._map.getNumColumns();
                DenseBlock denseBlock = this._map.getDenseBlock();
                for (int i = 0; i < numRows; i++) {
                    int rowBlockize = getRowBlockize(i);
                    for (int i2 = 0; i2 < numColumns; i2++) {
                        double d = denseBlock.get(i, i2);
                        if (d != 0.0d) {
                            denseBlock.set(i, i2, d * rowBlockize * getColBlockize(i2));
                        }
                    }
                }
                this._scaled = false;
            }
        }

        private MatrixBlock init(MatrixBlock matrixBlock) {
            MatrixBlock matrixBlock2 = new MatrixBlock((int) Math.ceil(this._rlen / this._b), (int) Math.ceil(this._clen / this._b), false);
            if (matrixBlock.isEmptyBlock(false)) {
                return matrixBlock2;
            }
            DenseBlock denseBlock = matrixBlock2.allocateBlock().getDenseBlock();
            if (matrixBlock.getLength() == matrixBlock.getNonZeros()) {
                denseBlock.set(1.0d);
                matrixBlock2.setNonZeros(matrixBlock.getLength());
                return matrixBlock2;
            }
            if (matrixBlock.isInSparseFormat()) {
                SparseBlock sparseBlock = matrixBlock.getSparseBlock();
                for (int i = 0; i < matrixBlock.getNumRows(); i++) {
                    if (!sparseBlock.isEmpty(i)) {
                        int size = sparseBlock.size(i);
                        int pos = sparseBlock.pos(i);
                        int[] indexes = sparseBlock.indexes(i);
                        for (int i2 = pos; i2 < pos + size; i2++) {
                            denseBlock.incr(i / this._b, indexes[i2] / this._b);
                        }
                    }
                }
            } else {
                for (int i3 = 0; i3 < this._rlen; i3++) {
                    for (int i4 = 0; i4 < this._clen; i4++) {
                        if (matrixBlock.quickGetValue(i3, i4) != 0.0d) {
                            denseBlock.incr(i3 / this._b, i4 / this._b);
                        }
                    }
                }
            }
            matrixBlock2.recomputeNonZeros();
            return matrixBlock2;
        }
    }

    public EstimatorDensityMap() {
        this(256);
    }

    public EstimatorDensityMap(int i) {
        this._b = i;
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public MatrixCharacteristics estim(MMNode mMNode) {
        if (!mMNode.getLeft().isLeaf()) {
            estim(mMNode.getLeft());
        }
        if (!mMNode.getRight().isLeaf()) {
            estim(mMNode.getRight());
        }
        DensityMap estimIntern = estimIntern(!mMNode.getLeft().isLeaf() ? (DensityMap) mMNode.getLeft().getSynopsis() : new DensityMap(mMNode.getLeft().getData(), this._b), !mMNode.getRight().isLeaf() ? (DensityMap) mMNode.getRight().getSynopsis() : new DensityMap(mMNode.getRight().getData(), this._b), mMNode.getOp());
        mMNode.setSynopsis(estimIntern);
        return mMNode.setMatrixCharacteristics(new MatrixCharacteristics(mMNode.getLeft().getRows(), mMNode.getRight().getCols(), estimIntern.getNonZeros()));
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return estim(matrixBlock, matrixBlock2, SparsityEstimator.OpCode.MM);
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        if (isExactMetadataOp(opCode)) {
            return estimExactMetaData(matrixBlock.getMatrixCharacteristics(), matrixBlock2.getMatrixCharacteristics(), opCode).getSparsity();
        }
        DensityMap densityMap = new DensityMap(matrixBlock, this._b);
        return OptimizerUtils.getSparsity(r0.getNumRowsOrig(), r0.getNumColumnsOrig(), estimIntern(densityMap, matrixBlock == matrixBlock2 ? densityMap : new DensityMap(matrixBlock2, this._b), opCode).getNonZeros());
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        return estim(matrixBlock, null, opCode);
    }

    private DensityMap estimIntern(DensityMap densityMap, DensityMap densityMap2, SparsityEstimator.OpCode opCode) {
        switch (opCode) {
            case MM:
                return estimInternMM(densityMap, densityMap2);
            case MULT:
                return estimInternMult(densityMap, densityMap2);
            case PLUS:
                return estimInternPlus(densityMap, densityMap2);
            case RBIND:
            case CBIND:
            case TRANS:
            case DIAG:
            case RESHAPE:
            default:
                throw new NotImplementedException();
        }
    }

    private DensityMap estimInternMM(DensityMap densityMap, DensityMap densityMap2) {
        int numRows = densityMap.getNumRows();
        int numColumns = densityMap.getNumColumns();
        int numColumns2 = densityMap2.getNumColumns();
        MatrixBlock matrixBlock = new MatrixBlock(densityMap.getNumRows(), densityMap2.getNumColumns(), false);
        DenseBlock denseBlock = matrixBlock.allocateBlock().getDenseBlock();
        densityMap.toSparsity();
        densityMap2.toSparsity();
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numColumns; i2++) {
                int colBlockize = densityMap.getColBlockize(i2);
                double d = densityMap.get(i, i2);
                if (d != 0.0d) {
                    for (int i3 = 0; i3 < numColumns2; i3++) {
                        double d2 = densityMap2.get(i2, i3);
                        if (d2 != 0.0d) {
                            double pow = 1.0d - Math.pow(1.0d - (d * d2), colBlockize);
                            double d3 = denseBlock.get(i, i3);
                            denseBlock.set(i, i3, (pow + d3) - (pow * d3));
                        }
                    }
                }
            }
        }
        matrixBlock.recomputeNonZeros();
        return new DensityMap(matrixBlock, densityMap.getNumRowsOrig(), densityMap2.getNumColumnsOrig(), this._b, true);
    }

    private DensityMap estimInternMult(DensityMap densityMap, DensityMap densityMap2) {
        MatrixBlock matrixBlock = new MatrixBlock(densityMap.getNumRows(), densityMap.getNumColumns(), false);
        DenseBlock denseBlock = matrixBlock.allocateBlock().getDenseBlock();
        densityMap.toSparsity();
        densityMap2.toSparsity();
        for (int i = 0; i < densityMap.getNumRows(); i++) {
            for (int i2 = 0; i2 < densityMap.getNumColumns(); i2++) {
                denseBlock.set(i, i2, densityMap.get(i, i2) * densityMap2.get(i, i2));
            }
        }
        matrixBlock.recomputeNonZeros();
        return new DensityMap(matrixBlock, densityMap.getNumRowsOrig(), densityMap.getNumColumnsOrig(), this._b, true);
    }

    private DensityMap estimInternPlus(DensityMap densityMap, DensityMap densityMap2) {
        MatrixBlock matrixBlock = new MatrixBlock(densityMap.getNumRows(), densityMap.getNumColumns(), false);
        DenseBlock denseBlock = matrixBlock.allocateBlock().getDenseBlock();
        densityMap.toSparsity();
        densityMap2.toSparsity();
        for (int i = 0; i < densityMap.getNumRows(); i++) {
            for (int i2 = 0; i2 < densityMap.getNumColumns(); i2++) {
                double d = densityMap.get(i, i2);
                double d2 = densityMap2.get(i, i2);
                denseBlock.set(i, i2, (d + d2) - (d * d2));
            }
        }
        matrixBlock.recomputeNonZeros();
        return new DensityMap(matrixBlock, densityMap.getNumRowsOrig(), densityMap.getNumColumnsOrig(), this._b, true);
    }
}
