package org.apache.sysml.runtime.controlprogram.parfor;

import java.util.LinkedList;
import java.util.List;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.Task;
import org.apache.sysml.runtime.instructions.cp.IntObject;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/TaskPartitionerFactoring.class */
public class TaskPartitionerFactoring extends TaskPartitioner {
    private int _numThreads;

    public TaskPartitionerFactoring(long j, int i, String str, IntObject intObject, IntObject intObject2, IntObject intObject3) {
        super(j, str, intObject, intObject2, intObject3);
        this._numThreads = -1;
        this._numThreads = i;
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.TaskPartitioner
    public List<Task> createTasks() {
        LinkedList linkedList = new LinkedList();
        long longValue = this._fromVal.getLongValue();
        long longValue2 = this._toVal.getLongValue();
        long longValue3 = this._incrVal.getLongValue();
        int i = this._numThreads;
        long j = this._numIter;
        long j2 = longValue;
        while (j2 <= longValue2) {
            long determineNextBatchSize = determineNextBatchSize(j, i);
            j -= determineNextBatchSize * i;
            Task.TaskType taskType = determineNextBatchSize > 3 ? Task.TaskType.RANGE : Task.TaskType.SET;
            for (int i2 = 0; i2 < i && j2 <= longValue2; i2++) {
                Task task = new Task(this._iterVarName, taskType);
                linkedList.addLast(task);
                if (taskType == Task.TaskType.SET) {
                    long j3 = 0;
                    while (j3 < determineNextBatchSize && j2 <= longValue2) {
                        task.addIteration(new IntObject(j2));
                        j3++;
                        j2 += longValue3;
                    }
                } else {
                    long min = Math.min(j2 + ((determineNextBatchSize - 1) * longValue3), longValue2);
                    task.addIteration(new IntObject(j2));
                    task.addIteration(new IntObject(min));
                    task.addIteration(new IntObject(longValue3));
                    j2 = min + longValue3;
                }
            }
        }
        return linkedList;
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.TaskPartitioner
    public long createTasks(LocalTaskQueue<Task> localTaskQueue) {
        long j = 0;
        long longValue = this._fromVal.getLongValue();
        long longValue2 = this._toVal.getLongValue();
        long longValue3 = this._incrVal.getLongValue();
        int i = this._numThreads;
        long j2 = this._numIter;
        long j3 = longValue;
        while (j3 <= longValue2) {
            try {
                long determineNextBatchSize = determineNextBatchSize(j2, i);
                j2 -= determineNextBatchSize * i;
                Task.TaskType taskType = determineNextBatchSize > 3 ? Task.TaskType.RANGE : Task.TaskType.SET;
                for (int i2 = 0; i2 < i && j3 <= longValue2; i2++) {
                    Task task = new Task(this._iterVarName, taskType);
                    if (taskType == Task.TaskType.SET) {
                        long j4 = 0;
                        while (j4 < determineNextBatchSize && j3 <= longValue2) {
                            task.addIteration(new IntObject(j3));
                            j4++;
                            j3 += longValue3;
                        }
                    } else {
                        long min = Math.min(j3 + ((determineNextBatchSize - 1) * longValue3), longValue2);
                        task.addIteration(new IntObject(j3));
                        task.addIteration(new IntObject(min));
                        task.addIteration(new IntObject(longValue3));
                        j3 = min + longValue3;
                    }
                    localTaskQueue.enqueueTask(task);
                    j++;
                }
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
        localTaskQueue.closeInput();
        return j;
    }

    protected long determineNextBatchSize(long j, int i) {
        long ceil = (long) Math.ceil(j / (2 * i));
        if (ceil < 1) {
            ceil = 1;
        }
        return ceil;
    }
}
