/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.memory.elasticsearch;

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.SortOrder;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
import co.elastic.clients.elasticsearch.core.DeleteByQueryResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation;
import co.elastic.clients.elasticsearch.core.search.Hit;
import co.elastic.clients.json.JsonpMapper;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.ElasticsearchTransport;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import com.alibaba.cloud.ai.memory.elasticsearch.ElasticsearchConfig;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.Credentials;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public class ElasticsearchChatMemoryRepository
implements ChatMemoryRepository,
AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(ElasticsearchChatMemoryRepository.class);
    private static final String INDEX_NAME = "chat_memory";
    private final ElasticsearchConfig config;
    private final ElasticsearchClient client;
    private final ObjectMapper objectMapper;

    public ElasticsearchChatMemoryRepository(ElasticsearchConfig config) {
        this.config = config;
        this.objectMapper = new ObjectMapper();
        this.objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        try {
            this.client = this.createClient();
            this.createIndexIfNotExists();
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to create Elasticsearch client", e);
        }
    }

    private void createIndexIfNotExists() throws IOException {
        if (!this.client.indices().exists(e -> e.index(INDEX_NAME, new String[0])).value()) {
            this.createIndex();
        }
    }

    private void createIndex() throws IOException {
        this.client.indices().create(c -> c.index(INDEX_NAME).mappings(m -> m.properties("conversationId", p -> p.keyword(k -> k)).properties("messageType", p -> p.keyword(k -> k)).properties("messageText", p -> p.text(t -> t)).properties("timestamp", p -> p.date(d -> d))));
    }

    public void recreateIndex() throws IOException {
        if (this.client.indices().exists(e -> e.index(INDEX_NAME, new String[0])).value()) {
            this.client.indices().delete(d -> d.index(INDEX_NAME, new String[0]));
        }
        this.createIndex();
    }

    private ElasticsearchClient createClient() throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
        HttpHost[] httpHosts = !CollectionUtils.isEmpty(this.config.getNodes()) ? (HttpHost[])this.config.getNodes().stream().map(node -> {
            String[] parts = node.split(":");
            return new HttpHost(parts[0], Integer.parseInt(parts[1]), this.config.getScheme());
        }).toArray(HttpHost[]::new) : new HttpHost[]{new HttpHost(this.config.getHost(), this.config.getPort(), this.config.getScheme())};
        RestClientBuilder restClientBuilder = RestClient.builder((HttpHost[])httpHosts);
        if (StringUtils.hasText((String)this.config.getUsername()) && StringUtils.hasText((String)this.config.getPassword())) {
            BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
            credentialsProvider.setCredentials(AuthScope.ANY, (Credentials)new UsernamePasswordCredentials(this.config.getUsername(), this.config.getPassword()));
            if ("https".equalsIgnoreCase(this.config.getScheme())) {
                SSLContext sslContext = SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build();
                restClientBuilder.setHttpClientConfigCallback(arg_0 -> ElasticsearchChatMemoryRepository.lambda$createClient$16((CredentialsProvider)credentialsProvider, sslContext, arg_0));
            } else {
                restClientBuilder.setHttpClientConfigCallback(arg_0 -> ElasticsearchChatMemoryRepository.lambda$createClient$17((CredentialsProvider)credentialsProvider, arg_0));
            }
        }
        RestClientTransport transport = new RestClientTransport(restClientBuilder.build(), (JsonpMapper)new JacksonJsonpMapper());
        return new ElasticsearchClient((ElasticsearchTransport)transport);
    }

    public List<String> findConversationIds() {
        try {
            SearchResponse response = this.client.search(s -> s.index(INDEX_NAME, new String[0]).size(Integer.valueOf(10000)).query(q -> q.matchAll(m -> m)), ChatMessage.class);
            return response.hits().hits().stream().map(hit -> ((ChatMessage)hit.source()).getConversationId()).distinct().collect(Collectors.toList());
        }
        catch (IOException e) {
            throw new RuntimeException("Error finding conversation IDs", e);
        }
    }

    public List<Message> findByConversationId(String conversationId) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        try {
            logger.info("Finding messages for conversation: {}", (Object)conversationId);
            SearchResponse response = this.client.search(s -> s.index(INDEX_NAME, new String[0]).query(q -> q.term(t -> t.field("conversationId").value(conversationId))).sort(sort -> sort.field(f -> f.field("timestamp").order(SortOrder.Asc))), ChatMessage.class);
            List<Message> messages = response.hits().hits().stream().map(hit -> ((ChatMessage)hit.source()).toSpringMessage()).filter(Objects::nonNull).collect(Collectors.toList());
            logger.info("Found {} messages for conversation: {}", (Object)messages.size(), (Object)conversationId);
            return messages;
        }
        catch (IOException e) {
            logger.error("Error finding messages for conversation: {}", (Object)conversationId, (Object)e);
            throw new RuntimeException("Error finding messages for conversation: " + conversationId, e);
        }
    }

    public void saveAll(String conversationId, List<Message> messages) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        Assert.notNull(messages, (String)"messages cannot be null");
        Assert.noNullElements(messages, (String)"messages cannot contain null elements");
        try {
            this.deleteByConversationId(conversationId);
            BulkRequest.Builder br = new BulkRequest.Builder();
            for (Message message : messages) {
                ChatMessage chatMessage = new ChatMessage(conversationId, message);
                logger.info("Saving message for {}: type={}, text={}", new Object[]{conversationId, chatMessage.getMessageType(), chatMessage.getMessageText()});
                br.operations(op -> op.index(idx -> ((IndexOperation.Builder)idx.index(INDEX_NAME)).document((Object)chatMessage)));
            }
            BulkResponse response = this.client.bulk(br.build());
            if (response.errors()) {
                logger.error("Error saving messages: {}", (Object)response.items().stream().filter(item -> item.error() != null).map(item -> item.error().reason()).collect(Collectors.joining(", ")));
                throw new RuntimeException("Error saving messages to Elasticsearch");
            }
            this.client.indices().refresh(r -> r.index(INDEX_NAME, new String[0]));
            logger.info("Successfully saved {} messages for conversation {}", (Object)messages.size(), (Object)conversationId);
        }
        catch (IOException e) {
            logger.error("Error saving messages", (Throwable)e);
            throw new RuntimeException("Error saving messages", e);
        }
    }

    public void deleteByConversationId(String conversationId) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        try {
            DeleteByQueryResponse response = this.client.deleteByQuery(d -> d.index(INDEX_NAME, new String[0]).query(q -> q.term(t -> t.field("conversationId").value(conversationId))));
            if (response.failures().size() > 0) {
                throw new RuntimeException("Error deleting messages for conversation: " + conversationId);
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Error deleting messages", e);
        }
    }

    public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
        Assert.hasText((String)conversationId, (String)"conversationId cannot be null or empty");
        try {
            SearchResponse response = this.client.search(s -> s.index(INDEX_NAME, new String[0]).query(q -> q.term(t -> t.field("conversationId").value(conversationId))).sort(sort -> sort.field(f -> f.field("timestamp").order(SortOrder.Asc))), ChatMessage.class);
            List messages = response.hits().hits().stream().map(Hit::source).collect(Collectors.toList());
            if (messages.size() >= maxLimit) {
                this.deleteByConversationId(conversationId);
                List messagesToKeep = messages.stream().skip(deleteSize).collect(Collectors.toList());
                BulkRequest.Builder br = new BulkRequest.Builder();
                for (ChatMessage message : messagesToKeep) {
                    br.operations(op -> op.index(idx -> ((IndexOperation.Builder)idx.index(INDEX_NAME)).document((Object)message)));
                }
                BulkResponse bulkResponse = this.client.bulk(br.build());
                if (bulkResponse.errors()) {
                    throw new RuntimeException("Error saving messages to Elasticsearch");
                }
                this.client.indices().refresh(r -> r.index(INDEX_NAME, new String[0]));
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Error clearing over limit messages", e);
        }
    }

    @Override
    public void close() {
        if (Objects.nonNull(this.client)) {
            this.client.shutdown();
        }
    }

    public String rawSearchQuery(String conversationId) throws IOException {
        SearchResponse allResponse = this.client.search(s -> s.index(INDEX_NAME, new String[0]).query(q -> q.matchAll(m -> m)).size(Integer.valueOf(100)).source(src -> src.fetch(Boolean.valueOf(true))), Void.class);
        SearchResponse byIdResponseKeyword = this.client.search(s -> s.index(INDEX_NAME, new String[0]).query(q -> q.term(t -> t.field("conversationId.keyword").value(conversationId))).size(Integer.valueOf(100)).source(src -> src.fetch(Boolean.valueOf(true))), Void.class);
        SearchResponse byIdResponseNoKeyword = this.client.search(s -> s.index(INDEX_NAME, new String[0]).query(q -> q.term(t -> t.field("conversationId").value(conversationId))).size(Integer.valueOf(100)).source(src -> src.fetch(Boolean.valueOf(true))), Void.class);
        StringBuilder sb = new StringBuilder();
        sb.append("=== All documents (").append(allResponse.hits().total().value()).append(") ===\n");
        sb.append(allResponse.toString()).append("\n\n");
        sb.append("=== Documents for conversation with keyword field ").append(conversationId).append(" (").append(byIdResponseKeyword.hits().total().value()).append(") ===\n");
        sb.append(byIdResponseKeyword.toString()).append("\n\n");
        sb.append("=== Documents for conversation without keyword field ").append(conversationId).append(" (").append(byIdResponseNoKeyword.hits().total().value()).append(") ===\n");
        sb.append(byIdResponseNoKeyword.toString());
        return sb.toString();
    }

    private static /* synthetic */ HttpAsyncClientBuilder lambda$createClient$17(CredentialsProvider credentialsProvider, HttpAsyncClientBuilder httpClientBuilder) {
        return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
    }

    private static /* synthetic */ HttpAsyncClientBuilder lambda$createClient$16(CredentialsProvider credentialsProvider, SSLContext sslContext, HttpAsyncClientBuilder httpClientBuilder) {
        return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider).setSSLContext(sslContext).setSSLHostnameVerifier((HostnameVerifier)NoopHostnameVerifier.INSTANCE);
    }

    private static class ChatMessage {
        private String conversationId;
        private String messageType;
        private String messageText;
        private long timestamp;
        private Object message;

        public ChatMessage() {
        }

        public ChatMessage(String conversationId, Message message) {
            this.conversationId = conversationId;
            this.messageType = message.getMessageType().toString();
            this.messageText = message.getText();
            this.timestamp = System.currentTimeMillis();
        }

        public String getConversationId() {
            return this.conversationId;
        }

        public void setConversationId(String conversationId) {
            this.conversationId = conversationId;
        }

        public String getMessageType() {
            return this.messageType;
        }

        public void setMessageType(String messageType) {
            this.messageType = messageType;
        }

        public String getMessageText() {
            return this.messageText;
        }

        public void setMessageText(String messageText) {
            this.messageText = messageText;
        }

        public long getTimestamp() {
            return this.timestamp;
        }

        public void setTimestamp(long timestamp) {
            this.timestamp = timestamp;
        }

        public Object getMessage() {
            return this.message;
        }

        public void setMessage(Object message) {
            this.message = message;
        }

        public Message toSpringMessage() {
            try {
                if (this.messageType != null && this.messageText != null) {
                    switch (MessageType.valueOf((String)this.messageType)) {
                        case USER: {
                            return new UserMessage(this.messageText);
                        }
                        case ASSISTANT: {
                            return new AssistantMessage(this.messageText);
                        }
                        case SYSTEM: {
                            return new SystemMessage(this.messageText);
                        }
                    }
                    throw new IllegalStateException("Unknown message type: " + this.messageType);
                }
                if (this.message != null) {
                    logger.info("Using legacy message format: {}", this.message);
                    return new UserMessage("Legacy message - please reindex");
                }
                return null;
            }
            catch (Exception e) {
                logger.error("Error converting message", (Throwable)e);
                return new UserMessage("Error: " + e.getMessage());
            }
        }
    }
}

