/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.util.UtilFunctions;

public class ColumnEncoderWordEmbedding
extends ColumnEncoder {
    private MatrixBlock _wordEmbeddings;
    private Map<Object, Long> _rcdMap;
    private HashMap<String, double[]> _embMap;

    public ColumnEncoderWordEmbedding() {
        super(-1);
        this._rcdMap = new HashMap<Object, Long>();
        this._wordEmbeddings = new MatrixBlock();
    }

    private long lookupRCDMap(Object key) {
        return this._rcdMap.getOrDefault(key, -1L);
    }

    @Override
    public int getDomainSize() {
        return this._wordEmbeddings.getNumColumns();
    }

    public int getNrDistinctEmbeddings() {
        return this._wordEmbeddings.getNumRows();
    }

    protected ColumnEncoderWordEmbedding(int colID) {
        super(colID);
    }

    @Override
    protected double getCode(CacheBlock<?> in, int row) {
        throw new NotImplementedException();
    }

    @Override
    protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) {
        throw new NotImplementedException();
    }

    private double[] getEmbeddedingFromEmbeddingMatrix(long r) {
        double[] embedding = new double[this.getDomainSize()];
        for (int i = 0; i < this.getDomainSize(); ++i) {
            embedding[i] = this._wordEmbeddings.quickGetValue((int)r, this._colID - 1 + i);
        }
        return embedding;
    }

    @Override
    public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        int rowEnd = UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk);
        for (int i = rowStart; i < rowEnd; ++i) {
            double[] embedding;
            String key = in.getString(i, this._colID - 1);
            if (key == null || key.isEmpty() || (embedding = this._embMap.get(key)) == null) continue;
            out.quickSetRow(i, embedding);
        }
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.WORD_EMBEDDING;
    }

    @Override
    public void build(CacheBlock<?> in) {
        throw new NotImplementedException();
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
        throw new NotImplementedException();
    }

    @Override
    public FrameBlock getMetaData(FrameBlock out) {
        throw new NotImplementedException();
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        if (meta == null || meta.getNumRows() <= 0) {
            return;
        }
        this._rcdMap = meta.getRecodeMap(this._colID - 1);
    }

    @Override
    public void initEmbeddings(MatrixBlock embeddings) {
        this._wordEmbeddings = embeddings;
        this._embMap = new HashMap();
        this._rcdMap.forEach((word, index) -> this._embMap.put((String)word, this.getEmbeddedingFromEmbeddingMatrix(index - 1L)));
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        super.writeExternal(out);
        out.writeInt(this._rcdMap.size());
        for (Map.Entry<Object, Long> e : this._rcdMap.entrySet()) {
            out.writeUTF(e.getKey().toString());
            out.writeLong(e.getValue());
        }
        this._wordEmbeddings.write(out);
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        super.readExternal(in);
        int size = in.readInt();
        for (int j = 0; j < size; ++j) {
            String key = in.readUTF();
            Long value = in.readLong();
            this._rcdMap.put(key, value);
        }
        this._wordEmbeddings.readExternal(in);
        this.initEmbeddings(this._wordEmbeddings);
    }
}

