package org.apache.sysml.hops.estim;

import java.util.BitSet;
import java.util.stream.IntStream;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.estim.SparsityEstimator;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;

/* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorBitsetMM.class */
public class EstimatorBitsetMM extends SparsityEstimator {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorBitsetMM$BitsetMatrix.class */
    public static abstract class BitsetMatrix {
        protected final int _rlen;
        protected final int _clen;
        protected long _nonZeros = 0;

        public BitsetMatrix(int i, int i2) {
            this._rlen = i;
            this._clen = i2;
        }

        public int getNumRows() {
            return this._rlen;
        }

        public int getNumColumns() {
            return this._clen;
        }

        public long getNonZeros() {
            return this._nonZeros;
        }

        public abstract boolean get(int i, int i2);

        public abstract void set(int i, int i2);

        protected void init(MatrixBlock matrixBlock) {
            if (matrixBlock.isEmptyBlock(false)) {
                return;
            }
            if (!SparsityEstimator.MULTI_THREADED_BUILD || matrixBlock.getNonZeros() <= 10240) {
                buildIntern(matrixBlock, 0, matrixBlock.getNumRows());
            } else {
                int localParallelism = 4 * InfrastructureAnalyzer.getLocalParallelism();
                int ceil = (int) Math.ceil(this._rlen / localParallelism);
                IntStream.range(0, localParallelism).parallel().forEach(i -> {
                    buildIntern(matrixBlock, i * ceil, Math.min((i + 1) * ceil, this._rlen));
                });
            }
            this._nonZeros = matrixBlock.getNonZeros();
        }

        public BitsetMatrix matMult(BitsetMatrix bitsetMatrix) {
            BitsetMatrix createBitSetMatrix = createBitSetMatrix(this._rlen, bitsetMatrix._clen);
            if (getNonZeros() == 0 || bitsetMatrix.getNonZeros() == 0) {
                return createBitSetMatrix;
            }
            long j = (this._rlen * this._clen) + (bitsetMatrix._rlen * bitsetMatrix._clen);
            if (!SparsityEstimator.MULTI_THREADED_ESTIM || j <= 10240) {
                createBitSetMatrix._nonZeros = matMultIntern(bitsetMatrix, createBitSetMatrix, 0, this._rlen);
            } else {
                int localParallelism = 4 * InfrastructureAnalyzer.getLocalParallelism();
                int ceil = (int) Math.ceil(this._rlen / localParallelism);
                createBitSetMatrix._nonZeros = IntStream.range(0, localParallelism).parallel().mapToLong(i -> {
                    return matMultIntern(bitsetMatrix, createBitSetMatrix, i * ceil, Math.min((i + 1) * ceil, this._rlen));
                }).sum();
            }
            return createBitSetMatrix;
        }

        protected abstract BitsetMatrix createBitSetMatrix(int i, int i2);

        protected abstract void buildIntern(MatrixBlock matrixBlock, int i, int i2);

        protected abstract long matMultIntern(BitsetMatrix bitsetMatrix, BitsetMatrix bitsetMatrix2, int i, int i2);

        protected abstract BitsetMatrix and(BitsetMatrix bitsetMatrix);

        protected abstract BitsetMatrix or(BitsetMatrix bitsetMatrix);

        protected abstract BitsetMatrix rbind(BitsetMatrix bitsetMatrix);

        protected abstract BitsetMatrix cbind(BitsetMatrix bitsetMatrix);

        protected abstract BitsetMatrix flip();

