Source code for langchain.retrievers.self_query.base

"""Retriever that generates and executes structured queries over its own data source."""

from typing import Any, Dict, List, Optional, Type, cast

from pydantic import BaseModel, Field, root_validator

from langchain import LLMChain
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.vectorstores import (
    Chroma,
    DeepLake,
    MyScale,
    Pinecone,
    Qdrant,
    VectorStore,
    Weaviate,
)


def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
    """Get the translator class corresponding to the vector store class."""
    vectorstore_cls = vectorstore.__class__
    BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
        Pinecone: PineconeTranslator,
        Chroma: ChromaTranslator,
        Weaviate: WeaviateTranslator,
        Qdrant: QdrantTranslator,
        MyScale: MyScaleTranslator,
        DeepLake: DeepLakeTranslator,
    }
    if vectorstore_cls not in BUILTIN_TRANSLATORS:
        raise ValueError(
            f"Self query retriever with Vector Store type {vectorstore_cls}"
            f" not supported."
        )
    if isinstance(vectorstore, Qdrant):
        return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
    elif isinstance(vectorstore, MyScale):
        return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
    return BUILTIN_TRANSLATORS[vectorstore_cls]()


[docs]class SelfQueryRetriever(BaseRetriever, BaseModel): """Retriever that uses a vector store and an LLM to generate the vector store queries.""" vectorstore: VectorStore """The underlying vector store from which documents will be retrieved.""" llm_chain: LLMChain """The LLMChain for generating the vector store queries.""" search_type: str = "similarity" """The search type to perform on the vector store.""" search_kwargs: dict = Field(default_factory=dict) """Keyword arguments to pass in to the vector store search.""" structured_query_translator: Visitor """Translator for turning internal query language into vectorstore search params.""" verbose: bool = False """Use original query instead of the revised new query from LLM""" use_original_query: bool = False
[docs] class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True
[docs] @root_validator(pre=True) def validate_translator(cls, values: Dict) -> Dict: """Validate translator.""" if "structured_query_translator" not in values: values["structured_query_translator"] = _get_builtin_translator( values["vectorstore"] ) return values
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query. Args: query: string to find relevant documents for Returns: List of relevant documents """ inputs = self.llm_chain.prep_inputs({"query": query}) structured_query = cast( StructuredQuery, self.llm_chain.predict_and_parse( callbacks=run_manager.get_child(), **inputs ), ) if self.verbose: print(structured_query) new_query, new_kwargs = self.structured_query_translator.visit_structured_query( structured_query ) if structured_query.limit is not None: new_kwargs["k"] = structured_query.limit if self.use_original_query: new_query = query search_kwargs = {**self.search_kwargs, **new_kwargs} docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs) return docs
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, vectorstore: VectorStore, document_contents: str, metadata_field_info: List[AttributeInfo], structured_query_translator: Optional[Visitor] = None, chain_kwargs: Optional[Dict] = None, enable_limit: bool = False, use_original_query: bool = False, **kwargs: Any, ) -> "SelfQueryRetriever": if structured_query_translator is None: structured_query_translator = _get_builtin_translator(vectorstore) chain_kwargs = chain_kwargs or {} if "allowed_comparators" not in chain_kwargs: chain_kwargs[ "allowed_comparators" ] = structured_query_translator.allowed_comparators if "allowed_operators" not in chain_kwargs: chain_kwargs[ "allowed_operators" ] = structured_query_translator.allowed_operators llm_chain = load_query_constructor_chain( llm, document_contents, metadata_field_info, enable_limit=enable_limit, **chain_kwargs, ) return cls( llm_chain=llm_chain, vectorstore=vectorstore, use_original_query=use_original_query, structured_query_translator=structured_query_translator, **kwargs, )