package org.apache.sysml.runtime.compress;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.sysml.runtime.compress.ColGroup;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;

/* loaded from: input_file:org/apache/sysml/runtime/compress/ColGroupDDC.class */
public abstract class ColGroupDDC extends ColGroupValue {
    private static final long serialVersionUID = -3204391646123465004L;

    /* loaded from: input_file:org/apache/sysml/runtime/compress/ColGroupDDC$DDCIterator.class */
    private class DDCIterator implements Iterator<IJV> {
        private final int _ru;
        private final boolean _inclZeros;
        private int _rpos;
        private int _cpos;
        private final IJV _buff = new IJV();
        private double _value = 0.0d;

        public DDCIterator(int i, int i2, boolean z) {
            this._rpos = -1;
            this._cpos = -1;
            this._ru = i2;
            this._inclZeros = z;
            this._rpos = i;
            this._cpos = -1;
            getNextValue();
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this._rpos < this._ru;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public IJV next() {
            this._buff.set(this._rpos, ColGroupDDC.this._colIndexes[this._cpos], this._value);
            getNextValue();
            return this._buff;
        }

        private void getNextValue() {
            do {
                boolean z = this._cpos + 1 >= ColGroupDDC.this.getNumCols();
                this._rpos += z ? 1 : 0;
                this._cpos = z ? 0 : this._cpos + 1;
                if (this._rpos >= this._ru) {
                    return;
                }
                this._value = ColGroupDDC.this.getData(this._rpos, this._cpos);
                if (this._inclZeros) {
                    return;
                }
            } while (this._value == 0.0d);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/compress/ColGroupDDC$DDCRowIterator.class */
    private class DDCRowIterator extends ColGroup.ColGroupRowIterator {
        public DDCRowIterator(int i, int i2) {
            super();
        }

        @Override // org.apache.sysml.runtime.compress.ColGroup.ColGroupRowIterator
        public void next(double[] dArr, int i, int i2, boolean z) {
            int numCols = ColGroupDDC.this.getNumCols();
            int code = ColGroupDDC.this.getCode(i) * numCols;
            for (int i3 = 0; i3 < numCols; i3++) {
                dArr[ColGroupDDC.this._colIndexes[i3]] = ColGroupDDC.this._values[code + i3];
            }
        }
    }

    public ColGroupDDC() {
    }

    public ColGroupDDC(int[] iArr, int i, UncompressedBitmap uncompressedBitmap) {
        super(iArr, i, uncompressedBitmap);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupDDC(int[] iArr, int i, double[] dArr) {
        super(iArr, i, dArr);
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            for (int i4 = 0; i4 < this._colIndexes.length; i4++) {
                matrixBlock.quickSetValue(i3, this._colIndexes[i4], getData(i3, i4));
            }
        }
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int[] iArr) {
        int numRows = getNumRows();
        int numCols = getNumCols();
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numCols; i2++) {
                matrixBlock.quickSetValue(i, iArr[getColIndex(i2)], getData(i, i2));
            }
        }
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int i) {
        int numRows = getNumRows();
        for (int i2 = 0; i2 < numRows; i2++) {
            matrixBlock.quickSetValue(i2, 0, getData(i2, i));
        }
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public double get(int i, int i2) {
        int binarySearch = Arrays.binarySearch(this._colIndexes, i2);
        if (binarySearch < 0) {
            throw new RuntimeException("Column index " + i2 + " not in DDC group.");
        }
        return getData(i, binarySearch);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.runtime.compress.ColGroup
    public void countNonZerosPerRow(int[] iArr, int i, int i2) {
        int numCols = getNumCols();
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = 0;
            for (int i5 = 0; i5 < numCols; i5++) {
                i4 += getData(i3, i5) != 0.0d ? 1 : 0;
            }
            int i6 = i3 - i;
            iArr[i6] = iArr[i6] + i4;
        }
    }

    @Override // org.apache.sysml.runtime.compress.ColGroupValue
    public void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, MatrixBlock matrixBlock, int i, int i2) {
        if ((aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlus) || (aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlusSq)) {
            KahanFunction kahanPlusFnObject = aggregateUnaryOperator.aggOp.increOp.fn instanceof KahanPlus ? KahanPlus.getKahanPlusFnObject() : KahanPlusSq.getKahanPlusSqFnObject();
            if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                computeSum(matrixBlock, kahanPlusFnObject);
                return;
            } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                computeRowSums(matrixBlock, kahanPlusFnObject, i, i2);
                return;
            } else {
                if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColSums(matrixBlock, kahanPlusFnObject);
                    return;
                }
                return;
            }
        }
        if (aggregateUnaryOperator.aggOp.increOp.fn instanceof Builtin) {
            if (((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MAX || ((Builtin) aggregateUnaryOperator.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN) {
                Builtin builtin = (Builtin) aggregateUnaryOperator.aggOp.increOp.fn;
                if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                    computeMxx(matrixBlock, builtin, false);
                } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                    computeRowMxx(matrixBlock, builtin, i, i2);
                } else if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColMxx(matrixBlock, builtin, false);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeSum(MatrixBlock matrixBlock, KahanFunction kahanFunction) {
        int numRows = getNumRows();
        int numCols = getNumCols();
        KahanObject kahanObject = new KahanObject(matrixBlock.quickGetValue(0, 0), matrixBlock.quickGetValue(0, 1));
        for (int i = 0; i < numRows; i++) {
            for (int i2 = 0; i2 < numCols; i2++) {
                kahanFunction.execute2(kahanObject, getData(i, i2));
            }
        }
        matrixBlock.quickSetValue(0, 0, kahanObject._sum);
        matrixBlock.quickSetValue(0, 1, kahanObject._correction);
    }

    protected void computeColSums(MatrixBlock matrixBlock, KahanFunction kahanFunction) {
        int numRows = getNumRows();
        int numCols = getNumCols();
        KahanObject[] kahanObjectArr = new KahanObject[getNumCols()];
        for (int i = 0; i < numCols; i++) {
            kahanObjectArr[i] = new KahanObject(matrixBlock.quickGetValue(0, this._colIndexes[i]), matrixBlock.quickGetValue(1, this._colIndexes[i]));
        }
        for (int i2 = 0; i2 < numRows; i2++) {
            for (int i3 = 0; i3 < numCols; i3++) {
                kahanFunction.execute2(kahanObjectArr[i3], getData(i2, i3));
            }
        }
        for (int i4 = 0; i4 < numCols; i4++) {
            matrixBlock.quickSetValue(0, this._colIndexes[i4], kahanObjectArr[i4]._sum);
            matrixBlock.quickSetValue(1, this._colIndexes[i4], kahanObjectArr[i4]._correction);
        }
    }

    protected void computeRowSums(MatrixBlock matrixBlock, KahanFunction kahanFunction, int i, int i2) {
        int numCols = getNumCols();
        KahanObject kahanObject = new KahanObject(0.0d, 0.0d);
        for (int i3 = i; i3 < i2; i3++) {
            kahanObject.set(matrixBlock.quickGetValue(i3, 0), matrixBlock.quickGetValue(i3, 1));
            for (int i4 = 0; i4 < numCols; i4++) {
                kahanFunction.execute2(kahanObject, getData(i3, i4));
            }
            matrixBlock.quickSetValue(i3, 0, kahanObject._sum);
            matrixBlock.quickSetValue(i3, 1, kahanObject._correction);
        }
    }

    protected void computeRowMxx(MatrixBlock matrixBlock, Builtin builtin, int i, int i2) {
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        int numCols = getNumCols();
        for (int i3 = i; i3 < i2; i3++) {
            for (int i4 = 0; i4 < numCols; i4++) {
                denseBlockValues[i3] = builtin.execute(denseBlockValues[i3], getData(i3, i4));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void postScaling(double[] dArr, double[] dArr2) {
        int numCols = getNumCols();
        int numValues = getNumValues();
        int i = 0;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i >= numValues) {
                return;
            }
            double d = dArr[i];
            for (int i4 = 0; i4 < numCols; i4++) {
                int i5 = this._colIndexes[i4];
                dArr2[i5] = dArr2[i5] + (d * this._values[i3 + i4]);
            }
            i++;
            i2 = i3 + numCols;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract double getData(int i);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract double getData(int i, int i2);

    protected abstract void setData(int i, int i2);

    protected abstract int getCode(int i);

    @Override // org.apache.sysml.runtime.compress.ColGroupValue, org.apache.sysml.runtime.compress.ColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize();
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public Iterator<IJV> getIterator(int i, int i2, boolean z, boolean z2) {
        return new DDCIterator(i, i2, z);
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public ColGroup.ColGroupRowIterator getRowIterator(int i, int i2) {
        return new DDCRowIterator(i, i2);
    }
}
