/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class ShuffleFederatedScheme
extends DataPartitionFederatedScheme {
    @Override
    public DataPartitionFederatedScheme.Result partition(MatrixObject features, MatrixObject labels, int seed) {
        List<MatrixObject> pFeatures = ShuffleFederatedScheme.sliceFederatedMatrix(features);
        List<MatrixObject> pLabels = ShuffleFederatedScheme.sliceFederatedMatrix(labels);
        DataPartitionFederatedScheme.BalanceMetrics balanceMetrics = ShuffleFederatedScheme.getBalanceMetrics(pFeatures);
        List<Double> weightingFactors = ShuffleFederatedScheme.getWeightingFactors(pFeatures, balanceMetrics);
        for (int i = 0; i < pFeatures.size(); ++i) {
            FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
            FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
            Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, featuresData.getVarID(), new shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed)));
            try {
                FederatedResponse response = udfResponse.get();
                if (response.isSuccessful()) continue;
                throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle UDF returned fail. Federated worker error message: " + response.getErrorMessage());
            }
            catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing shuffle UDF failed" + e.getMessage());
            }
        }
        return new DataPartitionFederatedScheme.Result(pFeatures, pLabels, pFeatures.size(), balanceMetrics, weightingFactors);
    }

    private static class shuffleDataOnFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = 3228664618781333325L;
        private final int _seed;

        protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) {
            super(inIDs);
            this._seed = seed;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject features = (MatrixObject)data[0];
            MatrixObject labels = (MatrixObject)data[1];
            MatrixBlock permutationMatrixBlock = ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), this._seed);
            DataPartitionFederatedScheme.shuffle(features, permutationMatrixBlock);
            DataPartitionFederatedScheme.shuffle(labels, permutationMatrixBlock);
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

