package org.apache.sysml.runtime.transform.encode;

import java.io.IOException;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.transform.TfUtils;
import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysml/runtime/transform/encode/EncoderMVImpute.class */
public class EncoderMVImpute extends Encoder {
    private static final long serialVersionUID = 9057868620144662194L;
    private MVMethod[] _mvMethodList;
    private MVMethod[] _mvscMethodList;
    private BitSet _isMVScaled;
    private CM _varFn;
    private Mean _meanFn;
    private KahanObject[] _meanList;
    private long[] _countList;
    private CM_COV_Object[] _varList;
    private int[] _scnomvList;
    private MVMethod[] _scnomvMethodList;
    private KahanObject[] _scnomvMeanList;
    private long[] _scnomvCountList;
    private CM_COV_Object[] _scnomvVarList;
    private String[] _replacementList;
    private String[] _NAstrings;
    private List<Integer> _rcList;
    private HashMap<Integer, HashMap<String, Long>> _hist;

    /* loaded from: input_file:org/apache/sysml/runtime/transform/encode/EncoderMVImpute$MVMethod.class */
    public enum MVMethod {
        INVALID,
        GLOBAL_MEAN,
        GLOBAL_MODE,
        CONSTANT
    }

    public String[] getReplacements() {
        return this._replacementList;
    }

    public KahanObject[] getMeans() {
        return this._meanList;
    }

    public CM_COV_Object[] getVars() {
        return this._varList;
    }

    public KahanObject[] getMeans_scnomv() {
        return this._scnomvMeanList;
    }

    public CM_COV_Object[] getVars_scnomv() {
        return this._scnomvVarList;
    }

    public EncoderMVImpute(JSONObject jSONObject, String[] strArr, int i) throws JSONException {
        super(null, i);
        this._mvMethodList = null;
        this._mvscMethodList = null;
        this._isMVScaled = null;
        this._varFn = CM.getCMFnObject(CMOperator.AggregateOperationTypes.VARIANCE);
        this._meanFn = Mean.getMeanFnObject();
        this._meanList = null;
        this._countList = null;
        this._varList = null;
        this._scnomvList = null;
        this._scnomvMethodList = null;
        this._scnomvMeanList = null;
        this._scnomvCountList = null;
        this._scnomvVarList = null;
        this._replacementList = null;
        this._NAstrings = null;
        this._rcList = null;
        this._hist = null;
        initColList(TfMetaUtils.parseJsonObjectIDList(jSONObject, strArr, TfUtils.TXMETHOD_IMPUTE));
        parseMethodsAndReplacments(jSONObject);
        this._hist = new HashMap<>();
    }

    public EncoderMVImpute(JSONObject jSONObject, String[] strArr, String[] strArr2, int i) throws JSONException {
        super(null, i);
        this._mvMethodList = null;
        this._mvscMethodList = null;
        this._isMVScaled = null;
        this._varFn = CM.getCMFnObject(CMOperator.AggregateOperationTypes.VARIANCE);
        this._meanFn = Mean.getMeanFnObject();
        this._meanList = null;
        this._countList = null;
        this._varList = null;
        this._scnomvList = null;
        this._scnomvMethodList = null;
        this._scnomvMeanList = null;
        this._scnomvCountList = null;
        this._scnomvVarList = null;
        this._replacementList = null;
        this._NAstrings = null;
        this._rcList = null;
        this._hist = null;
        boolean containsKey = jSONObject.containsKey(TfUtils.TXMETHOD_IMPUTE);
        boolean containsKey2 = jSONObject.containsKey(TfUtils.TXMETHOD_SCALE);
        this._NAstrings = strArr2;
        if (containsKey) {
            JSONObject jSONObject2 = (JSONObject) jSONObject.get(TfUtils.TXMETHOD_IMPUTE);
            JSONArray jSONArray = (JSONArray) jSONObject2.get(TfUtils.JSON_ATTRS);
            JSONArray jSONArray2 = (JSONArray) jSONObject2.get(TfUtils.JSON_MTHD);
            int size = jSONArray.size();
            this._colList = new int[size];
            this._mvMethodList = new MVMethod[size];
            this._meanList = new KahanObject[size];
            this._countList = new long[size];
            this._varList = new CM_COV_Object[size];
            this._isMVScaled = new BitSet(this._colList.length);
            this._isMVScaled.clear();
            for (int i2 = 0; i2 < this._colList.length; i2++) {
                this._colList[i2] = UtilFunctions.toInt(jSONArray.get(i2));
                this._mvMethodList[i2] = MVMethod.values()[UtilFunctions.toInt(jSONArray2.get(i2))];
                this._meanList[i2] = new KahanObject(0.0d, 0.0d);
            }
            this._replacementList = new String[size];
            JSONArray jSONArray3 = (JSONArray) jSONObject2.get(TfUtils.JSON_CONSTS);
            for (int i3 = 0; i3 < jSONArray3.size(); i3++) {
                if (jSONArray3.get(i3) == null) {
                    this._replacementList[i3] = "NaN";
                } else {
                    this._replacementList[i3] = jSONArray3.get(i3).toString();
                }
            }
        } else {
            this._colList = null;
            this._mvMethodList = null;
            this._meanList = null;
            this._countList = null;
            this._replacementList = null;
        }
        if (!containsKey2) {
            this._scnomvCountList = null;
            this._scnomvMeanList = null;
            this._scnomvVarList = null;
            return;
        }
        if (this._colList != null) {
            this._mvscMethodList = new MVMethod[this._colList.length];
        }
        JSONObject jSONObject3 = (JSONObject) jSONObject.get(TfUtils.TXMETHOD_SCALE);
        JSONArray jSONArray4 = (JSONArray) jSONObject3.get(TfUtils.JSON_ATTRS);
        JSONArray jSONArray5 = (JSONArray) jSONObject3.get(TfUtils.JSON_MTHD);
        int size2 = jSONArray4.size();
        int[] iArr = new int[size2];
        int i4 = 0;
        for (int i5 = 0; i5 < size2; i5++) {
            int i6 = UtilFunctions.toInt(jSONArray4.get(i5));
            byte b = (byte) UtilFunctions.toInt(jSONArray5.get(i5));
            iArr[i5] = i6;
            int isApplicable = isApplicable(i6);
            if (isApplicable != -1) {
                this._isMVScaled.set(isApplicable);
                this._mvscMethodList[isApplicable] = MVMethod.values()[b];
                this._varList[isApplicable] = new CM_COV_Object();
            } else {
                i4++;
            }
        }
        if (i4 > 0) {
            this._scnomvList = new int[i4];
            this._scnomvMethodList = new MVMethod[i4];
            this._scnomvMeanList = new KahanObject[i4];
            this._scnomvCountList = new long[i4];
            this._scnomvVarList = new CM_COV_Object[i4];
            int i7 = 0;
            for (int i8 = 0; i8 < size2; i8++) {
                int i9 = UtilFunctions.toInt(jSONArray4.get(i8));
                byte b2 = (byte) UtilFunctions.toInt(jSONArray5.get(i8));
                if (isApplicable(i9) == -1) {
                    this._scnomvList[i7] = i9;
                    this._scnomvMethodList[i7] = MVMethod.values()[b2];
                    this._scnomvMeanList[i7] = new KahanObject(0.0d, 0.0d);
                    this._scnomvVarList[i7] = new CM_COV_Object();
                    i7++;
                }
            }
        }
    }

