package org.springframework.ai.vectorstore;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.core.io.Resource;

/* loaded from: input_file:org/springframework/ai/vectorstore/SimpleVectorStore.class */
public class SimpleVectorStore implements VectorStore {
    private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);
    protected Map<String, Document> store = new ConcurrentHashMap();
    protected EmbeddingClient embeddingClient;

    /* loaded from: input_file:org/springframework/ai/vectorstore/SimpleVectorStore$EmbeddingMath.class */
    public class EmbeddingMath {
        private EmbeddingMath() {
            throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
        }

        public static double cosineSimilarity(List<Double> list, List<Double> list2) {
            if (list == null || list2 == null) {
                throw new RuntimeException("Vectors must not be null");
            }
            if (list.size() != list2.size()) {
                throw new IllegalArgumentException("Vectors lengths must be equal");
            }
            double dotProduct = dotProduct(list, list2);
            double norm = norm(list);
            double norm2 = norm(list2);
            if (norm == SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL || norm2 == SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL) {
                throw new IllegalArgumentException("Vectors cannot have zero norm");
            }
            return dotProduct / (Math.sqrt(norm) * Math.sqrt(norm2));
        }

        public static double dotProduct(List<Double> list, List<Double> list2) {
            if (list.size() != list2.size()) {
                throw new IllegalArgumentException("Vectors lengths must be equal");
            }
            double d = 0.0d;
            for (int i = 0; i < list.size(); i++) {
                d += list.get(i).doubleValue() * list2.get(i).doubleValue();
            }
            return d;
        }

        public static double norm(List<Double> list) {
            return dotProduct(list, list);
        }
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/SimpleVectorStore$Similarity.class */
    public static class Similarity {
        private String key;
        private double score;

        public Similarity(String str, double d) {
            this.key = str;
            this.score = d;
        }
    }

    public SimpleVectorStore(EmbeddingClient embeddingClient) {
        Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null");
        this.embeddingClient = embeddingClient;
    }

    @Override // org.springframework.ai.vectorstore.VectorStore
    public void add(List<Document> list) {
        for (Document document : list) {
            logger.info("Calling EmbeddingClient for document id = {}", document.getId());
            document.setEmbedding(this.embeddingClient.embed(document));
            this.store.put(document.getId(), document);
        }
    }

    @Override // org.springframework.ai.vectorstore.VectorStore
    public Optional<Boolean> delete(List<String> list) {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            this.store.remove(it.next());
        }
        return Optional.of(true);
    }

    @Override // org.springframework.ai.vectorstore.VectorStore
    public List<Document> similaritySearch(SearchRequest searchRequest) {
        if (searchRequest.getFilterExpression() != null) {
            throw new UnsupportedOperationException("The [" + String.valueOf(getClass()) + "] doesn't support metadata filtering!");
        }
        List<Double> userQueryEmbedding = getUserQueryEmbedding(searchRequest.getQuery());
        return this.store.values().stream().map(document -> {
            return new Similarity(document.getId(), EmbeddingMath.cosineSimilarity(userQueryEmbedding, document.getEmbedding()));
        }).filter(similarity -> {
            return similarity.score >= searchRequest.getSimilarityThreshold();
        }).sorted(Comparator.comparingDouble(similarity2 -> {
            return similarity2.score;
        }).reversed()).limit(searchRequest.getTopK()).map(similarity3 -> {
            return this.store.get(similarity3.key);
        }).toList();
    }

    public void save(File file) {
        String vectorDbAsJson = getVectorDbAsJson();
        try {
            if (file.exists()) {
                logger.info("Overwriting existing vector store file: {}", file);
            } else {
                logger.info("Creating new vector store file: {}", file);
                file.createNewFile();
            }
            FileOutputStream fileOutputStream = new FileOutputStream(file);
            try {
                OutputStreamWriter outputStreamWriter = new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8);
                try {
                    outputStreamWriter.write(vectorDbAsJson);
                    outputStreamWriter.flush();
                    outputStreamWriter.close();
                    fileOutputStream.close();
                } catch (Throwable th) {
                    try {
                        outputStreamWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                try {
                    fileOutputStream.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
                throw th3;
            }
        } catch (IOException e) {
            logger.error("IOException occurred while saving vector store file.", e);
            throw new RuntimeException(e);
        } catch (NullPointerException e2) {
            logger.error("NullPointerException occurred while saving vector store file.", e2);
            throw new RuntimeException(e2);
        } catch (SecurityException e3) {
            logger.error("SecurityException occurred while saving vector store file.", e3);
            throw new RuntimeException(e3);
        }
    }

    public void load(File file) {
        try {
            this.store = (Map) new ObjectMapper().readValue(file, new TypeReference<HashMap<String, Document>>() { // from class: org.springframework.ai.vectorstore.SimpleVectorStore.1
            });
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void load(Resource resource) {
        TypeReference<HashMap<String, Document>> typeReference = new TypeReference<HashMap<String, Document>>() { // from class: org.springframework.ai.vectorstore.SimpleVectorStore.2
        };
        try {
            this.store = (Map) new ObjectMapper().readValue(resource.getInputStream(), typeReference);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private String getVectorDbAsJson() {
        try {
            return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(this.store);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error serializing documentMap to JSON.", e);
        }
    }

    private List<Double> getUserQueryEmbedding(String str) {
        return this.embeddingClient.embed(str);
    }
}