        public BitsetMatrix transpose() {
            BitsetMatrix1 bitsetMatrix1 = new BitsetMatrix1(getNumRows(), getNumColumns());
            for (int i = 0; i < getNumColumns(); i++) {
                for (int i2 = 0; i2 < getNumRows(); i2++) {
                    if (get(i, i2)) {
                        bitsetMatrix1.set(i2, i);
                    }
                }
            }
            return bitsetMatrix1;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorBitsetMM$BitsetMatrix1.class */
    public static class BitsetMatrix1 extends BitsetMatrix {
        private final int _rowLen;
        private final long[] _data;

        public BitsetMatrix1(int i, int i2) {
            super(i, i2);
            this._rowLen = (int) Math.ceil(i2 / 64.0d);
            this._data = new long[i * this._rowLen];
        }

        public BitsetMatrix1(MatrixBlock matrixBlock) {
            this(matrixBlock.getNumRows(), matrixBlock.getNumColumns());
            init(matrixBlock);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected BitsetMatrix createBitSetMatrix(int i, int i2) {
            return new BitsetMatrix1(i, i2);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected void buildIntern(MatrixBlock matrixBlock, int i, int i2) {
            if (!matrixBlock.isInSparseFormat()) {
                DenseBlock denseBlock = matrixBlock.getDenseBlock();
                for (int i3 = i; i3 < i2; i3++) {
                    double[] values = denseBlock.values(i3);
                    int pos = denseBlock.pos(i3);
                    for (int i4 = 0; i4 < matrixBlock.getNumColumns(); i4++) {
                        if (values[pos + i4] != 0.0d) {
                            set(i3, i4);
                        }
                    }
                }
                return;
            }
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i5 = i; i5 < i2; i5++) {
                if (!sparseBlock.isEmpty(i5)) {
                    int size = sparseBlock.size(i5);
                    int pos2 = sparseBlock.pos(i5);
                    int[] indexes = sparseBlock.indexes(i5);
                    for (int i6 = pos2; i6 < pos2 + size; i6++) {
                        set(i5, indexes[i6]);
                    }
                }
            }
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected long matMultIntern(BitsetMatrix bitsetMatrix, BitsetMatrix bitsetMatrix2, int i, int i2) {
            BitsetMatrix1 bitsetMatrix1 = (BitsetMatrix1) bitsetMatrix;
            long[] jArr = bitsetMatrix1._data;
            long[] jArr2 = ((BitsetMatrix1) bitsetMatrix2)._data;
            int i3 = this._clen;
            int i4 = bitsetMatrix1._clen;
            int i5 = bitsetMatrix1._rowLen;
            long j = 0;
            for (int i6 = i; i6 < i2; i6 += 32) {
                int min = Math.min(i2, i6 + 32);
                for (int i7 = 0; i7 < i3; i7 += 24) {
                    int min2 = Math.min(i3, i7 + 24);
                    int i8 = 0;
                    while (true) {
                        int i9 = i8;
                        if (i9 < i4) {
                            int ceil = (int) Math.ceil((Math.min(i4, i9 + 65536) - i9) / 64.0d);
                            int i10 = i9 / 64;
                            int i11 = i6;
                            int i12 = i11 * this._rowLen;
                            while (true) {
                                int i13 = i12;
                                if (i11 < min) {
                                    for (int i14 = i7; i14 < min2; i14++) {
                                        if (getCol(i13, i14)) {
                                            or(jArr, jArr2, (i14 * i5) + i10, (i11 * i5) + i10, ceil);
                                        }
                                    }
                                    i11++;
                                    i12 = i13 + this._rowLen;
                                }
                            }
                            i8 = i9 + 65536;
                        }
                    }
                }
                j += card(jArr2, i6 * i5, (min - i6) * i5);
            }
            return j;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix and(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix1 bitsetMatrix1 = (BitsetMatrix1) bitsetMatrix;
            BitsetMatrix1 bitsetMatrix12 = new BitsetMatrix1(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix12._data[i] = this._data[i] & bitsetMatrix1._data[i];
            }
            bitsetMatrix12._nonZeros = card(bitsetMatrix12._data, 0, bitsetMatrix12._data.length);
            return bitsetMatrix12;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix or(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix1 bitsetMatrix1 = (BitsetMatrix1) bitsetMatrix;
            BitsetMatrix1 bitsetMatrix12 = new BitsetMatrix1(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix12._data[i] = this._data[i] | bitsetMatrix1._data[i];
            }
            bitsetMatrix12._nonZeros = card(bitsetMatrix12._data, 0, bitsetMatrix12._data.length);
            return bitsetMatrix12;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix rbind(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix1 bitsetMatrix1 = (BitsetMatrix1) bitsetMatrix;
            BitsetMatrix1 bitsetMatrix12 = new BitsetMatrix1(getNumRows() + bitsetMatrix.getNumRows(), getNumColumns());
            System.arraycopy(this._data, 0, bitsetMatrix12._data, 0, this._rlen * this._rowLen);
            System.arraycopy(bitsetMatrix1._data, 0, bitsetMatrix12._data, this._rlen * this._rowLen, bitsetMatrix1._rlen * this._rowLen);
            bitsetMatrix12._nonZeros = card(bitsetMatrix12._data, 0, bitsetMatrix12._data.length);
            return bitsetMatrix12;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix cbind(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix1)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix1 bitsetMatrix1 = (BitsetMatrix1) bitsetMatrix;
            BitsetMatrix1 bitsetMatrix12 = new BitsetMatrix1(getNumRows(), getNumColumns() + bitsetMatrix.getNumColumns());
            for (int i = 0; i < getNumRows(); i++) {
                System.arraycopy(this._data, i * this._rowLen, bitsetMatrix12._data, i * bitsetMatrix12._rowLen, this._rowLen);
            }
            for (int i2 = 0; i2 < getNumRows(); i2++) {
                for (int i3 = 0; i3 < bitsetMatrix1.getNumColumns(); i3++) {
                    if (bitsetMatrix1.get(i2, i3)) {
                        bitsetMatrix12.set(i2, getNumColumns() + i3);
                    }
                }
            }
            bitsetMatrix12._nonZeros = card(bitsetMatrix12._data, 0, bitsetMatrix12._data.length);
            return bitsetMatrix12;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix flip() {
            BitsetMatrix1 bitsetMatrix1 = new BitsetMatrix1(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix1._data[i] = this._data[i] ^ (-1);
            }
            bitsetMatrix1._nonZeros = (getNumRows() * getNumColumns()) - getNonZeros();
            return bitsetMatrix1;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public void set(int i, int i2) {
            int i3 = i * this._rowLen;
            int wordIndex = wordIndex(i2);
            long[] jArr = this._data;
            int i4 = i3 + wordIndex;
            jArr[i4] = jArr[i4] | (1 << i2);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public boolean get(int i, int i2) {
            return (this._data[(i * this._rowLen) + wordIndex(i2)] & (1 << i2)) != 0;
        }

        private boolean getCol(int i, int i2) {
            return (this._data[i + wordIndex(i2)] & (1 << i2)) != 0;
        }

        private static int wordIndex(int i) {
            return i >> 6;
        }

        private static int card(long[] jArr, int i, int i2) {
            int i3 = 0;
            for (int i4 = i; i4 < i + i2; i4++) {
                i3 += Long.bitCount(jArr[i4]);
            }
            return i3;
        }

        private static void or(long[] jArr, long[] jArr2, int i, int i2, int i3) {
            int i4 = i3 % 8;
            int i5 = 0;
            while (i5 < i4) {
                int i6 = i2;
                jArr2[i6] = jArr2[i6] | jArr[i];
                i5++;
                i++;
                i2++;
            }
            int i7 = i4;
            while (i7 < i3) {
                int i8 = i2 + 0;
                jArr2[i8] = jArr2[i8] | jArr[i + 0];
                int i9 = i2 + 1;
                jArr2[i9] = jArr2[i9] | jArr[i + 1];
                int i10 = i2 + 2;
                jArr2[i10] = jArr2[i10] | jArr[i + 2];
                int i11 = i2 + 3;
                jArr2[i11] = jArr2[i11] | jArr[i + 3];
                int i12 = i2 + 4;
                jArr2[i12] = jArr2[i12] | jArr[i + 4];
                int i13 = i2 + 5;
                jArr2[i13] = jArr2[i13] | jArr[i + 5];
                int i14 = i2 + 6;
                jArr2[i14] = jArr2[i14] | jArr[i + 6];
                int i15 = i2 + 7;
                jArr2[i15] = jArr2[i15] | jArr[i + 7];
                i7 += 8;
                i += 8;
                i2 += 8;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysml/hops/estim/EstimatorBitsetMM$BitsetMatrix2.class */
    private static class BitsetMatrix2 extends BitsetMatrix {
        private BitSet[] _data;

        public BitsetMatrix2(int i, int i2) {
            super(i, i2);
            this._data = new BitSet[this._rlen];
        }

        public BitsetMatrix2(MatrixBlock matrixBlock) {
            this(matrixBlock.getNumRows(), matrixBlock.getNumColumns());
            init(matrixBlock);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected BitsetMatrix createBitSetMatrix(int i, int i2) {
            return new BitsetMatrix2(i, i2);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected void buildIntern(MatrixBlock matrixBlock, int i, int i2) {
            int numColumns = matrixBlock.getNumColumns();
            if (!matrixBlock.isInSparseFormat()) {
                DenseBlock denseBlock = matrixBlock.getDenseBlock();
                for (int i3 = i; i3 < i2; i3++) {
                    BitSet bitSet = new BitSet(numColumns);
                    this._data[i3] = bitSet;
                    double[] values = denseBlock.values(i3);
                    int pos = denseBlock.pos(i3);
                    for (int i4 = 0; i4 < matrixBlock.getNumColumns(); i4++) {
                        if (values[pos + i4] != 0.0d) {
                            bitSet.set(i4);
                        }
                    }
                }
                return;
            }
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i5 = i; i5 < i2; i5++) {
                if (!sparseBlock.isEmpty(i5)) {
                    BitSet bitSet2 = new BitSet(numColumns);
                    this._data[i5] = bitSet2;
                    int size = sparseBlock.size(i5);
                    int pos2 = sparseBlock.pos(i5);
                    int[] indexes = sparseBlock.indexes(i5);
                    for (int i6 = pos2; i6 < pos2 + size; i6++) {
                        bitSet2.set(indexes[i6]);
                    }
                }
            }
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected long matMultIntern(BitsetMatrix bitsetMatrix, BitsetMatrix bitsetMatrix2, int i, int i2) {
            BitsetMatrix2 bitsetMatrix22 = (BitsetMatrix2) bitsetMatrix;
            BitsetMatrix2 bitsetMatrix23 = (BitsetMatrix2) bitsetMatrix2;
            int i3 = this._clen;
            int i4 = bitsetMatrix22._clen;
            long j = 0;
            for (int i5 = i; i5 < i2; i5++) {
                BitSet bitSet = this._data[i5];
                if (bitSet != null) {
                    BitSet bitSet2 = new BitSet(i4);
                    bitsetMatrix23._data[i5] = bitSet2;
                    for (int i6 = 0; i6 < i3; i6++) {
                        BitSet bitSet3 = bitsetMatrix22._data[i6];
                        if (bitSet.get(i6) && bitSet3 != null) {
                            bitSet2.or(bitSet3);
                        }
                    }
                    j += bitSet2.cardinality();
                }
            }
            return j;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix and(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix2 bitsetMatrix2 = (BitsetMatrix2) bitsetMatrix;
            BitsetMatrix2 bitsetMatrix22 = new BitsetMatrix2(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix22._data[i] = (BitSet) this._data[i].clone();
                bitsetMatrix22._data[i].and(bitsetMatrix2._data[i]);
                bitsetMatrix22._nonZeros += bitsetMatrix22._data[i].cardinality();
            }
            return bitsetMatrix22;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix or(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix2 bitsetMatrix2 = (BitsetMatrix2) bitsetMatrix;
            BitsetMatrix2 bitsetMatrix22 = new BitsetMatrix2(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix22._data[i] = (BitSet) this._data[i].clone();
                bitsetMatrix22._data[i].or(bitsetMatrix2._data[i]);
                bitsetMatrix22._nonZeros += bitsetMatrix22._data[i].cardinality();
            }
            return bitsetMatrix22;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix rbind(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix2 bitsetMatrix2 = (BitsetMatrix2) bitsetMatrix;
            BitsetMatrix2 bitsetMatrix22 = new BitsetMatrix2(getNumRows() + bitsetMatrix.getNumRows(), getNumColumns());
            System.arraycopy(this._data, 0, bitsetMatrix22._data, 0, this._rlen);
            System.arraycopy(bitsetMatrix2._data, 0, bitsetMatrix22._data, this._rlen, bitsetMatrix2._rlen);
            return bitsetMatrix22;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        protected BitsetMatrix cbind(BitsetMatrix bitsetMatrix) {
            if (!(bitsetMatrix instanceof BitsetMatrix2)) {
                throw new HopsException("Incompatible bitset types: " + getClass().getSimpleName() + " and " + bitsetMatrix.getClass().getSimpleName());
            }
            BitsetMatrix2 bitsetMatrix2 = (BitsetMatrix2) bitsetMatrix;
            BitsetMatrix2 bitsetMatrix22 = new BitsetMatrix2(getNumRows(), getNumColumns() + bitsetMatrix.getNumColumns());
            for (int i = 0; i < getNumRows(); i++) {
                bitsetMatrix22._data[i] = (BitSet) this._data[i].clone();
            }
            for (int i2 = 0; i2 < getNumRows(); i2++) {
                for (int i3 = 0; i3 < bitsetMatrix2.getNumColumns(); i3++) {
                    if (bitsetMatrix2.get(i2, i3)) {
                        bitsetMatrix22.set(i2, getNumColumns() + i3);
                    }
                }
                bitsetMatrix22._nonZeros += bitsetMatrix22._data[i2].cardinality();
            }
            return bitsetMatrix22;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public BitsetMatrix flip() {
            BitsetMatrix2 bitsetMatrix2 = new BitsetMatrix2(getNumRows(), getNumColumns());
            for (int i = 0; i < this._data.length; i++) {
                bitsetMatrix2._data[i] = (BitSet) this._data[i].clone();
                bitsetMatrix2._data[i].flip(0, this._data[i].size());
                bitsetMatrix2._nonZeros += bitsetMatrix2._data[i].cardinality();
            }
            return bitsetMatrix2;
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public boolean get(int i, int i2) {
            return this._data[i].get(i2);
        }

        @Override // org.apache.sysml.hops.estim.EstimatorBitsetMM.BitsetMatrix
        public void set(int i, int i2) {
            this._data[i].set(i2);
        }
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public MatrixCharacteristics estim(MMNode mMNode) {
        if (!mMNode.getLeft().isLeaf()) {
            estim(mMNode.getLeft());
        }
        if (!mMNode.getRight().isLeaf()) {
            estim(mMNode.getRight());
        }
        BitsetMatrix estimInternal = estimInternal(!mMNode.getLeft().isLeaf() ? (BitsetMatrix) mMNode.getLeft().getSynopsis() : new BitsetMatrix1(mMNode.getLeft().getData()), !mMNode.getRight().isLeaf() ? (BitsetMatrix) mMNode.getRight().getSynopsis() : new BitsetMatrix1(mMNode.getRight().getData()), mMNode.getOp());
        mMNode.setSynopsis(estimInternal);
        return mMNode.setMatrixCharacteristics(new MatrixCharacteristics(estimInternal.getNumRows(), estimInternal.getNumColumns(), estimInternal.getNonZeros()));
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return estim(matrixBlock, matrixBlock2, SparsityEstimator.OpCode.MM);
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, SparsityEstimator.OpCode opCode) {
        if (isExactMetadataOp(opCode)) {
            return estimExactMetaData(matrixBlock.getMatrixCharacteristics(), matrixBlock2.getMatrixCharacteristics(), opCode).getSparsity();
        }
        BitsetMatrix1 bitsetMatrix1 = new BitsetMatrix1(matrixBlock);
        return OptimizerUtils.getSparsity(r0.getNumRows(), r0.getNumColumns(), estimInternal(bitsetMatrix1, matrixBlock == matrixBlock2 ? bitsetMatrix1 : new BitsetMatrix1(matrixBlock2), opCode).getNonZeros());
    }

    @Override // org.apache.sysml.hops.estim.SparsityEstimator
    public double estim(MatrixBlock matrixBlock, SparsityEstimator.OpCode opCode) {
        if (isExactMetadataOp(opCode)) {
            return estimExactMetaData(matrixBlock.getMatrixCharacteristics(), null, opCode).getSparsity();
        }
        return OptimizerUtils.getSparsity(r0.getNumRows(), r0.getNumColumns(), estimInternal(new BitsetMatrix1(matrixBlock), null, opCode).getNonZeros());
    }

    private BitsetMatrix estimInternal(BitsetMatrix bitsetMatrix, BitsetMatrix bitsetMatrix2, SparsityEstimator.OpCode opCode) {
        switch (opCode) {
            case MM:
                return bitsetMatrix.matMult(bitsetMatrix2);
            case MULT:
                return bitsetMatrix.and(bitsetMatrix2);
            case PLUS:
                return bitsetMatrix.or(bitsetMatrix2);
            case RBIND:
                return bitsetMatrix.rbind(bitsetMatrix2);
            case CBIND:
                return bitsetMatrix.cbind(bitsetMatrix2);
            case NEQZERO:
                return bitsetMatrix;
            case EQZERO:
                return bitsetMatrix.flip();
            case TRANS:
                return bitsetMatrix.transpose();
            case DIAG:
            case RESHAPE:
            default:
                throw new NotImplementedException();
        }
    }
}