    private void parseMethodsAndReplacments(JSONObject jSONObject) throws JSONException {
        JSONArray jSONArray = (JSONArray) jSONObject.get(TfUtils.TXMETHOD_IMPUTE);
        this._mvMethodList = new MVMethod[jSONArray.size()];
        this._replacementList = new String[jSONArray.size()];
        this._meanList = new KahanObject[jSONArray.size()];
        this._countList = new long[jSONArray.size()];
        for (int i = 0; i < jSONArray.size(); i++) {
            JSONObject jSONObject2 = (JSONObject) jSONArray.get(i);
            this._mvMethodList[i] = MVMethod.valueOf(jSONObject2.get("method").toString().toUpperCase());
            if (this._mvMethodList[i] == MVMethod.CONSTANT) {
                this._replacementList[i] = jSONObject2.getString("value").toString();
            }
            this._meanList[i] = new KahanObject(0.0d, 0.0d);
        }
    }

    public void prepare(String[] strArr) throws IOException {
        try {
            if (this._colList != null) {
                for (int i = 0; i < this._colList.length; i++) {
                    int i2 = this._colList[i];
                    String unquote = UtilFunctions.unquote(strArr[i2 - 1].trim());
                    try {
                        if (!TfUtils.isNA(this._NAstrings, unquote)) {
                            long[] jArr = this._countList;
                            int i3 = i;
                            jArr[i3] = jArr[i3] + 1;
                            if (this._mvMethodList[i] == MVMethod.GLOBAL_MEAN || this._isMVScaled.get(i)) {
                                double parseToDouble = UtilFunctions.parseToDouble(unquote);
                                this._meanFn.execute2(this._meanList[i], parseToDouble, this._countList[i]);
                                if (this._isMVScaled.get(i) && this._mvscMethodList[i] == MVMethod.GLOBAL_MODE) {
                                    this._varFn.execute(this._varList[i], parseToDouble);
                                }
                            }
                        }
                    } catch (NumberFormatException e) {
                        throw new RuntimeException("Encountered \"" + unquote + "\" in column ID \"" + i2 + "\", when expecting a numeric value. Consider adding \"" + unquote + "\" to na.strings, along with an appropriate imputation method.");
                    }
                }
            }
            if (this._scnomvList != null) {
                for (int i4 = 0; i4 < this._scnomvList.length; i4++) {
                    double parseToDouble2 = UtilFunctions.parseToDouble(UtilFunctions.unquote(strArr[this._scnomvList[i4] - 1].trim()));
                    long[] jArr2 = this._scnomvCountList;
                    int i5 = i4;
                    jArr2[i5] = jArr2[i5] + 1;
                    this._meanFn.execute2(this._scnomvMeanList[i4], parseToDouble2, this._scnomvCountList[i4]);
                    if (this._scnomvMethodList[i4] == MVMethod.GLOBAL_MODE) {
                        this._varFn.execute(this._scnomvVarList[i4], parseToDouble2);
                    }
                }
            }
        } catch (Exception e2) {
            throw new IOException(e2);
        }
    }

