package org.apache.sysml.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.MultiThreadedHop;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.utils.Statistics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.class */
public class ParamservUtils {
    public static final String PS_FUNC_PREFIX = "_ps_";
    protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName());
    public static long SEED = -1;

    public static ListObject copyList(ListObject listObject, boolean z) {
        ListObject listObject2 = new ListObject((List) IntStream.range(0, listObject.getLength()).mapToObj(i -> {
            Data slice = listObject.slice(i);
            if (slice instanceof MatrixObject) {
                return createShallowCopy((MatrixObject) slice);
            }
            if ((slice instanceof ListObject) || (slice instanceof FrameObject)) {
                throw new DMLRuntimeException("Copy list: does not support list or frame.");
            }
            return slice;
        }).collect(Collectors.toList()), listObject.getNames());
        if (z) {
            cleanupListObject(listObject);
        }
        return listObject2;
    }

    public static void cleanupListObject(ExecutionContext executionContext, String str) {
        ListObject listObject = (ListObject) executionContext.removeVariable(str);
        cleanupListObject(executionContext, listObject, listObject.getStatus());
    }

    public static void cleanupListObject(ExecutionContext executionContext, String str, boolean[] zArr) {
        cleanupListObject(executionContext, (ListObject) executionContext.removeVariable(str), zArr);
    }

    public static void cleanupListObject(ExecutionContext executionContext, ListObject listObject) {
        cleanupListObject(executionContext, listObject, listObject.getStatus());
    }

    public static void cleanupListObject(ExecutionContext executionContext, ListObject listObject, boolean[] zArr) {
        for (int i = 0; i < listObject.getLength(); i++) {
            if (zArr == null || zArr[i]) {
                cleanupData(executionContext, listObject.getData().get(i));
            }
        }
    }

    public static void cleanupData(ExecutionContext executionContext, Data data) {
        if (data instanceof CacheableData) {
            CacheableData<?> cacheableData = (CacheableData) data;
            cacheableData.enableCleanup(true);
            executionContext.cleanupCacheableData(cacheableData);
        }
    }

    public static void cleanupData(ExecutionContext executionContext, String str) {
        cleanupData(executionContext, executionContext.removeVariable(str));
    }

    public static void cleanupListObject(ListObject listObject) {
        cleanupListObject(ExecutionContextFactory.createContext(), listObject);
    }

    public static MatrixObject newMatrixObject(MatrixBlock matrixBlock) {
        return newMatrixObject(matrixBlock, true);
    }

    public static MatrixObject newMatrixObject(MatrixBlock matrixBlock, boolean z) {
        MatrixObject matrixObject = new MatrixObject(Expression.ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(new MatrixCharacteristics(-1L, -1L, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
        matrixObject.acquireModify(matrixBlock);
        matrixObject.release();
        matrixObject.enableCleanup(z);
        return matrixObject;
    }

    public static MatrixObject createShallowCopy(MatrixObject matrixObject) {
        return newMatrixObject(matrixObject.acquireReadAndRelease(), false);
    }

    public static MatrixObject sliceMatrix(MatrixObject matrixObject, long j, long j2) {
        return newMatrixObject(sliceMatrixBlock(matrixObject.acquireReadAndRelease(), j, j2), false);
    }

    public static MatrixBlock sliceMatrixBlock(MatrixBlock matrixBlock, long j, long j2) {
        return matrixBlock.slice(((int) j) - 1, ((int) j2) - 1);
    }

    public static MatrixBlock generatePermutation(int i, long j) {
        return new MatrixBlock(i, 1, false).ctableSeqOperations(MatrixBlock.sampleOperations(i, i, false, j), 1.0d, new MatrixBlock(i, i, true));
    }

    public static String[] getCompleteFuncName(String str, String str2) {
        String[] splitFunctionKey = DMLProgram.splitFunctionKey(str);
        String str3 = splitFunctionKey.length == 2 ? splitFunctionKey[0] : null;
        String str4 = splitFunctionKey.length == 2 ? splitFunctionKey[1] : splitFunctionKey[0];
        return StringUtils.isEmpty(str2) ? new String[]{str3, str4} : new String[]{str3, str4};
    }

    public static ExecutionContext createExecutionContext(ExecutionContext executionContext, LocalVariableMap localVariableMap, String str, String str2, int i) {
        FunctionProgramBlock functionBlock = getFunctionBlock(executionContext, str);
        FunctionProgramBlock functionBlock2 = getFunctionBlock(executionContext, str2);
        Program program = executionContext.getProgram();
        recompileProgramBlocks(i, program.getProgramBlocks());
        program.getFunctionProgramBlocks().forEach((str3, functionProgramBlock) -> {
            recompileProgramBlocks(i, functionProgramBlock.getChildBlocks());
        });
        FunctionProgramBlock copyFunction = copyFunction(str, functionBlock);
        FunctionProgramBlock copyFunction2 = copyFunction(str2, functionBlock2);
        Program program2 = new Program();
        putFunction(program2, copyFunction);
        putFunction(program2, copyFunction2);
        return ExecutionContextFactory.createContext(new LocalVariableMap(localVariableMap), program2);
    }

    public static List<ExecutionContext> copyExecutionContext(ExecutionContext executionContext, int i) {
        return (List) IntStream.range(0, i).mapToObj(i2 -> {
            Program program = new Program();
            executionContext.getProgram().getFunctionProgramBlocks().forEach((str, functionProgramBlock) -> {
                putFunction(program, copyFunction(str, functionProgramBlock));
            });
            return ExecutionContextFactory.createContext(new LocalVariableMap(executionContext.getVariables()), program);
        }).collect(Collectors.toList());
    }

    private static FunctionProgramBlock copyFunction(String str, FunctionProgramBlock functionProgramBlock) {
        FunctionProgramBlock createDeepCopyFunctionProgramBlock = ProgramConverter.createDeepCopyFunctionProgramBlock(functionProgramBlock, new HashSet(), new HashSet());
        String[] completeFuncName = getCompleteFuncName(str, PS_FUNC_PREFIX);
        createDeepCopyFunctionProgramBlock._namespace = completeFuncName[0];
        createDeepCopyFunctionProgramBlock._functionName = completeFuncName[1];
        return createDeepCopyFunctionProgramBlock;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void putFunction(Program program, FunctionProgramBlock functionProgramBlock) {
        program.addFunctionProgramBlock(functionProgramBlock._namespace, functionProgramBlock._functionName, functionProgramBlock);
        program.addProgramBlock(functionProgramBlock);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void recompileProgramBlocks(int i, ArrayList<ProgramBlock> arrayList) {
        Iterator<ProgramBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            DMLTranslator.resetHopsDAGVisitStatus(it.next().getStatementBlock());
        }
        try {
            rAssignParallelism(arrayList, i, false);
        } catch (IOException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean rAssignParallelism(ArrayList<ProgramBlock> arrayList, int i, boolean z) throws IOException {
        Iterator<ProgramBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            ProgramBlock next = it.next();
            if (next instanceof ParForProgramBlock) {
                ParForProgramBlock parForProgramBlock = (ParForProgramBlock) next;
                parForProgramBlock.setDegreeOfParallelism(i);
                z |= rAssignParallelism(parForProgramBlock.getChildBlocks(), 1, z);
            } else if (next instanceof ForProgramBlock) {
                z |= rAssignParallelism(((ForProgramBlock) next).getChildBlocks(), i, z);
            } else if (next instanceof WhileProgramBlock) {
                z |= rAssignParallelism(((WhileProgramBlock) next).getChildBlocks(), i, z);
            } else if (next instanceof FunctionProgramBlock) {
                z |= rAssignParallelism(((FunctionProgramBlock) next).getChildBlocks(), i, z);
            } else if (next instanceof IfProgramBlock) {
                IfProgramBlock ifProgramBlock = (IfProgramBlock) next;
                z |= rAssignParallelism(ifProgramBlock.getChildBlocksIfBody(), i, z);
                if (ifProgramBlock.getChildBlocksElseBody() != null) {
                    z |= rAssignParallelism(ifProgramBlock.getChildBlocksElseBody(), i, z);
                }
            } else {
                Iterator<Hop> it2 = next.getStatementBlock().getHops().iterator();
                while (it2.hasNext()) {
                    z |= rAssignParallelism(it2.next(), i, z);
                }
            }
            if (z) {
                Recompiler.recompileProgramBlockInstructions(next);
            }
        }
        return z;
    }

    private static boolean rAssignParallelism(Hop hop, int i, boolean z) {
        if (hop.isVisited()) {
            return z;
        }
        if (hop instanceof MultiThreadedHop) {
            ((MultiThreadedHop) hop).setMaxNumThreads(i);
            z = true;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z |= rAssignParallelism(it.next(), i, z);
        }
        hop.setVisited();
        return z;
    }

    private static FunctionProgramBlock getFunctionBlock(ExecutionContext executionContext, String str) {
        String[] completeFuncName = getCompleteFuncName(str, null);
        return executionContext.getProgram().getFunctionProgramBlock(completeFuncName[0], completeFuncName[1]);
    }

    public static MatrixBlock cbindMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        return matrixBlock.append(matrixBlock2, new MatrixBlock());
    }

    public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD2) {
        return groupMatrix(javaPairRDD).join(groupMatrix(javaPairRDD2));
    }

    private static JavaPairRDD<Long, MatrixBlock> groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return javaPairRDD.mapToPair(tuple2 -> {
            return new Tuple2(Long.valueOf(((MatrixIndexes) tuple2._1).getRowIndex()), new Tuple2(Long.valueOf(((MatrixIndexes) tuple2._1).getColumnIndex()), tuple2._2));
        }).aggregateByKey(new LinkedList(), (linkedList, tuple22) -> {
            linkedList.add(tuple22);
            return linkedList;
        }, (linkedList2, linkedList3) -> {
            linkedList2.addAll(linkedList3);
            linkedList2.sort((tuple23, tuple24) -> {
                return ((Long) tuple23._1).compareTo((Long) tuple24._1);
            });
            return linkedList2;
        }).mapToPair(tuple23 -> {
            LinkedList linkedList4 = (LinkedList) tuple23._2;
            MatrixBlock matrixBlock = (MatrixBlock) ((Tuple2) linkedList4.get(0))._2;
            for (int i = 1; i < linkedList4.size(); i++) {
                matrixBlock = cbindMatrix(matrixBlock, (MatrixBlock) ((Tuple2) linkedList4.get(i))._2);
            }
            return new Tuple2(tuple23._1, matrixBlock);
        });
    }

    public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sparkExecutionContext, MatrixObject matrixObject, MatrixObject matrixObject2, Statement.PSScheme pSScheme, final int i) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> mapToPair = assembleTrainingData(sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo), sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject2, InputInfo.BinaryBlockInputInfo)).flatMapToPair(new DataPartitionerSparkMapper(pSScheme, i, sparkExecutionContext, (int) matrixObject.getNumRows())).aggregateByKey(new LinkedList(), new Partitioner() { // from class: org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.1
            private static final long serialVersionUID = -7937781374718031224L;

            public int getPartition(Object obj) {
                return ((Integer) obj).intValue();
            }

            public int numPartitions() {
                return i;
            }
        }, (linkedList, tuple2) -> {
            linkedList.add(tuple2);
            return linkedList;
        }, (linkedList2, linkedList3) -> {
            linkedList2.addAll(linkedList3);
            linkedList2.sort((tuple22, tuple23) -> {
                return ((Long) tuple22._1).compareTo((Long) tuple23._1);
            });
            return linkedList2;
        }).mapToPair(new DataPartitionerSparkAggregator(matrixObject.getNumColumns(), matrixObject2.getNumColumns()));
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long) timing.stop());
        }
        return mapToPair;
    }

    public static ListObject accrueGradients(ListObject listObject, ListObject listObject2, boolean z) {
        return accrueGradients(listObject, listObject2, false, z);
    }

    public static ListObject accrueGradients(ListObject listObject, ListObject listObject2, boolean z, boolean z2) {
        if (listObject == null) {
            return copyList(listObject2, z2);
        }
        IntStream range = IntStream.range(0, listObject.getLength());
        (z ? range.parallel() : range).forEach(i -> {
            ((MatrixObject) listObject.getData().get(i)).acquireReadAndRelease().binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), ((MatrixObject) listObject2.getData().get(i)).acquireReadAndRelease());
        });
        if (z2) {
            cleanupListObject(listObject2);
        }
        return listObject;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -247569191:
                if (implMethodName.equals("lambda$groupMatrix$d66f2528$1")) {
                    z = 3;
                    break;
                }
                break;
            case -247569190:
                if (implMethodName.equals("lambda$groupMatrix$d66f2528$2")) {
                    z = true;
                    break;
                }
                break;
            case 437830700:
                if (implMethodName.equals("lambda$groupMatrix$583ed2c$1")) {
                    z = false;
                    break;
                }
                break;
            case 437830701:
                if (implMethodName.equals("lambda$groupMatrix$583ed2c$2")) {
                    z = 2;
                    break;
                }
                break;
            case 582417386:
                if (implMethodName.equals("lambda$doPartitionOnSpark$ac1b77bd$1")) {
                    z = 4;
                    break;
                }
                break;
            case 582417387:
                if (implMethodName.equals("lambda$doPartitionOnSpark$ac1b77bd$2")) {
                    z = 5;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lscala/Tuple2;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lscala/Tuple2;")) {
                    return tuple2 -> {
                        return new Tuple2(Long.valueOf(((MatrixIndexes) tuple2._1).getRowIndex()), new Tuple2(Long.valueOf(((MatrixIndexes) tuple2._1).getColumnIndex()), tuple2._2));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedList;Ljava/util/LinkedList;)Ljava/util/LinkedList;")) {
                    return (linkedList2, linkedList3) -> {
                        linkedList2.addAll(linkedList3);
                        linkedList2.sort((tuple23, tuple24) -> {
                            return ((Long) tuple23._1).compareTo((Long) tuple24._1);
                        });
                        return linkedList2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/PairFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Lscala/Tuple2;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Lscala/Tuple2;)Lscala/Tuple2;")) {
                    return tuple23 -> {
                        LinkedList linkedList4 = (LinkedList) tuple23._2;
                        MatrixBlock matrixBlock = (MatrixBlock) ((Tuple2) linkedList4.get(0))._2;
                        for (int i = 1; i < linkedList4.size(); i++) {
                            matrixBlock = cbindMatrix(matrixBlock, (MatrixBlock) ((Tuple2) linkedList4.get(i))._2);
                        }
                        return new Tuple2(tuple23._1, matrixBlock);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedList;Lscala/Tuple2;)Ljava/util/LinkedList;")) {
                    return (linkedList, tuple22) -> {
                        linkedList.add(tuple22);
                        return linkedList;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedList;Lscala/Tuple2;)Ljava/util/LinkedList;")) {
                    return (linkedList4, tuple24) -> {
                        linkedList4.add(tuple24);
                        return linkedList4;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/LinkedList;Ljava/util/LinkedList;)Ljava/util/LinkedList;")) {
                    return (linkedList22, linkedList32) -> {
                        linkedList22.addAll(linkedList32);
                        linkedList22.sort((tuple222, tuple232) -> {
                            return ((Long) tuple222._1).compareTo((Long) tuple232._1);
                        });
                        return linkedList22;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
