/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.milvus;

import com.alibaba.fastjson.JSONObject;
import io.micrometer.observation.ObservationRegistry;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.exception.ParamException;
import io.milvus.grpc.DataType;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.DropCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.collection.ReleaseCollectionParam;
import io.milvus.param.dml.DeleteParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.param.index.DescribeIndexParam;
import io.milvus.param.index.DropIndexParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.milvus.MilvusFilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class MilvusVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final int OPENAI_EMBEDDING_DIMENSION_SIZE = 1536;
    public static final int INVALID_EMBEDDING_DIMENSION = -1;
    public static final String DEFAULT_DATABASE_NAME = "default";
    public static final String DEFAULT_COLLECTION_NAME = "vector_store";
    public static final String DOC_ID_FIELD_NAME = "doc_id";
    public static final String CONTENT_FIELD_NAME = "content";
    public static final String METADATA_FIELD_NAME = "metadata";
    public static final String EMBEDDING_FIELD_NAME = "embedding";
    private static final String DISTANCE_FIELD_NAME = "distance";
    private static final Logger logger = LoggerFactory.getLogger(MilvusVectorStore.class);
    private static final Map<MetricType, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(MetricType.COSINE, VectorStoreSimilarityMetric.COSINE, MetricType.L2, VectorStoreSimilarityMetric.EUCLIDEAN, MetricType.IP, VectorStoreSimilarityMetric.DOT);
    public final FilterExpressionConverter filterExpressionConverter = new MilvusFilterExpressionConverter();
    private final MilvusServiceClient milvusClient;
    @Deprecated(forRemoval=true, since="1.0.0-M5")
    private final MilvusVectorStoreConfig config;
    private final boolean initializeSchema;
    private final BatchingStrategy batchingStrategy;
    private final String databaseName;
    private final String collectionName;
    private final int embeddingDimension;
    private final IndexType indexType;
    private final MetricType metricType;
    private final String indexParameters;
    private final String idFieldName;
    private final boolean isAutoId;
    private final String contentFieldName;
    private final String metadataFieldName;
    private final String embeddingFieldName;

    @Deprecated(forRemoval=true, since="1.0.0-M5")
    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
        this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, (BatchingStrategy)new TokenCountBatchingStrategy());
    }

    @Deprecated(forRemoval=true, since="1.0.0-M5")
    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, boolean initializeSchema, BatchingStrategy batchingStrategy) {
        this(milvusClient, embeddingModel, MilvusVectorStoreConfig.defaultConfig(), initializeSchema, batchingStrategy);
    }

    @Deprecated(forRemoval=true, since="1.0.0-M5")
    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, MilvusVectorStoreConfig config, boolean initializeSchema, BatchingStrategy batchingStrategy) {
        this(milvusClient, embeddingModel, config, initializeSchema, batchingStrategy, ObservationRegistry.NOOP, null);
    }

    @Deprecated(forRemoval=true, since="1.0.0-M5")
    public MilvusVectorStore(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel, MilvusVectorStoreConfig config, boolean initializeSchema, BatchingStrategy batchingStrategy, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention) {
        this(((Builder)((Builder)MilvusVectorStore.builder(milvusClient, embeddingModel).observationRegistry(observationRegistry)).customObservationConvention(customObservationConvention)).initializeSchema(initializeSchema).batchingStrategy(batchingStrategy));
    }

    protected MilvusVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.milvusClient, (String)"milvusClient must not be null");
        this.milvusClient = builder.milvusClient;
        this.batchingStrategy = builder.batchingStrategy;
        this.initializeSchema = builder.initializeSchema;
        this.config = null;
        this.databaseName = builder.databaseName;
        this.collectionName = builder.collectionName;
        this.embeddingDimension = builder.embeddingDimension;
        this.indexType = builder.indexType;
        this.metricType = builder.metricType;
        this.indexParameters = builder.indexParameters;
        this.idFieldName = builder.idFieldName;
        this.isAutoId = builder.isAutoId;
        this.contentFieldName = builder.contentFieldName;
        this.metadataFieldName = builder.metadataFieldName;
        this.embeddingFieldName = builder.embeddingFieldName;
    }

    public static Builder builder(MilvusServiceClient milvusServiceClient, EmbeddingModel embeddingModel) {
        return new Builder(milvusServiceClient, embeddingModel);
    }

    public void doAdd(List<Document> documents) {
        Assert.notNull(documents, (String)"Documents must not be null");
        ArrayList<String> docIdArray = new ArrayList<String>();
        ArrayList<String> contentArray = new ArrayList<String>();
        ArrayList<JSONObject> metadataArray = new ArrayList<JSONObject>();
        ArrayList<List> embeddingArray = new ArrayList<List>();
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        for (Document document : documents) {
            docIdArray.add(document.getId());
            contentArray.add(document.getText());
            metadataArray.add(new JSONObject(document.getMetadata()));
            embeddingArray.add(EmbeddingUtils.toList((float[])((float[])embeddings.get(documents.indexOf(document)))));
        }
        ArrayList<InsertParam.Field> fields = new ArrayList<InsertParam.Field>();
        if (!this.isAutoId) {
            fields.add(new InsertParam.Field(this.idFieldName, docIdArray));
        }
        fields.add(new InsertParam.Field(this.contentFieldName, contentArray));
        fields.add(new InsertParam.Field(this.metadataFieldName, metadataArray));
        fields.add(new InsertParam.Field(this.embeddingFieldName, embeddingArray));
        InsertParam insertParam = InsertParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withFields(fields).build();
        R status = this.milvusClient.insert(insertParam);
        if (status.getException() != null) {
            throw new RuntimeException("Failed to insert:", status.getException());
        }
    }

    public Optional<Boolean> doDelete(List<String> idList) {
        Assert.notNull(idList, (String)"Document id list must not be null");
        String deleteExpression = String.format("%s in [%s]", this.idFieldName, idList.stream().map(id -> "'" + id + "'").collect(Collectors.joining(",")));
        R status = this.milvusClient.delete(DeleteParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withExpr(deleteExpression).build());
        long deleteCount = ((MutationResult)status.getData()).getDeleteCnt();
        if (deleteCount != (long)idList.size()) {
            logger.warn(String.format("Deleted only %s entries from requested %s ", deleteCount, idList.size()));
        }
        return Optional.of(status.getStatus().intValue() == R.Status.Success.getCode());
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        R respSearch;
        String nativeFilterExpressions = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        Assert.notNull((Object)request.getQuery(), (String)"Query string must not be null");
        ArrayList<String> outFieldNames = new ArrayList<String>();
        outFieldNames.add(this.idFieldName);
        outFieldNames.add(this.contentFieldName);
        outFieldNames.add(this.metadataFieldName);
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        SearchParam.Builder searchParamBuilder = SearchParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withConsistencyLevel(ConsistencyLevelEnum.STRONG).withMetricType(this.metricType).withOutFields(outFieldNames).withTopK(Integer.valueOf(request.getTopK())).withVectors(List.of(EmbeddingUtils.toList((float[])embedding))).withVectorFieldName(this.embeddingFieldName);
        if (StringUtils.hasText((String)nativeFilterExpressions)) {
            searchParamBuilder.withExpr(nativeFilterExpressions);
        }
        if ((respSearch = this.milvusClient.search(searchParamBuilder.build())).getException() != null) {
            throw new RuntimeException("Search failed!", respSearch.getException());
        }
        SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(((SearchResults)respSearch.getData()).getResults());
        return wrapperSearch.getRowRecords(0).stream().filter(rowRecord -> (double)this.getResultSimilarity((QueryResultsWrapper.RowRecord)rowRecord) >= request.getSimilarityThreshold()).map(rowRecord -> {
            String docId = String.valueOf(rowRecord.get(this.idFieldName));
            String content = (String)rowRecord.get(this.contentFieldName);
            JSONObject metadata = null;
            try {
                metadata = (JSONObject)rowRecord.get(this.metadataFieldName);
                metadata.put(DocumentMetadata.DISTANCE.value(), (Object)Float.valueOf(1.0f - this.getResultSimilarity((QueryResultsWrapper.RowRecord)rowRecord)));
            }
            catch (ParamException paramException) {
                // empty catch block
            }
            return Document.builder().id(docId).text(content).metadata(metadata != null ? metadata.getInnerMap() : Map.of()).score(Double.valueOf(this.getResultSimilarity((QueryResultsWrapper.RowRecord)rowRecord))).build();
        }).toList();
    }

    private float getResultSimilarity(QueryResultsWrapper.RowRecord rowRecord) {
        Float distance = (Float)rowRecord.get(DISTANCE_FIELD_NAME);
        return this.metricType == MetricType.IP || this.metricType == MetricType.COSINE ? distance.floatValue() : 1.0f - distance.floatValue();
    }

    public void afterPropertiesSet() throws Exception {
        if (!this.initializeSchema) {
            return;
        }
        this.createCollection();
    }

    void releaseCollection() {
        if (this.isDatabaseCollectionExists()) {
            this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.collectionName).build());
        }
    }

    private boolean isDatabaseCollectionExists() {
        return (Boolean)this.milvusClient.hasCollection(HasCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build()).getData();
    }

    void createCollection() {
        R indexStatus;
        R indexDescriptionResponse;
        if (!this.isDatabaseCollectionExists()) {
            this.createCollection(this.databaseName, this.collectionName, this.idFieldName, this.isAutoId, this.contentFieldName, this.metadataFieldName, this.embeddingFieldName);
        }
        if ((indexDescriptionResponse = this.milvusClient.describeIndex(DescribeIndexParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build())).getData() == null && (indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).withFieldName(this.embeddingFieldName).withIndexType(this.indexType).withMetricType(this.metricType).withExtraParam(this.indexParameters).withSyncMode(Boolean.FALSE).build())).getException() != null) {
            throw new RuntimeException("Failed to create Index", indexStatus.getException());
        }
        R loadCollectionStatus = this.milvusClient.loadCollection(LoadCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build());
        if (loadCollectionStatus.getException() != null) {
            throw new RuntimeException("Collection loading failed!", loadCollectionStatus.getException());
        }
    }

    void createCollection(String databaseName, String collectionName, String idFieldName, boolean isAutoId, String contentFieldName, String metadataFieldName, String embeddingFieldName) {
        FieldType docIdFieldType = FieldType.newBuilder().withName(idFieldName).withDataType(DataType.VarChar).withMaxLength(Integer.valueOf(36)).withPrimaryKey(true).withAutoID(isAutoId).build();
        FieldType contentFieldType = FieldType.newBuilder().withName(contentFieldName).withDataType(DataType.VarChar).withMaxLength(Integer.valueOf(65535)).build();
        FieldType metadataFieldType = FieldType.newBuilder().withName(metadataFieldName).withDataType(DataType.JSON).build();
        FieldType embeddingFieldType = FieldType.newBuilder().withName(embeddingFieldName).withDataType(DataType.FloatVector).withDimension(Integer.valueOf(this.embeddingDimensions())).build();
        CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder().withDatabaseName(databaseName).withCollectionName(collectionName).withDescription("Spring AI Vector Store").withConsistencyLevel(ConsistencyLevelEnum.STRONG).withShardsNum(2).addFieldType(docIdFieldType).addFieldType(contentFieldType).addFieldType(metadataFieldType).addFieldType(embeddingFieldType).build();
        R collectionStatus = this.milvusClient.createCollection(createCollectionReq);
        if (collectionStatus.getException() != null) {
            throw new RuntimeException("Failed to create collection", collectionStatus.getException());
        }
    }

    int embeddingDimensions() {
        if (this.embeddingDimension != -1) {
            return this.embeddingDimension;
        }
        try {
            int embeddingDimensions = this.embeddingModel.dimensions();
            if (embeddingDimensions > 0) {
                return embeddingDimensions;
            }
        }
        catch (Exception e) {
            logger.warn("Failed to obtain the embedding dimensions from the embedding model and fall backs to default:" + this.embeddingDimension, (Throwable)e);
        }
        return 1536;
    }

    void dropCollection() {
        R status = this.milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder().withCollectionName(this.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Release collection failed!", status.getException());
        }
        status = this.milvusClient.dropIndex(DropIndexParam.newBuilder().withCollectionName(this.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Drop Index failed!", status.getException());
        }
        status = this.milvusClient.dropCollection(DropCollectionParam.newBuilder().withDatabaseName(this.databaseName).withCollectionName(this.collectionName).build());
        if (status.getException() != null) {
            throw new RuntimeException("Drop Collection failed!", status.getException());
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.MILVUS.value(), (String)operationName).collectionName(this.collectionName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).similarityMetric(this.getSimilarityMetric()).namespace(this.databaseName);
    }

    private String getSimilarityMetric() {
        if (!SIMILARITY_TYPE_MAPPING.containsKey(this.metricType)) {
            return this.metricType.name();
        }
        return SIMILARITY_TYPE_MAPPING.get(this.metricType).value();
    }

    @Deprecated(forRemoval=true, since="1.0.0-M5")
    public static final class MilvusVectorStoreConfig {
        private final String databaseName;
        private final String collectionName;
        private final int embeddingDimension;
        private final IndexType indexType;
        private final MetricType metricType;
        private final String indexParameters;
        private final String idFieldName;
        private final boolean isAutoId;
        private final String contentFieldName;
        private final String metadataFieldName;
        private final String embeddingFieldName;

        private MilvusVectorStoreConfig(Builder builder) {
            this.databaseName = builder.databaseName;
            this.collectionName = builder.collectionName;
            this.embeddingDimension = builder.embeddingDimension;
            this.indexType = builder.indexType;
            this.metricType = builder.metricType;
            this.indexParameters = builder.indexParameters;
            this.idFieldName = builder.idFieldName;
            this.isAutoId = builder.isAutoId;
            this.contentFieldName = builder.contentFieldName;
            this.metadataFieldName = builder.metadataFieldName;
            this.embeddingFieldName = builder.embeddingFieldName;
        }

        public static Builder builder() {
            return new Builder();
        }

        public static MilvusVectorStoreConfig defaultConfig() {
            return MilvusVectorStoreConfig.builder().build();
        }

        @Deprecated(forRemoval=true, since="1.0.0-M5")
        public static final class Builder {
            private String databaseName = "default";
            private String collectionName = "vector_store";
            private int embeddingDimension = -1;
            private IndexType indexType = IndexType.IVF_FLAT;
            private MetricType metricType = MetricType.COSINE;
            private String indexParameters = "{\"nlist\":1024}";
            private String idFieldName = "doc_id";
            private boolean isAutoId = false;
            private String contentFieldName = "content";
            private String metadataFieldName = "metadata";
            private String embeddingFieldName = "embedding";

            private Builder() {
            }

            public Builder withMetricType(MetricType metricType) {
                Assert.notNull((Object)metricType, (String)"Collection Name must not be empty");
                Assert.isTrue((metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE ? 1 : 0) != 0, (String)"Only the text metric types IP and L2 are supported");
                this.metricType = metricType;
                return this;
            }

            public Builder withIndexType(IndexType indexType) {
                this.indexType = indexType;
                return this;
            }

            public Builder withIndexParameters(String indexParameters) {
                this.indexParameters = indexParameters;
                return this;
            }

            public Builder withDatabaseName(String databaseName) {
                this.databaseName = databaseName;
                return this;
            }

            public Builder withCollectionName(String collectionName) {
                this.collectionName = collectionName;
                return this;
            }

            public Builder withEmbeddingDimension(int newEmbeddingDimension) {
                Assert.isTrue((newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768 ? 1 : 0) != 0, (String)"Dimension has to be withing the boundaries 1 and 32768 (inclusively)");
                this.embeddingDimension = newEmbeddingDimension;
                return this;
            }

            public Builder withIDFieldName(String idFieldName) {
                this.idFieldName = idFieldName;
                return this;
            }

            public Builder withAutoId(boolean isAutoId) {
                this.isAutoId = isAutoId;
                return this;
            }

            public Builder withContentFieldName(String contentFieldName) {
                this.contentFieldName = contentFieldName;
                return this;
            }

            public Builder withMetadataFieldName(String metadataFieldName) {
                this.metadataFieldName = metadataFieldName;
                return this;
            }

            public Builder withEmbeddingFieldName(String embeddingFieldName) {
                this.embeddingFieldName = embeddingFieldName;
                return this;
            }

            public MilvusVectorStoreConfig build() {
                return new MilvusVectorStoreConfig(this);
            }
        }
    }

    public static final class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final MilvusServiceClient milvusClient;
        private String databaseName = "default";
        private String collectionName = "vector_store";
        private int embeddingDimension = -1;
        private IndexType indexType = IndexType.IVF_FLAT;
        private MetricType metricType = MetricType.COSINE;
        private String indexParameters = "{\"nlist\":1024}";
        private String idFieldName = "doc_id";
        private boolean isAutoId = false;
        private String contentFieldName = "content";
        private String metadataFieldName = "metadata";
        private String embeddingFieldName = "embedding";
        private boolean initializeSchema = false;
        private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();

        private Builder(MilvusServiceClient milvusClient, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)milvusClient, (String)"milvusClient must not be null");
            this.milvusClient = milvusClient;
        }

        public Builder metricType(MetricType metricType) {
            Assert.notNull((Object)metricType, (String)"Collection Name must not be empty");
            Assert.isTrue((metricType == MetricType.IP || metricType == MetricType.L2 || metricType == MetricType.COSINE ? 1 : 0) != 0, (String)"Only the text metric types IP and L2 are supported");
            this.metricType = metricType;
            return this;
        }

        public Builder indexType(IndexType indexType) {
            this.indexType = indexType;
            return this;
        }

        public Builder indexParameters(String indexParameters) {
            this.indexParameters = indexParameters;
            return this;
        }

        public Builder databaseName(String databaseName) {
            this.databaseName = databaseName;
            return this;
        }

        public Builder collectionName(String collectionName) {
            this.collectionName = collectionName;
            return this;
        }

        public Builder embeddingDimension(int newEmbeddingDimension) {
            Assert.isTrue((newEmbeddingDimension >= 1 && newEmbeddingDimension <= 32768 ? 1 : 0) != 0, (String)"Dimension has to be withing the boundaries 1 and 32768 (inclusively)");
            this.embeddingDimension = newEmbeddingDimension;
            return this;
        }

        public Builder iDFieldName(String idFieldName) {
            this.idFieldName = idFieldName;
            return this;
        }

        public Builder autoId(boolean isAutoId) {
            this.isAutoId = isAutoId;
            return this;
        }

        public Builder contentFieldName(String contentFieldName) {
            this.contentFieldName = contentFieldName;
            return this;
        }

        public Builder metadataFieldName(String metadataFieldName) {
            this.metadataFieldName = metadataFieldName;
            return this;
        }

        public Builder embeddingFieldName(String embeddingFieldName) {
            this.embeddingFieldName = embeddingFieldName;
            return this;
        }

        public Builder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public Builder batchingStrategy(BatchingStrategy batchingStrategy) {
            Assert.notNull((Object)batchingStrategy, (String)"batchingStrategy must not be null");
            this.batchingStrategy = batchingStrategy;
            return this;
        }

        public MilvusVectorStore build() {
            return new MilvusVectorStore(this);
        }
    }
}