    public MVMethod getMethod(int i) {
        int isApplicable = isApplicable(i);
        return isApplicable == -1 ? MVMethod.INVALID : this._mvMethodList[isApplicable];
    }

    public long getNonMVCount(int i) {
        int isApplicable = isApplicable(i);
        if (isApplicable == -1) {
            return 0L;
        }
        return this._countList[isApplicable];
    }

    public String getReplacement(int i) {
        int isApplicable = isApplicable(i);
        if (isApplicable == -1) {
            return null;
        }
        return this._replacementList[isApplicable];
    }

    @Override // org.apache.sysml.runtime.transform.encode.Encoder
    public MatrixBlock encode(FrameBlock frameBlock, MatrixBlock matrixBlock) {
        build(frameBlock);
        return apply(frameBlock, matrixBlock);
    }

    @Override // org.apache.sysml.runtime.transform.encode.Encoder
    public void build(FrameBlock frameBlock) {
        for (int i = 0; i < this._colList.length; i++) {
            try {
                int i2 = this._colList[i];
                if (this._mvMethodList[i] == MVMethod.GLOBAL_MEAN) {
                    long j = this._countList[i];
                    for (int i3 = 0; i3 < frameBlock.getNumRows(); i3++) {
                        this._meanFn.execute2(this._meanList[i], UtilFunctions.objectToDouble(frameBlock.getSchema()[i2 - 1], frameBlock.get(i3, i2 - 1)), j + i3 + 1);
                    }
                    this._replacementList[i] = String.valueOf(this._meanList[i]._sum);
                    long[] jArr = this._countList;
                    int i4 = i;
                    jArr[i4] = jArr[i4] + frameBlock.getNumRows();
                } else if (this._mvMethodList[i] == MVMethod.GLOBAL_MODE) {
                    HashMap<String, Long> hashMap = this._hist.containsKey(Integer.valueOf(i2)) ? this._hist.get(Integer.valueOf(i2)) : new HashMap<>();
                    for (int i5 = 0; i5 < frameBlock.getNumRows(); i5++) {
                        String valueOf = String.valueOf(frameBlock.get(i5, i2 - 1));
                        if (valueOf != null && !valueOf.isEmpty()) {
                            Long l = hashMap.get(valueOf);
                            hashMap.put(valueOf, Long.valueOf(l != null ? l.longValue() + 1 : 1L));
                        }
                    }
                    this._hist.put(Integer.valueOf(i2), hashMap);
                    long j2 = Long.MIN_VALUE;
                    for (Map.Entry<String, Long> entry : hashMap.entrySet()) {
                        if (entry.getValue().longValue() > j2) {
                            this._replacementList[i] = entry.getKey();
                            j2 = entry.getValue().longValue();
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    @Override // org.apache.sysml.runtime.transform.encode.Encoder
    public MatrixBlock apply(FrameBlock frameBlock, MatrixBlock matrixBlock) {
        for (int i = 0; i < frameBlock.getNumRows(); i++) {
            for (int i2 = 0; i2 < this._colList.length; i2++) {
                int i3 = this._colList[i2];
                if (Double.isNaN(matrixBlock.quickGetValue(i, i3 - 1))) {
                    matrixBlock.quickSetValue(i, i3 - 1, Double.parseDouble(this._replacementList[i2]));
                }
            }
        }
        return matrixBlock;
    }

    @Override // org.apache.sysml.runtime.transform.encode.Encoder
    public FrameBlock getMetaData(FrameBlock frameBlock) {
        for (int i = 0; i < this._colList.length; i++) {
            frameBlock.getColumnMetadata(this._colList[i] - 1).setMvValue(this._replacementList[i]);
        }
        return frameBlock;
    }

    @Override // org.apache.sysml.runtime.transform.encode.Encoder
    public void initMetaData(FrameBlock frameBlock) {
        for (int i = 0; i < this._colList.length; i++) {
            int i2 = this._colList[i];
            String unquote = UtilFunctions.unquote(frameBlock.getColumnMetadata(i2 - 1).getMvValue());
            if (this._rcList.contains(Integer.valueOf(i2))) {
                Long l = frameBlock.getRecodeMap(i2 - 1).get(unquote);
                if (l == null) {
                    throw new RuntimeException("Missing recode value for impute value '" + unquote + "' (colID=" + i2 + ").");
                }
                this._replacementList[i] = l.toString();
            } else {
                this._replacementList[i] = unquote;
            }
        }
    }

    public void initRecodeIDList(List<Integer> list) {
        this._rcList = list;
    }

    public HashMap<String, Long> getHistogram(int i) {
        return this._hist.get(Integer.valueOf(i));
    }
}
