/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.ollama;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.time.Duration;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class OllamaChatModel
extends AbstractToolCallSupport
implements ChatModel {
    private static final String DONE = "done";
    private static final String METADATA_PROMPT_EVAL_COUNT = "prompt-eval-count";
    private static final String METADATA_EVAL_COUNT = "eval-count";
    private static final String METADATA_CREATED_AT = "created-at";
    private static final String METADATA_TOTAL_DURATION = "total-duration";
    private static final String METADATA_LOAD_DURATION = "load-duration";
    private static final String METADATA_PROMPT_EVAL_DURATION = "prompt-eval-duration";
    private static final String METADATA_EVAL_DURATION = "eval-duration";
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private final OllamaApi chatApi;
    private final OllamaOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final OllamaModelManager modelManager;
    private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

    public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
        super(functionCallbackResolver, (FunctionCallingOptions)defaultOptions, toolFunctionCallbacks);
        Assert.notNull((Object)ollamaApi, (String)"ollamaApi must not be null");
        Assert.notNull((Object)defaultOptions, (String)"defaultOptions must not be null");
        Assert.notNull((Object)observationRegistry, (String)"observationRegistry must not be null");
        Assert.notNull((Object)modelManagementOptions, (String)"modelManagementOptions must not be null");
        this.chatApi = ollamaApi;
        this.defaultOptions = defaultOptions;
        this.observationRegistry = observationRegistry;
        this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
        this.initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
    }

    public static Builder builder() {
        return new Builder();
    }

    static ChatResponseMetadata from(OllamaApi.ChatResponse response, ChatResponse previousChatResponse) {
        Assert.notNull((Object)response, (String)"OllamaApi.ChatResponse must not be null");
        OllamaChatUsage newUsage = OllamaChatUsage.from(response);
        Long promptTokens = newUsage.getPromptTokens();
        Long generationTokens = newUsage.getGenerationTokens();
        Long totalTokens = newUsage.getTotalTokens();
        Duration evalDuration = response.getEvalDuration();
        Duration promptEvalDuration = response.getPromptEvalDuration();
        Duration loadDuration = response.getLoadDuration();
        Duration totalDuration = response.getTotalDuration();
        if (previousChatResponse != null && previousChatResponse.getMetadata() != null) {
            if (previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION) != null) {
                evalDuration = evalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_EVAL_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION) != null) {
                promptEvalDuration = promptEvalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_PROMPT_EVAL_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION) != null) {
                loadDuration = loadDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_LOAD_DURATION));
            }
            if (previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION) != null) {
                totalDuration = totalDuration.plus((Duration)previousChatResponse.getMetadata().get(METADATA_TOTAL_DURATION));
            }
            if (previousChatResponse.getMetadata().getUsage() != null) {
                promptTokens = promptTokens + previousChatResponse.getMetadata().getUsage().getPromptTokens();
                generationTokens = generationTokens + previousChatResponse.getMetadata().getUsage().getGenerationTokens();
                totalTokens = totalTokens + previousChatResponse.getMetadata().getUsage().getTotalTokens();
            }
        }
        DefaultUsage aggregatedUsage = new DefaultUsage(promptTokens, generationTokens, totalTokens);
        return ChatResponseMetadata.builder().usage((Usage)aggregatedUsage).model(response.model()).keyValue(METADATA_CREATED_AT, (Object)response.createdAt()).keyValue(METADATA_EVAL_DURATION, (Object)evalDuration).keyValue(METADATA_EVAL_COUNT, (Object)aggregatedUsage.getGenerationTokens().intValue()).keyValue(METADATA_LOAD_DURATION, (Object)loadDuration).keyValue(METADATA_PROMPT_EVAL_DURATION, (Object)promptEvalDuration).keyValue(METADATA_PROMPT_EVAL_COUNT, (Object)aggregatedUsage.getPromptTokens().intValue()).keyValue(METADATA_TOTAL_DURATION, (Object)totalDuration).keyValue(DONE, (Object)response.done()).build();
    }

    public ChatResponse call(Prompt prompt) {
        return this.internalCall(prompt, null);
    }

    private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
        OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, false);
        ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(request)).build();
        ChatResponse response = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
            List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
            AssistantMessage assistantMessage = new AssistantMessage(ollamaResponse.message().content(), Map.of(), toolCalls);
            ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
            if (ollamaResponse.promptEvalCount() != null && ollamaResponse.evalCount() != null) {
                generationMetadata = ChatGenerationMetadata.builder().finishReason(ollamaResponse.doneReason()).build();
            }
            Generation generator = new Generation(assistantMessage, generationMetadata);
            ChatResponse chatResponse = new ChatResponse(List.of(generator), OllamaChatModel.from(ollamaResponse, previousChatResponse));
            observationContext.setResponse((Object)chatResponse);
            return chatResponse;
        });
        if (!this.isProxyToolCalls(prompt, this.defaultOptions) && response != null && this.isToolCall(response, Set.of("stop"))) {
            List toolCallConversation = this.handleToolCalls(prompt, response);
            return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
        }
        return response;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return this.internalStream(prompt, null);
    }

    private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
        return Flux.deferContextual(contextView -> {
            OllamaApi.ChatRequest request = this.ollamaChatRequest(prompt, true);
            ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(OllamaApi.PROVIDER_NAME).requestOptions(this.buildRequestOptions(request)).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
            observation.parentObservation((Observation)contextView.getOrDefault((Object)"micrometer.observation", null)).start();
            Flux<OllamaApi.ChatResponse> ollamaResponse = this.chatApi.streamingChat(request);
            Flux chatResponse = ollamaResponse.map(chunk -> {
                String content = chunk.message() != null ? chunk.message().content() : "";
                List<Object> toolCalls = List.of();
                if (chunk.message() != null && chunk.message().toolCalls() != null) {
                    toolCalls = chunk.message().toolCalls().stream().map(toolCall -> new AssistantMessage.ToolCall("", "function", toolCall.function().name(), ModelOptionsUtils.toJsonString(toolCall.function().arguments()))).toList();
                }
                AssistantMessage assistantMessage = new AssistantMessage(content, Map.of(), toolCalls);
                ChatGenerationMetadata generationMetadata = ChatGenerationMetadata.NULL;
                if (chunk.promptEvalCount() != null && chunk.evalCount() != null) {
                    generationMetadata = ChatGenerationMetadata.builder().finishReason(chunk.doneReason()).build();
                }
                Generation generator = new Generation(assistantMessage, generationMetadata);
                return new ChatResponse(List.of(generator), OllamaChatModel.from(chunk, previousChatResponse));
            });
            Flux chatResponseFlux = chatResponse.flatMap(response -> {
                if (this.isToolCall((ChatResponse)response, Set.of("stop"))) {
                    List toolCallConversation = this.handleToolCalls(prompt, (ChatResponse)response);
                    return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), (ChatResponse)response);
                }
                return Flux.just((Object)response);
            }).doOnError(arg_0 -> ((Observation)observation).error(arg_0)).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put((Object)"micrometer.observation", (Object)observation));
            return new MessageAggregator().aggregate(chatResponseFlux, arg_0 -> ((ChatModelObservationContext)observationContext).setResponse(arg_0));
        });
    }

    OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
        OllamaOptions mergedOptions;
        List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
            if (message instanceof UserMessage) {
                UserMessage userMessage = (UserMessage)message;
                OllamaApi.Message.Builder messageBuilder = OllamaApi.Message.builder(OllamaApi.Message.Role.USER).content(message.getText());
                if (!CollectionUtils.isEmpty((Collection)userMessage.getMedia())) {
                    messageBuilder.images(userMessage.getMedia().stream().map(media -> this.fromMediaData(media.getData())).toList());
                }
                return List.of(messageBuilder.build());
            }
            if (message instanceof SystemMessage) {
                SystemMessage systemMessage = (SystemMessage)message;
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.SYSTEM).content(systemMessage.getText()).build());
            }
            if (message instanceof AssistantMessage) {
                AssistantMessage assistantMessage = (AssistantMessage)message;
                List<OllamaApi.Message.ToolCall> toolCalls = null;
                if (!CollectionUtils.isEmpty((Collection)assistantMessage.getToolCalls())) {
                    toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
                        OllamaApi.Message.ToolCallFunction function = new OllamaApi.Message.ToolCallFunction(toolCall.name(), ModelOptionsUtils.jsonToMap((String)toolCall.arguments()));
                        return new OllamaApi.Message.ToolCall(function);
                    }).toList();
                }
                return List.of(OllamaApi.Message.builder(OllamaApi.Message.Role.ASSISTANT).content(assistantMessage.getText()).toolCalls(toolCalls).build());
            }
            if (message instanceof ToolResponseMessage) {
                ToolResponseMessage toolMessage = (ToolResponseMessage)message;
                return toolMessage.getResponses().stream().map(tr -> OllamaApi.Message.builder(OllamaApi.Message.Role.TOOL).content(tr.responseData()).build()).toList();
            }
            throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.getMessageType()));
        }).flatMap(Collection::stream).toList();
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        OllamaOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            ChatOptions chatOptions = prompt.getOptions();
            if (chatOptions instanceof FunctionCallingOptions) {
                FunctionCallingOptions functionCallingOptions = (FunctionCallingOptions)chatOptions;
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)functionCallingOptions, FunctionCallingOptions.class, OllamaOptions.class);
            } else {
                runtimeOptions = (OllamaOptions)ModelOptionsUtils.copyToTarget((Object)prompt.getOptions(), ChatOptions.class, OllamaOptions.class);
            }
            functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(runtimeOptions));
        }
        if (!CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) {
            functionsForThisRequest.addAll(this.defaultOptions.getFunctions());
        }
        if (!StringUtils.hasText((String)(mergedOptions = (OllamaOptions)ModelOptionsUtils.merge(runtimeOptions, (Object)this.defaultOptions, OllamaOptions.class)).getModel())) {
            throw new IllegalArgumentException("Model is not set!");
        }
        String model = mergedOptions.getModel();
        OllamaApi.ChatRequest.Builder requestBuilder = OllamaApi.ChatRequest.builder(model).stream(stream).messages(ollamaMessages).options(mergedOptions);
        if (mergedOptions.getFormat() != null) {
            requestBuilder.format(mergedOptions.getFormat());
        }
        if (mergedOptions.getKeepAlive() != null) {
            requestBuilder.keepAlive(mergedOptions.getKeepAlive());
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            requestBuilder.tools(this.getFunctionTools(functionsForThisRequest));
        }
        return requestBuilder.build();
    }

    private String fromMediaData(Object mediaData) {
        if (mediaData instanceof byte[]) {
            byte[] bytes = (byte[])mediaData;
            return Base64.getEncoder().encodeToString(bytes);
        }
        if (mediaData instanceof String) {
            String text = (String)mediaData;
            return text;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
    }

    private List<OllamaApi.ChatRequest.Tool> getFunctionTools(Set<String> functionNames) {
        return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
            OllamaApi.ChatRequest.Tool.Function function = new OllamaApi.ChatRequest.Tool.Function(functionCallback.getName(), functionCallback.getDescription(), functionCallback.getInputTypeSchema());
            return new OllamaApi.ChatRequest.Tool(function);
        }).toList();
    }

    private ChatOptions buildRequestOptions(OllamaApi.ChatRequest request) {
        OllamaOptions options = (OllamaOptions)ModelOptionsUtils.mapToClass(request.options(), OllamaOptions.class);
        return ChatOptions.builder().model(request.model()).frequencyPenalty(options.getFrequencyPenalty()).maxTokens(options.getMaxTokens()).presencePenalty(options.getPresencePenalty()).stopSequences(options.getStopSequences()).temperature(options.getTemperature()).topK(options.getTopK()).topP(options.getTopP()).build();
    }

    public ChatOptions getDefaultOptions() {
        return OllamaOptions.fromOptions(this.defaultOptions);
    }

    private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
        if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals((Object)pullModelStrategy)) {
            this.modelManager.pullModel(model, pullModelStrategy);
        }
    }

    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static final class Builder {
        private OllamaApi ollamaApi;
        private OllamaOptions defaultOptions = OllamaOptions.builder().model(OllamaModel.MISTRAL.id()).build();
        private FunctionCallbackResolver functionCallbackResolver;
        private List<FunctionCallback> toolFunctionCallbacks = List.of();
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();

        private Builder() {
        }

        public Builder ollamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        public Builder defaultOptions(OllamaOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        @Deprecated
        public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackContext) {
            this.functionCallbackResolver = functionCallbackContext;
            return this;
        }

        public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
            this.functionCallbackResolver = functionCallbackResolver;
            return this;
        }

        public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
            this.toolFunctionCallbacks = toolFunctionCallbacks;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder modelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public Builder withOllamaApi(OllamaApi ollamaApi) {
            this.ollamaApi = ollamaApi;
            return this;
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public Builder withDefaultOptions(OllamaOptions defaultOptions) {
            this.defaultOptions = defaultOptions;
            return this;
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
            this.toolFunctionCallbacks = toolFunctionCallbacks;
            return this;
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public Builder withModelManagementOptions(ModelManagementOptions modelManagementOptions) {
            this.modelManagementOptions = modelManagementOptions;
            return this;
        }

        public OllamaChatModel build() {
            return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver, this.toolFunctionCallbacks, this.observationRegistry, this.modelManagementOptions);
        }
    }
}

