/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.embedding.onnx;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.embedding.TokenCountEstimator;
import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder;
import dev.langchain4j.model.embedding.onnx.PoolingMode;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

public abstract class AbstractInProcessEmbeddingModel
extends DimensionAwareEmbeddingModel
implements TokenCountEstimator {
    private final Executor executor;

    protected AbstractInProcessEmbeddingModel(Executor executor) {
        this.executor = (Executor)Utils.getOrDefault((Object)executor, this::createDefaultExecutor);
    }

    private Executor createDefaultExecutor() {
        int threadPoolSize = Runtime.getRuntime().availableProcessors();
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(threadPoolSize, threadPoolSize, 1L, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
        threadPoolExecutor.allowCoreThreadTimeOut(true);
        return threadPoolExecutor;
    }

    protected static OnnxBertBiEncoder loadFromJar(String modelFileName, String tokenizerFileName, PoolingMode poolingMode) {
        InputStream model = Thread.currentThread().getContextClassLoader().getResourceAsStream(modelFileName);
        InputStream tokenizer = Thread.currentThread().getContextClassLoader().getResourceAsStream(tokenizerFileName);
        return new OnnxBertBiEncoder(model, tokenizer, poolingMode);
    }

    static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, Path pathToTokenizer, PoolingMode poolingMode) {
        try {
            return new OnnxBertBiEncoder(Files.newInputStream(pathToModel, new OpenOption[0]), Files.newInputStream(pathToTokenizer, new OpenOption[0]), poolingMode);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, InputStream tokenizer, PoolingMode poolingMode) {
        try {
            return new OnnxBertBiEncoder(Files.newInputStream(pathToModel, new OpenOption[0]), tokenizer, poolingMode);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected abstract OnnxBertBiEncoder model();

    public Response<List<Embedding>> embedAll(List<TextSegment> segments) {
        ValidationUtils.ensureNotEmpty(segments, (String)"segments");
        if (segments.size() == 1) {
            return this.embedInTheSameThread(segments.get(0));
        }
        return this.parallelizeEmbedding(segments);
    }

    private Response<List<Embedding>> embedInTheSameThread(TextSegment segment) {
        OnnxBertBiEncoder.EmbeddingAndTokenCount embeddingAndTokenCount = this.model().embed(segment.text());
        return Response.from(Collections.singletonList(Embedding.from((float[])embeddingAndTokenCount.embedding)), (TokenUsage)new TokenUsage(Integer.valueOf(embeddingAndTokenCount.tokenCount - 2)));
    }

    private Response<List<Embedding>> parallelizeEmbedding(List<TextSegment> segments) {
        List futures = segments.stream().map(segment -> CompletableFuture.supplyAsync(() -> this.model().embed(segment.text()), this.executor)).collect(Collectors.toList());
        int inputTokenCount = 0;
        ArrayList<Embedding> embeddings = new ArrayList<Embedding>();
        for (CompletableFuture future : futures) {
            try {
                OnnxBertBiEncoder.EmbeddingAndTokenCount embeddingAndTokenCount = (OnnxBertBiEncoder.EmbeddingAndTokenCount)future.get();
                embeddings.add(Embedding.from((float[])embeddingAndTokenCount.embedding));
                inputTokenCount += embeddingAndTokenCount.tokenCount - 2;
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        return Response.from(embeddings, (TokenUsage)new TokenUsage(Integer.valueOf(inputTokenCount)));
    }

    public int estimateTokenCount(String text) {
        return this.model().countTokens(text);
    }
}

