package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.opt.InterestingPoint;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;

/* loaded from: input_file:org/apache/sysml/hops/codegen/opt/PlanAnalyzer.class */
public class PlanAnalyzer {
    private static final Log LOG = LogFactory.getLog(PlanAnalyzer.class.getName());

    public static Collection<PlanPartition> analyzePlanPartitions(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList, boolean z) {
        Collection<HashSet<Long>> connectedSubGraphs = getConnectedSubGraphs(cPlanMemoTable, arrayList);
        ArrayList arrayList2 = new ArrayList();
        for (HashSet<Long> hashSet : connectedSubGraphs) {
            HashSet<Long> partitionRootNodes = getPartitionRootNodes(cPlanMemoTable, hashSet);
            HashSet<Long> partitionInputNodes = getPartitionInputNodes(partitionRootNodes, hashSet, cPlanMemoTable);
            ArrayList<Long> materializationPoints = getMaterializationPoints(partitionRootNodes, hashSet, cPlanMemoTable);
            arrayList2.add(new PlanPartition(hashSet, partitionRootNodes, partitionInputNodes, getNodesWithNonPartitionConsumers(partitionRootNodes, hashSet, cPlanMemoTable), materializationPoints, !z ? null : getMaterializationPointsExt(partitionRootNodes, hashSet, materializationPoints, cPlanMemoTable), hashSet.stream().anyMatch(l -> {
                return cPlanMemoTable.contains(l.longValue(), TemplateBase.TemplateType.OUTER);
            })));
        }
        return arrayList2;
    }

