/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.zhipuai;

import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import reactor.core.publisher.Flux;

public class ZhiPuAiChatModel
extends AbstractFunctionCallSupport<ZhiPuAiApi.ChatCompletionMessage, ZhiPuAiApi.ChatCompletionRequest, ResponseEntity<ZhiPuAiApi.ChatCompletion>>
implements ChatModel,
StreamingChatModel {
    private static final Logger logger = LoggerFactory.getLogger(ZhiPuAiChatModel.class);
    private ZhiPuAiChatOptions defaultOptions;
    public final RetryTemplate retryTemplate;
    private final ZhiPuAiApi zhiPuAiApi;

    public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi) {
        this(zhiPuAiApi, ZhiPuAiChatOptions.builder().withModel(ZhiPuAiApi.DEFAULT_CHAT_MODEL).withTemperature(Float.valueOf(0.7f)).build());
    }

    public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options) {
        this(zhiPuAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
    }

    public ZhiPuAiChatModel(ZhiPuAiApi zhiPuAiApi, ZhiPuAiChatOptions options, FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
        super(functionCallbackContext);
        Assert.notNull((Object)zhiPuAiApi, (String)"ZhiPuAiApi must not be null");
        Assert.notNull((Object)options, (String)"Options must not be null");
        Assert.notNull((Object)retryTemplate, (String)"RetryTemplate must not be null");
        this.zhiPuAiApi = zhiPuAiApi;
        this.defaultOptions = options;
        this.retryTemplate = retryTemplate;
    }

    public ChatResponse call(Prompt prompt) {
        ZhiPuAiApi.ChatCompletionRequest request = this.createRequest(prompt, false);
        return (ChatResponse)this.retryTemplate.execute(ctx -> {
            ResponseEntity completionEntity = (ResponseEntity)this.callWithFunctionSupport(request);
            ZhiPuAiApi.ChatCompletion chatCompletion = (ZhiPuAiApi.ChatCompletion)completionEntity.getBody();
            if (chatCompletion == null) {
                logger.warn("No chat completion returned for prompt: {}", (Object)prompt);
                return new ChatResponse(List.of());
            }
            List<Generation> generations = chatCompletion.choices().stream().map(choice -> new Generation(choice.message().content(), this.toMap(chatCompletion.id(), (ZhiPuAiApi.ChatCompletion.Choice)choice)).withGenerationMetadata(ChatGenerationMetadata.from((String)choice.finishReason().name(), null))).toList();
            return new ChatResponse(generations);
        });
    }

    private Map<String, Object> toMap(String id, ZhiPuAiApi.ChatCompletion.Choice choice) {
        HashMap<String, Object> map = new HashMap<String, Object>();
        ZhiPuAiApi.ChatCompletionMessage message = choice.message();
        if (message.role() != null) {
            map.put("role", message.role().name());
        }
        if (choice.finishReason() != null) {
            map.put("finishReason", choice.finishReason().name());
        }
        map.put("id", id);
        return map;
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        ZhiPuAiApi.ChatCompletionRequest request = this.createRequest(prompt, true);
        return (Flux)this.retryTemplate.execute(ctx -> {
            Flux<ZhiPuAiApi.ChatCompletionChunk> completionChunks = this.zhiPuAiApi.chatCompletionStream(request);
            ConcurrentHashMap roleMap = new ConcurrentHashMap();
            return completionChunks.map(chunk -> this.chunkToChatCompletion((ZhiPuAiApi.ChatCompletionChunk)chunk)).map(chatCompletion -> {
                try {
                    chatCompletion = (ZhiPuAiApi.ChatCompletion)((ResponseEntity)this.handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))).getBody();
                    String id = chatCompletion.id();
                    List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
                        if (choice.message().role() != null) {
                            roleMap.putIfAbsent(id, choice.message().role().name());
                        }
                        String finish = choice.finishReason() != null ? choice.finishReason().name() : "";
                        Generation generation = new Generation(choice.message().content(), Map.of("id", id, "role", roleMap.get(id), "finishReason", finish));
                        if (choice.finishReason() != null) {
                            generation = generation.withGenerationMetadata(ChatGenerationMetadata.from((String)choice.finishReason().name(), null));
                        }
                        return generation;
                    }).toList();
                    return new ChatResponse(generations);
                }
                catch (Exception e) {
                    logger.error("Error processing chat completion", (Throwable)e);
                    return new ChatResponse(List.of());
                }
            });
        });
    }

    private ZhiPuAiApi.ChatCompletion chunkToChatCompletion(ZhiPuAiApi.ChatCompletionChunk chunk) {
        List<ZhiPuAiApi.ChatCompletion.Choice> choices = chunk.choices().stream().map(cc -> new ZhiPuAiApi.ChatCompletion.Choice(cc.finishReason(), cc.index(), cc.delta(), cc.logprobs())).toList();
        return new ZhiPuAiApi.ChatCompletion(chunk.id(), choices, chunk.created(), chunk.model(), chunk.systemFingerprint(), "chat.completion", null);
    }

    ZhiPuAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
        HashSet<String> functionsForThisRequest = new HashSet<String>();
        List<ZhiPuAiApi.ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> {
            ArrayList<ZhiPuAiApi.ChatCompletionMessage.MediaContent> contents = new ArrayList<ZhiPuAiApi.ChatCompletionMessage.MediaContent>(List.of(new ZhiPuAiApi.ChatCompletionMessage.MediaContent(m.getContent())));
            if (!CollectionUtils.isEmpty((Collection)m.getMedia())) {
                contents.addAll(m.getMedia().stream().map(media -> new ZhiPuAiApi.ChatCompletionMessage.MediaContent(new ZhiPuAiApi.ChatCompletionMessage.MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())))).toList());
            }
            return new ZhiPuAiApi.ChatCompletionMessage(contents, ZhiPuAiApi.ChatCompletionMessage.Role.valueOf(m.getMessageType().name()));
        }).toList();
        ZhiPuAiApi.ChatCompletionRequest request = new ZhiPuAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
        if (prompt.getOptions() != null) {
            ModelOptions modelOptions = prompt.getOptions();
            if (modelOptions instanceof ChatOptions) {
                ChatOptions runtimeOptions = (ChatOptions)modelOptions;
                ZhiPuAiChatOptions updatedRuntimeOptions = (ZhiPuAiChatOptions)ModelOptionsUtils.copyToTarget((Object)runtimeOptions, ChatOptions.class, ZhiPuAiChatOptions.class);
                Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, true);
                functionsForThisRequest.addAll(promptEnabledFunctions);
                request = (ZhiPuAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)updatedRuntimeOptions, (Object)request, ZhiPuAiApi.ChatCompletionRequest.class);
            } else {
                throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName());
            }
        }
        if (this.defaultOptions != null) {
            Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, false);
            functionsForThisRequest.addAll(defaultEnabledFunctions);
            request = (ZhiPuAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)request, (Object)this.defaultOptions, ZhiPuAiApi.ChatCompletionRequest.class);
        }
        if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
            request = (ZhiPuAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)ZhiPuAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), (Object)request, ZhiPuAiApi.ChatCompletionRequest.class);
        }
        return request;
    }

    private String fromMediaData(MimeType mimeType, Object mediaContentData) {
        if (mediaContentData instanceof byte[]) {
            byte[] bytes = (byte[])mediaContentData;
            return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
        }
        if (mediaContentData instanceof String) {
            String text = (String)mediaContentData;
            return text;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
    }

    private List<ZhiPuAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
        return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
            ZhiPuAiApi.FunctionTool.Function function = new ZhiPuAiApi.FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(), functionCallback.getInputTypeSchema());
            return new ZhiPuAiApi.FunctionTool(function);
        }).toList();
    }

    protected ZhiPuAiApi.ChatCompletionRequest doCreateToolResponseRequest(ZhiPuAiApi.ChatCompletionRequest previousRequest, ZhiPuAiApi.ChatCompletionMessage responseMessage, List<ZhiPuAiApi.ChatCompletionMessage> conversationHistory) {
        for (ZhiPuAiApi.ChatCompletionMessage.ToolCall toolCall : responseMessage.toolCalls()) {
            String functionName = toolCall.function().name();
            String functionArguments = toolCall.function().arguments();
            if (!this.functionCallbackRegister.containsKey(functionName)) {
                throw new IllegalStateException("No function callback found for function name: " + functionName);
            }
            String functionResponse = ((FunctionCallback)this.functionCallbackRegister.get(functionName)).call(functionArguments);
            conversationHistory.add(new ZhiPuAiApi.ChatCompletionMessage(functionResponse, ZhiPuAiApi.ChatCompletionMessage.Role.TOOL, functionName, toolCall.id(), null));
        }
        ZhiPuAiApi.ChatCompletionRequest newRequest = new ZhiPuAiApi.ChatCompletionRequest(conversationHistory, false);
        newRequest = (ZhiPuAiApi.ChatCompletionRequest)ModelOptionsUtils.merge((Object)newRequest, (Object)previousRequest, ZhiPuAiApi.ChatCompletionRequest.class);
        return newRequest;
    }

    protected List<ZhiPuAiApi.ChatCompletionMessage> doGetUserMessages(ZhiPuAiApi.ChatCompletionRequest request) {
        return request.messages();
    }

    protected ZhiPuAiApi.ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ZhiPuAiApi.ChatCompletion> chatCompletion) {
        return ((ZhiPuAiApi.ChatCompletion)chatCompletion.getBody()).choices().iterator().next().message();
    }

    protected ResponseEntity<ZhiPuAiApi.ChatCompletion> doChatCompletion(ZhiPuAiApi.ChatCompletionRequest request) {
        return this.zhiPuAiApi.chatCompletionEntity(request);
    }

    protected Flux<ResponseEntity<ZhiPuAiApi.ChatCompletion>> doChatCompletionStream(ZhiPuAiApi.ChatCompletionRequest request) {
        return this.zhiPuAiApi.chatCompletionStream(request).map(this::chunkToChatCompletion).map(Optional::ofNullable).map(ResponseEntity::of);
    }

    protected boolean isToolFunctionCall(ResponseEntity<ZhiPuAiApi.ChatCompletion> chatCompletion) {
        ZhiPuAiApi.ChatCompletion body = (ZhiPuAiApi.ChatCompletion)chatCompletion.getBody();
        if (body == null) {
            return false;
        }
        List<ZhiPuAiApi.ChatCompletion.Choice> choices = body.choices();
        if (CollectionUtils.isEmpty(choices)) {
            return false;
        }
        ZhiPuAiApi.ChatCompletion.Choice choice = choices.get(0);
        return !CollectionUtils.isEmpty(choice.message().toolCalls()) && choice.finishReason() == ZhiPuAiApi.ChatCompletionFinishReason.TOOL_CALLS;
    }

    public ChatOptions getDefaultOptions() {
        return ZhiPuAiChatOptions.fromOptions(this.defaultOptions);
    }
}

