/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.analysis;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PayloadAttribute;
import org.apache.lucene.util.BytesRef;

public class HFModelTokenizer
extends Tokenizer {
    @Generated
    private static final Logger log = LogManager.getLogger(HFModelTokenizer.class);
    public static final String NAME = "hf_model_tokenizer";
    private static final Float DEFAULT_TOKEN_WEIGHT = Float.valueOf(1.0f);
    private final CharTermAttribute termAtt = (CharTermAttribute)this.addAttribute(CharTermAttribute.class);
    private final PayloadAttribute payloadAtt;
    private final OffsetAttribute offsetAtt = (OffsetAttribute)this.addAttribute(OffsetAttribute.class);
    private final Supplier<HuggingFaceTokenizer> tokenizerSupplier;
    private final Supplier<Map<String, Float>> tokenWeightsSupplier;
    private Encoding encoding;
    private int tokenIdx = 0;
    private int overflowingIdx = 0;

    public HFModelTokenizer(Supplier<HuggingFaceTokenizer> huggingFaceTokenizerSupplier) {
        this(huggingFaceTokenizerSupplier, null);
    }

    public HFModelTokenizer(Supplier<HuggingFaceTokenizer> huggingFaceTokenizerSupplier, Supplier<Map<String, Float>> weightsSupplier) {
        this.payloadAtt = Objects.nonNull(weightsSupplier) ? (PayloadAttribute)this.addAttribute(PayloadAttribute.class) : null;
        this.tokenizerSupplier = huggingFaceTokenizerSupplier;
        this.tokenWeightsSupplier = weightsSupplier;
    }

    public void reset() throws IOException {
        super.reset();
        this.tokenIdx = 0;
        this.overflowingIdx = -1;
        String inputStr = CharStreams.toString((Readable)this.input);
        this.encoding = StringUtils.isEmpty((CharSequence)inputStr) ? null : this.tokenizerSupplier.get().encode(inputStr, false, true);
    }

    private static boolean isLastTokenInEncodingSegment(int idx, Encoding encodingSegment) {
        return idx >= encodingSegment.getTokens().length || encodingSegment.getAttentionMask()[idx] == 0L;
    }

    public static byte[] floatToBytes(float value) {
        return ByteBuffer.allocate(4).putFloat(value).array();
    }

    public static float bytesToFloat(byte[] bytes) {
        return ByteBuffer.wrap(bytes).getFloat();
    }

    public final boolean incrementToken() throws IOException {
        Encoding curEncoding;
        this.clearAttributes();
        if (Objects.isNull(this.encoding)) {
            return false;
        }
        Encoding encoding = curEncoding = this.overflowingIdx == -1 ? this.encoding : this.encoding.getOverflowing()[this.overflowingIdx];
        while (!HFModelTokenizer.isLastTokenInEncodingSegment(this.tokenIdx, curEncoding) || this.overflowingIdx < this.encoding.getOverflowing().length) {
            if (HFModelTokenizer.isLastTokenInEncodingSegment(this.tokenIdx, curEncoding)) {
                this.tokenIdx = 0;
                ++this.overflowingIdx;
                if (this.overflowingIdx >= this.encoding.getOverflowing().length) {
                    return false;
                }
                curEncoding = this.encoding.getOverflowing()[this.overflowingIdx];
                continue;
            }
            this.termAtt.append(curEncoding.getTokens()[this.tokenIdx]);
            this.offsetAtt.setOffset(curEncoding.getCharTokenSpans()[this.tokenIdx].getStart(), curEncoding.getCharTokenSpans()[this.tokenIdx].getEnd());
            if (Objects.nonNull(this.tokenWeightsSupplier)) {
                this.payloadAtt.setPayload(new BytesRef(HFModelTokenizer.floatToBytes(this.tokenWeightsSupplier.get().getOrDefault(curEncoding.getTokens()[this.tokenIdx], DEFAULT_TOKEN_WEIGHT).floatValue())));
            }
            ++this.tokenIdx;
            return true;
        }
        return false;
    }
}