    private static Collection<HashSet<Long>> getConnectedSubGraphs(CPlanMemoTable cPlanMemoTable, ArrayList<Hop> arrayList) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry : cPlanMemoTable.getPlans().entrySet()) {
            for (CPlanMemoTable.MemoTableEntry memoTableEntry : entry.getValue()) {
                for (int i = 0; i < 3; i++) {
                    if (memoTableEntry.isPlanRef(i)) {
                        if (!hashMap.containsKey(Long.valueOf(memoTableEntry.input(i)))) {
                            hashMap.put(Long.valueOf(memoTableEntry.input(i)), new HashSet());
                        }
                        ((HashSet) hashMap.get(Long.valueOf(memoTableEntry.input(i)))).add(entry.getKey());
                    }
                }
            }
        }
        ArrayList arrayList2 = new ArrayList();
        HashSet hashSet = new HashSet();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> entry2 : cPlanMemoTable.getPlans().entrySet()) {
            if (!hashMap.containsKey(entry2.getKey())) {
                HashSet<Long> rGetConnectedSubGraphs = rGetConnectedSubGraphs(entry2.getKey().longValue(), cPlanMemoTable, hashMap, hashSet, new HashSet());
                if (!rGetConnectedSubGraphs.isEmpty()) {
                    arrayList2.add(rGetConnectedSubGraphs);
                }
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Connected sub graphs: " + arrayList2.size());
        }
        return arrayList2;
    }

    private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable cPlanMemoTable, HashSet<Long> hashSet) {
        HashSet hashSet2 = new HashSet();
        Iterator<Long> it = hashSet.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            if (cPlanMemoTable.contains(next.longValue())) {
                for (CPlanMemoTable.MemoTableEntry memoTableEntry : cPlanMemoTable.get(next.longValue())) {
                    hashSet2.add(Long.valueOf(memoTableEntry.input1));
                    hashSet2.add(Long.valueOf(memoTableEntry.input2));
                    hashSet2.add(Long.valueOf(memoTableEntry.input3));
                }
            }
        }
        HashSet<Long> hashSet3 = new HashSet<>();
        Iterator<Long> it2 = hashSet.iterator();
        while (it2.hasNext()) {
            Long next2 = it2.next();
            if (!hashSet2.contains(next2)) {
                hashSet3.add(next2);
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Partition root points: " + Arrays.toString(hashSet3.toArray(new Long[0])));
        }
        return hashSet3;
    }

    private static ArrayList<Long> getMaterializationPoints(HashSet<Long> hashSet, HashSet<Long> hashSet2, CPlanMemoTable cPlanMemoTable) {
        ArrayList<Long> arrayList = new ArrayList<>();
        HashSet hashSet3 = new HashSet();
        Iterator<Long> it = hashSet.iterator();
        while (it.hasNext()) {
            rCollectMaterializationPoints(cPlanMemoTable.getHopRefs().get(it.next()), hashSet3, hashSet2, hashSet, arrayList);
        }
        arrayList.removeIf(l -> {
            return hashSet.contains(l) || HopRewriteUtils.isTsmmInput(cPlanMemoTable.getHopRefs().get(l));
        });
        if (LOG.isTraceEnabled()) {
            LOG.trace("Partition materialization points: " + Arrays.toString(arrayList.toArray(new Long[0])));
        }
        return arrayList;
    }

    private static void rCollectMaterializationPoints(Hop hop, HashSet<Long> hashSet, HashSet<Long> hashSet2, HashSet<Long> hashSet3, ArrayList<Long> arrayList) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rCollectMaterializationPoints(it.next(), hashSet, hashSet2, hashSet3, arrayList);
        }
        if (isMaterializationPointCandidate(hop, hashSet2, hashSet3)) {
            arrayList.add(Long.valueOf(hop.getHopID()));
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static boolean isMaterializationPointCandidate(Hop hop, HashSet<Long> hashSet, HashSet<Long> hashSet2) {
        return hop.getParent().size() >= 2 && hop.getDataType().isMatrix() && hashSet.contains(Long.valueOf(hop.getHopID())) && !hashSet2.contains(Long.valueOf(hop.getHopID()));
    }

    private static HashSet<Long> getPartitionInputNodes(HashSet<Long> hashSet, HashSet<Long> hashSet2, CPlanMemoTable cPlanMemoTable) {
        HashSet<Long> hashSet3 = new HashSet<>();
        HashSet hashSet4 = new HashSet();
        Iterator<Long> it = hashSet.iterator();
        while (it.hasNext()) {
            rCollectPartitionInputNodes(cPlanMemoTable.getHopRefs().get(it.next()), hashSet4, hashSet2, hashSet3);
        }
        return hashSet3;
    }

    private static void rCollectPartitionInputNodes(Hop hop, HashSet<Long> hashSet, HashSet<Long> hashSet2, HashSet<Long> hashSet3) {
        if (hashSet.contains(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (hashSet2.contains(Long.valueOf(next.getHopID()))) {
                rCollectPartitionInputNodes(next, hashSet, hashSet2, hashSet3);
            } else {
                hashSet3.add(Long.valueOf(next.getHopID()));
            }
        }
        hashSet.add(Long.valueOf(hop.getHopID()));
    }

    private static HashSet<Long> getNodesWithNonPartitionConsumers(HashSet<Long> hashSet, HashSet<Long> hashSet2, CPlanMemoTable cPlanMemoTable) {
        HashSet<Long> hashSet3 = new HashSet<>();
        Iterator<Long> it = hashSet2.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            if (hasNonPartitionConsumer(cPlanMemoTable.getHopRefs().get(next), hashSet2) && !hashSet.contains(next)) {
                hashSet3.add(next);
            }
        }
        return hashSet3;
    }

    private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> hashSet) {
        boolean z = false;
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            z |= !hashSet.contains(Long.valueOf(it.next().getHopID()));
        }
        return z;
    }

    private static InterestingPoint[] getMaterializationPointsExt(HashSet<Long> hashSet, HashSet<Long> hashSet2, ArrayList<Long> arrayList, CPlanMemoTable cPlanMemoTable) {
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(getMaterializationPointConsumers(arrayList, hashSet2, cPlanMemoTable));
        arrayList2.addAll(getTemplateChangePoints(hashSet2, cPlanMemoTable));
        InterestingPoint[] interestingPointArr = (InterestingPoint[]) arrayList2.stream().distinct().toArray(i -> {
            return new InterestingPoint[i];
        });
        if (LOG.isTraceEnabled()) {
            LOG.trace("Partition materialization points (extended): " + Arrays.toString(interestingPointArr));
        }
        return interestingPointArr;
    }

    private static ArrayList<InterestingPoint> getMaterializationPointConsumers(ArrayList<Long> arrayList, HashSet<Long> hashSet, CPlanMemoTable cPlanMemoTable) {
        ArrayList<InterestingPoint> arrayList2 = new ArrayList<>();
        Iterator<Long> it = arrayList.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            Iterator<Hop> it2 = cPlanMemoTable.getHopRefs().get(next).getParent().iterator();
            while (it2.hasNext()) {
                Hop next2 = it2.next();
                if (hashSet.contains(Long.valueOf(next2.getHopID()))) {
                    arrayList2.add(new InterestingPoint(InterestingPoint.DecisionType.MULTI_CONSUMER, next2.getHopID(), next.longValue()));
                }
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Partition materialization point consumers: " + Arrays.toString(arrayList2.toArray(new InterestingPoint[0])));
        }
        return arrayList2;
    }

    private static ArrayList<InterestingPoint> getTemplateChangePoints(HashSet<Long> hashSet, CPlanMemoTable cPlanMemoTable) {
        ArrayList<InterestingPoint> arrayList = new ArrayList<>();
        Iterator<Long> it = hashSet.iterator();
        while (it.hasNext()) {
            Long next = it.next();
            long[] allRefs = cPlanMemoTable.getAllRefs(next.longValue());
            for (int i = 0; i < 3; i++) {
                if (allRefs[i] >= 0) {
                    if (cPlanMemoTable.containsNotIn(allRefs[i], cPlanMemoTable.getDistinctTemplateTypes(next.longValue(), i, true), true)) {
                        arrayList.add(new InterestingPoint(InterestingPoint.DecisionType.TEMPLATE_CHANGE, next.longValue(), allRefs[i]));
                    }
                }
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Partition template change points: " + Arrays.toString(arrayList.toArray(new InterestingPoint[0])));
        }
        return arrayList;
    }

    private static HashSet<Long> rGetConnectedSubGraphs(long j, CPlanMemoTable cPlanMemoTable, HashMap<Long, HashSet<Long>> hashMap, HashSet<Long> hashSet, HashSet<Long> hashSet2) {
        if (hashSet.contains(Long.valueOf(j))) {
            return hashSet2;
        }
        if (cPlanMemoTable.contains(j)) {
            hashSet2.add(Long.valueOf(j));
            hashSet.add(Long.valueOf(j));
        }
        if (hashMap.containsKey(Long.valueOf(j))) {
            Iterator<Long> it = hashMap.get(Long.valueOf(j)).iterator();
            while (it.hasNext()) {
                rGetConnectedSubGraphs(it.next().longValue(), cPlanMemoTable, hashMap, hashSet, hashSet2);
            }
        }
        if (cPlanMemoTable.contains(j)) {
            long[] allRefs = cPlanMemoTable.getAllRefs(j);
            for (int i = 0; i < 3; i++) {
                if (allRefs[i] != -1) {
                    rGetConnectedSubGraphs(allRefs[i], cPlanMemoTable, hashMap, hashSet, hashSet2);
                }
            }
        }
        return hashSet2;
    }
}
