package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import jcuda.Pointer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.runtime.DMLRuntimeException;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/gpu/context/GPUMatrixMemoryManager.class */
public class GPUMatrixMemoryManager {
    protected static final Log LOG = LogFactory.getLog(GPUMatrixMemoryManager.class.getName());
    GPUMemoryManager gpuManager;
    HashSet<GPUObject> gpuObjects = new HashSet<>();

    public GPUMatrixMemoryManager(GPUMemoryManager gPUMemoryManager) {
        this.gpuManager = gPUMemoryManager;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addGPUObject(GPUObject gPUObject) {
        this.gpuObjects.add(gPUObject);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long getWorstCaseContiguousMemorySize(GPUObject gPUObject) {
        long j = 0;
        if (!gPUObject.isDensePointerNull()) {
            j = !gPUObject.shadowBuffer.isBuffered() ? this.gpuManager.allPointers.get(gPUObject.getDensePointer()).getSizeInBytes() : 0L;
        } else if (gPUObject.getJcudaSparseMatrixPtr() != null) {
            CSRPointer jcudaSparseMatrixPtr = gPUObject.getJcudaSparseMatrixPtr();
            if (jcudaSparseMatrixPtr.nnz > 0) {
                if (jcudaSparseMatrixPtr.rowPtr != null) {
                    j = Math.max(0L, this.gpuManager.allPointers.get(jcudaSparseMatrixPtr.rowPtr).getSizeInBytes());
                }
                if (jcudaSparseMatrixPtr.colInd != null) {
                    j = Math.max(j, this.gpuManager.allPointers.get(jcudaSparseMatrixPtr.colInd).getSizeInBytes());
                }
                if (jcudaSparseMatrixPtr.val != null) {
                    j = Math.max(j, this.gpuManager.allPointers.get(jcudaSparseMatrixPtr.val).getSizeInBytes());
                }
            }
        }
        return j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Set<Pointer> getPointers(GPUObject gPUObject) {
        CSRPointer sparseMatrixCudaPointer;
        HashSet hashSet = new HashSet();
        if (!gPUObject.isDensePointerNull() && gPUObject.getSparseMatrixCudaPointer() != null) {
            LOG.warn("Matrix allocated in both dense and sparse format");
        }
        if (!gPUObject.isDensePointerNull()) {
            hashSet.add(gPUObject.getDensePointer());
        }
        if (gPUObject.getSparseMatrixCudaPointer() != null && (sparseMatrixCudaPointer = gPUObject.getSparseMatrixCudaPointer()) != null) {
            if (sparseMatrixCudaPointer.rowPtr != null) {
                hashSet.add(sparseMatrixCudaPointer.rowPtr);
            } else if (sparseMatrixCudaPointer.colInd != null) {
                hashSet.add(sparseMatrixCudaPointer.colInd);
            } else if (sparseMatrixCudaPointer.val != null) {
                hashSet.add(sparseMatrixCudaPointer.val);
            }
        }
        return hashSet;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Set<Pointer> getPointers() {
        return (Set) this.gpuObjects.stream().flatMap(gPUObject -> {
            return getPointers(gPUObject).stream();
        }).collect(Collectors.toSet());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Set<Pointer> getPointers(boolean z, boolean z2) {
        return (Set) this.gpuObjects.stream().filter(gPUObject -> {
            return gPUObject.isLocked() == z && gPUObject.isDirty() == z2;
        }).flatMap(gPUObject2 -> {
            return getPointers(gPUObject2).stream();
        }).collect(Collectors.toSet());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void clearAllUnlocked(String str) throws DMLRuntimeException {
        Set<GPUObject> set = (Set) this.gpuObjects.stream().filter(gPUObject -> {
            return !gPUObject.isLocked();
        }).collect(Collectors.toSet());
        if (set.size() > 0) {
            if (LOG.isWarnEnabled()) {
                LOG.warn("Clearing all unlocked matrices (count=" + set.size() + ").");
            }
            for (GPUObject gPUObject2 : set) {
                if (gPUObject2.dirty) {
                    gPUObject2.copyFromDeviceToHost(str, true, true);
                } else {
                    gPUObject2.clearData(str, true);
                }
            }
            this.gpuObjects.removeAll(set);
        }
    }
}
