package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaData;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;

/* loaded from: input_file:org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.class */
public class RewriteRemovePersistentReadWrite extends HopRewriteRule {
    private static final Log LOG = LogFactory.getLog(RewriteRemovePersistentReadWrite.class.getName());
    private HashSet<String> _inputs;
    private HashSet<String> _outputs;
    private HashMap<String, MetaData> _inputsMeta;

    public RewriteRemovePersistentReadWrite(String[] strArr, String[] strArr2) {
        this(strArr, strArr2, null);
    }

    public RewriteRemovePersistentReadWrite(String[] strArr, String[] strArr2, LocalVariableMap localVariableMap) {
        this._inputs = null;
        this._outputs = null;
        this._inputsMeta = null;
        this._inputs = new HashSet<>();
        for (String str : strArr) {
            this._inputs.add(str);
        }
        this._outputs = new HashSet<>();
        for (String str2 : strArr2) {
            this._outputs.add(str2);
        }
        this._inputsMeta = new HashMap<>();
        if (localVariableMap != null) {
            for (String str3 : strArr) {
                Data data = localVariableMap.get(str3);
                if (data != null && (data instanceof CacheableData)) {
                    this._inputsMeta.put(str3, ((CacheableData) data).getMetaData());
                }
            }
        }
    }

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        if (arrayList == null) {
            return null;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rule_RemovePersistentDataOp(it.next());
        }
        return arrayList;
    }

    @Override // org.apache.sysml.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return hop;
        }
        rule_RemovePersistentDataOp(hop);
        return hop;
    }

    private void rule_RemovePersistentDataOp(Hop hop) {
        if (hop.isVisited()) {
            return;
        }
        ArrayList<Hop> input = hop.getInput();
        for (int i = 0; i < input.size(); i++) {
            rule_RemovePersistentDataOp(input.get(i));
        }
        if (hop instanceof DataOp) {
            DataOp dataOp = (DataOp) hop;
            switch (dataOp.getDataOpType()) {
                case PERSISTENTREAD:
                    if (this._inputs.contains(dataOp.getName())) {
                        dataOp.setDataOpType(Hop.DataOpTypes.TRANSIENTREAD);
                        if (hop.getDataType() == Expression.DataType.SCALAR) {
                            dataOp.removeInput(DataExpression.IO_FILENAME);
                        }
                        if (dataOp.requiresReblock() && this._inputsMeta.containsKey(dataOp.getName()) && (this._inputsMeta.get(dataOp.getName()) instanceof MetaDataFormat)) {
                            MetaDataFormat metaDataFormat = (MetaDataFormat) this._inputsMeta.get(dataOp.getName());
                            MatrixCharacteristics matrixCharacteristics = metaDataFormat.getMatrixCharacteristics();
                            boolean z = matrixCharacteristics.getRowsPerBlock() == dataOp.getRowsInBlock() && matrixCharacteristics.getColsPerBlock() == dataOp.getColsInBlock();
                            if (metaDataFormat.getInputInfo() == InputInfo.BinaryBlockInputInfo && (z || dataOp.getDataType() == Expression.DataType.FRAME)) {
                                dataOp.setRequiresReblock(false);
                                break;
                            }
                        }
                    } else {
                        LOG.warn("Non-registered persistent read of variable '" + dataOp.getName() + "' (line " + dataOp.getBeginLine() + ").");
                        break;
                    }
                    break;
                case PERSISTENTWRITE:
                    if (this._outputs.contains(dataOp.getName())) {
                        dataOp.setDataOpType(Hop.DataOpTypes.TRANSIENTWRITE);
                        dataOp.setRowsInBlock(dataOp.getInput().get(0).getRowsInBlock());
                        dataOp.setColsInBlock(dataOp.getInput().get(0).getColsInBlock());
                        if (hop.getDataType() == Expression.DataType.SCALAR) {
                            dataOp.removeInput(DataExpression.IO_FILENAME);
                            break;
                        }
                    } else {
                        LOG.warn("Non-registered persistent write of variable '" + dataOp.getName() + "' (line " + dataOp.getBeginLine() + ").");
                        break;
                    }
                    break;
            }
        }
        hop.setVisited();
    }
}
