/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFactory;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.stats.NativeStatistics;

public class LibMatrixNative {
    private static final Log LOG = LogFactory.getLog((String)LibMatrixNative.class.getName());
    private static ThreadLocal<FloatBuffer> inBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> biasBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> filterBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> outBuff = new ThreadLocal();

    public static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int m2Clen) {
        return !(m1Rlen != 1 && m1Clen != 1 && m2Clen != 1 || 8L * (long)m1Rlen * (long)m1Clen <= 0x10000000L && 8L * (long)m1Clen * (long)m2Clen <= 0x10000000L);
    }

    public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
        if (NativeHelper.isNativeLibraryLoaded()) {
            boolean isValidForNative;
            int n = k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
            if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
                return LibMatrixMult.emptyMatrixMult(m1, m2, ret);
            }
            boolean bl = isValidForNative = !LibMatrixNative.isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) && !m1.isInSparseFormat() && !m2.isInSparseFormat() && (m1.getDenseBlock().isContiguous() || !LibMatrixNative.isSinglePrecision()) && m2.getDenseBlock().isContiguous() && 8L * ret.getLength() < Integer.MAX_VALUE;
            if (isValidForNative) {
                if (ret == null) {
                    ret = new MatrixBlock(m1.rlen, m2.clen, false);
                } else {
                    ret.reset(m1.rlen, m2.clen, false);
                }
                ret.allocateBlock();
                long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
                long nnz = 0L;
                if (LibMatrixNative.isSinglePrecision()) {
                    FloatBuffer fin1 = LibMatrixNative.toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
                    FloatBuffer fin2 = LibMatrixNative.toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
                    FloatBuffer fout = LibMatrixNative.toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
                    nnz = NativeHelper.smmdd(fin1, fin2, fout, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k);
                    LibMatrixNative.fromFloatBuffer(outBuff.get(), ret.getDenseBlockValues());
                } else {
                    DenseBlock a = m1.getDenseBlock();
                    if (a.isContiguous()) {
                        nnz = NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(), ret.getDenseBlockValues(), m1.rlen, m1.clen, m2.clen, k);
                    } else {
                        for (int bix = 0; bix < a.numBlocks(); ++bix) {
                            double[] tmp = new double[a.blockSize(bix) * m2.clen];
                            nnz += NativeHelper.dmmdd(a.valuesAt(bix), m2.getDenseBlockValues(), tmp, a.blockSize(bix), m1.clen, m2.clen, k);
                            int rl = bix * a.blockSize();
                            ret.getDenseBlock().set(rl, rl + a.blockSize(bix), 0, m2.clen, DenseBlockFactory.createDenseBlock(tmp, new int[]{a.blockSize(bix), m2.clen}));
                        }
                    }
                }
                if (nnz > -1L) {
                    if (DMLScript.STATISTICS) {
                        NativeStatistics.incrementLibMatrixMultTime(System.nanoTime() - start);
                        NativeStatistics.incrementNumLibMatrixMultCalls();
                    }
                    ret.setNonZeros(nnz);
                    ret.examSparsity();
                    return ret;
                }
                NativeStatistics.incrementFailuresCounter();
                LOG.warn((Object)("matrixMult: Native mat mult failed. Falling back to java version (loaded=" + NativeHelper.isNativeLibraryLoaded() + ", sparse=" + (m1.isInSparseFormat() | m2.isInSparseFormat()) + ")"));
            }
        } else {
            LOG.warn((Object)"Was valid for native MM but native lib was not loaded");
        }
        return LibMatrixMult.matrixMult(m1, m2, ret, k);
    }

    public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) {
        if (m1.isEmptyBlock(false)) {
            return;
        }
        if (NativeHelper.isNativeLibraryLoaded() && (ret.clen > 1 || ret.getLength() == 1L) && !LibMatrixMult.isOuterProductTSMM(m1.rlen, m1.clen, leftTrans) && !m1.sparse && m1.getDenseBlock().isContiguous() | leftTrans) {
            ret.sparse = false;
            ret.allocateDenseBlock();
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            DenseBlock a = m1.getDenseBlock();
            double[] cvals = ret.getDenseBlockValues();
            long nnz = 0L;
            if (a.isContiguous()) {
                nnz = NativeHelper.tsmm(a.valuesAt(0), cvals, m1.rlen, m1.clen, leftTrans, k);
            } else {
                IntStream.range(0, a.numBlocks()).forEach(bix -> {
                    double[] tmp = new double[m1.clen * m1.clen];
                    NativeHelper.tsmm(a.valuesAt(bix), tmp, a.blockSize(bix), m1.clen, leftTrans, k);
                    LibMatrixMult.vectAdd(tmp, cvals, 0, 0, m1.clen * m1.clen);
                });
                nnz = ret.recomputeNonZeros();
            }
            if (nnz > -1L) {
                if (DMLScript.STATISTICS) {
                    NativeStatistics.incrementLibMatrixMultTime(System.nanoTime() - start);
                    NativeStatistics.incrementNumLibMatrixMultCalls();
                }
                ret.setNonZeros(nnz);
                ret.examSparsity();
                return;
            }
            LOG.warn((Object)"Native TSMM failed. Falling back to java version.");
            NativeStatistics.incrementFailuresCounter();
        }
        if (k > 1) {
            LibMatrixMult.matrixMultTransposeSelf(m1, ret, leftTrans, k);
        } else {
            LibMatrixMult.matrixMultTransposeSelf(m1, ret, leftTrans);
        }
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !input.isInSparseFormat() && !filter.isInSparseFormat()) {
            long nnz;
            long start;
            LibMatrixNative.setNumThreads(params);
            long l = start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (params.bias == null) {
                nnz = NativeHelper.conv2dDense(input.getDenseBlockValues(), filter.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            } else {
                if (params.bias.isInSparseFormat()) {
                    params.bias.sparseToDense();
                }
                if (LibMatrixNative.isSinglePrecision() && !NativeHelper.getCurrentBLAS().equalsIgnoreCase("mkl")) {
                    FloatBuffer foutput;
                    FloatBuffer ffilter;
                    FloatBuffer fbias;
                    FloatBuffer finput = LibMatrixNative.toFloatBuffer(input.getDenseBlockValues(), inBuff, true);
                    nnz = NativeHelper.sconv2dBiasAddDense(finput, fbias = LibMatrixNative.toFloatBuffer(params.bias.getDenseBlockValues(), biasBuff, true), ffilter = LibMatrixNative.toFloatBuffer(filter.getDenseBlockValues(), filterBuff, true), foutput = LibMatrixNative.toFloatBuffer(outputBlock.getDenseBlockValues(), outBuff, false), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                    if (nnz != -1L) {
                        LibMatrixNative.fromFloatBuffer(outBuff.get(), outputBlock.getDenseBlockValues());
                    }
                } else {
                    nnz = NativeHelper.dconv2dBiasAddDense(input.getDenseBlockValues(), params.bias.getDenseBlockValues(), filter.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                }
            }
            if (nnz != -1L) {
                if (DMLScript.STATISTICS) {
                    NativeStatistics.incrementConv2dTime(System.nanoTime() - start);
                    NativeStatistics.incrementNumConv2dCalls();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            LOG.warn((Object)"Native conv2d call returned with error - falling back to java operator.");
            if (!LibMatrixNative.isSinglePrecision() || params.bias == null) {
                outputBlock.reset();
            }
            NativeStatistics.incrementFailuresCounter();
        }
        LibMatrixDNN.conv2d(input, filter, outputBlock, params);
    }

    private static void setNumThreads(DnnParameters params) {
        params.numThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        if (!params.isOutputThreadSafe() || params.numThreads <= 1) {
            params.numThreads = 1;
        }
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !input.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            long nnz = NativeHelper.conv2dBackwardFilterDense(input.getDenseBlockValues(), dout.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1L) {
                if (DMLScript.STATISTICS) {
                    NativeStatistics.incrementConv2dBwdFilterTime(System.nanoTime() - start);
                    NativeStatistics.incrementNumConv2dBwdFilterCalls();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            NativeStatistics.incrementFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardFilter(input, dout, outputBlock, params);
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, DnnParameters params) {
        LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !filter.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            long nnz = NativeHelper.conv2dBackwardDataDense(filter.getDenseBlockValues(), dout.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1L) {
                if (DMLScript.STATISTICS) {
                    NativeStatistics.incrementConv2dBwdDataTime(System.nanoTime() - start);
                    NativeStatistics.incrementNumConv2dBwdDataCalls();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            NativeStatistics.incrementFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardData(filter, dout, outputBlock, params);
    }

    public static boolean isSinglePrecision() {
        return ConfigurationManager.getDMLConfig().getTextValue("sysds.floating.point.precision").equals("single");
    }

    private static FloatBuffer toFloatBuffer(double[] input, ThreadLocal<FloatBuffer> buff, boolean copy) {
        FloatBuffer ret = buff.get();
        if (ret == null || ret.capacity() < input.length) {
            ret = ByteBuffer.allocateDirect(4 * input.length).order(ByteOrder.nativeOrder()).asFloatBuffer();
            buff.set(ret);
        }
        FloatBuffer ret2 = ret;
        if (copy) {
            IntStream.range(0, input.length).parallel().forEach(i -> ret2.put(i, (float)input[i]));
        }
        return ret2;
    }

    public static void fromFloatBuffer(FloatBuffer buff, double[] output) {
        Arrays.parallelSetAll(output, i -> buff.get(i));
    }
}

