Source code for langchain.retrievers.google_cloud_enterprise_search

"""Retriever wrapper for Google Cloud Enterprise Search on Gen App Builder."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env

if TYPE_CHECKING:
    from google.cloud.discoveryengine_v1beta import (
        SearchRequest,
        SearchResult,
        SearchServiceClient,
    )


[docs]class GoogleCloudEnterpriseSearchRetriever(BaseRetriever): """Retriever for the Google Cloud Enterprise Search Service API. For the detailed explanation of the Enterprise Search concepts and configuration parameters refer to the product documentation. https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction """ project_id: str """Google Cloud Project ID.""" search_engine_id: str """Enterprise Search engine ID.""" serving_config_id: str = "default_config" """Enterprise Search serving config ID.""" location_id: str = "global" """Enterprise Search engine location.""" filter: Optional[str] = None """Filter expression.""" get_extractive_answers: bool = False """If True return Extractive Answers, otherwise return Extractive Segments.""" max_documents: int = Field(default=5, ge=1, le=100) """The maximum number of documents to return.""" max_extractive_answer_count: int = Field(default=1, ge=1, le=5) """The maximum number of extractive answers returned in each search result. At most 5 answers will be returned for each SearchResult. """ max_extractive_segment_count: int = Field(default=1, ge=1, le=1) """The maximum number of extractive segments returned in each search result. Currently one segment will be returned for each SearchResult. """ query_expansion_condition: int = Field(default=1, ge=0, le=2) """Specification to determine under which conditions query expansion should occur. 0 - Unspecified query expansion condition. In this case, server behavior defaults to disabled 1 - Disabled query expansion. Only the exact search query is used, even if SearchResponse.total_size is zero. 2 - Automatic query expansion built by the Search API. """ credentials: Any = None """The default custom credentials (google.auth.credentials.Credentials) to use when making API calls. If not provided, credentials will be ascertained from the environment.""" _client: SearchServiceClient _serving_config: str
[docs] class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True underscore_attrs_are_private = True
[docs] @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validates the environment.""" try: from google.cloud import discoveryengine_v1beta # noqa: F401 except ImportError as exc: raise ImportError( "google.cloud.discoveryengine is not installed. " "Please install it with pip install google-cloud-discoveryengine" ) from exc values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID") values["search_engine_id"] = get_from_dict_or_env( values, "search_engine_id", "SEARCH_ENGINE_ID" ) return values
def __init__(self, **data: Any) -> None: """Initializes private fields.""" from google.cloud.discoveryengine_v1beta import SearchServiceClient super().__init__(**data) self._client = SearchServiceClient(credentials=self.credentials) self._serving_config = self._client.serving_config_path( project=self.project_id, location=self.location_id, data_store=self.search_engine_id, serving_config=self.serving_config_id, ) def _convert_search_response( self, results: Sequence[SearchResult] ) -> List[Document]: """Converts a sequence of search results to a list of LangChain documents.""" documents: List[Document] = [] for result in results: derived_struct_data = result.document.derived_struct_data doc_metadata = result.document.struct_data doc_metadata.source = derived_struct_data.link or "" doc_metadata.id = result.document.id for chunk in ( derived_struct_data.extractive_answers or derived_struct_data.extractive_segments ): if hasattr(chunk, "page_number"): doc_metadata.source += f":{chunk.page_number}" documents.append( Document(page_content=chunk.content, metadata=doc_metadata) ) return documents def _create_search_request(self, query: str) -> SearchRequest: """Prepares a SearchRequest object.""" from google.cloud.discoveryengine_v1beta import SearchRequest query_expansion_spec = SearchRequest.QueryExpansionSpec( condition=self.query_expansion_condition, ) if self.get_extractive_answers: extractive_content_spec = ( SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_answer_count=self.max_extractive_answer_count, ) ) else: extractive_content_spec = ( SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_segment_count=self.max_extractive_segment_count, ) ) content_search_spec = SearchRequest.ContentSearchSpec( extractive_content_spec=extractive_content_spec, ) return SearchRequest( query=query, filter=self.filter, serving_config=self._serving_config, page_size=self.max_documents, content_search_spec=content_search_spec, query_expansion_spec=query_expansion_spec, ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """Get documents relevant for a query.""" search_request = self._create_search_request(query) response = self._client.search(search_request) documents = self._convert_search_response(response.results) return documents