Source code for langchain.llms.vertexai

from __future__ import annotations

import asyncio
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional

from pydantic import BaseModel, root_validator

from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens
from langchain.utilities.vertexai import (
    init_vertexai,
    raise_vertex_import_error,
)

if TYPE_CHECKING:
    from vertexai.language_models._language_models import _LanguageModel


[docs]def is_codey_model(model_name: str) -> bool: """Returns True if the model name is a Codey model. Args: model_name: The model name to check. Returns: True if the model name is a Codey model. """ return "code" in model_name
def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]: import google.api_core errors = [ google.api_core.exceptions.ResourceExhausted, google.api_core.exceptions.ServiceUnavailable, google.api_core.exceptions.Aborted, google.api_core.exceptions.DeadlineExceeded, ] decorator = create_base_retry_decorator( error_types=errors, max_retries=llm.max_retries # type: ignore ) return decorator
[docs]def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm) @retry_decorator def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: return llm.client.predict(*args, **kwargs) return _completion_with_retry(*args, **kwargs)
class _VertexAICommon(BaseModel): client: "_LanguageModel" = None #: :meta private: model_name: str "Model name to use." temperature: float = 0.0 "Sampling temperature, it controls the degree of randomness in token selection." max_output_tokens: int = 128 "Token limit determines the maximum amount of text output from one prompt." top_p: float = 0.95 "Tokens are selected from most probable to least until the sum of their " "probabilities equals the top-p value. Top-p is ignored for Codey models." top_k: int = 40 "How the model selects tokens for output, the next token is selected from " "among the top-k most probable tokens. Top-k is ignored for Codey models." stop: Optional[List[str]] = None "Optional list of stop words to use when generating." project: Optional[str] = None "The default GCP project to use when making Vertex API calls." location: str = "us-central1" "The default location to use when making API calls." credentials: Any = None "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the environment." request_parallelism: int = 5 "The amount of parallelism allowed for requests issued to VertexAI models. " "Default is 5." max_retries: int = 6 """The maximum number of retries to make when generating.""" task_executor: ClassVar[Optional[Executor]] = None @property def is_codey_model(self) -> bool: return is_codey_model(self.model_name) @property def _default_params(self) -> Dict[str, Any]: if self.is_codey_model: return { "temperature": self.temperature, "max_output_tokens": self.max_output_tokens, } else: return { "temperature": self.temperature, "max_output_tokens": self.max_output_tokens, "top_k": self.top_k, "top_p": self.top_p, } def _predict( self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any ) -> str: params = {**self._default_params, **kwargs} res = completion_with_retry(self, prompt, **params) # type: ignore return self._enforce_stop_words(res.text, stop) def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: if stop is None and self.stop is not None: stop = self.stop if stop: return enforce_stop_tokens(text, stop) return text @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return {**{"model_name": self.model_name}, **self._default_params} @property def _llm_type(self) -> str: return "vertexai" @classmethod def _get_task_executor(cls, request_parallelism: int = 5) -> Executor: if cls.task_executor is None: cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism) return cls.task_executor @classmethod def _try_init_vertexai(cls, values: Dict) -> None: allowed_params = ["project", "location", "credentials"] params = {k: v for k, v in values.items() if k in allowed_params} init_vertexai(**params) return None
[docs]class VertexAI(_VertexAICommon, LLM): """Google Vertex AI large language models.""" model_name: str = "text-bison" "The name of the Vertex AI large language model." tuned_model_name: Optional[str] = None "The name of a tuned model. If provided, model_name is ignored."
[docs] @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" cls._try_init_vertexai(values) tuned_model_name = values.get("tuned_model_name") model_name = values["model_name"] try: if tuned_model_name or not is_codey_model(model_name): from vertexai.preview.language_models import TextGenerationModel if tuned_model_name: values["client"] = TextGenerationModel.get_tuned_model( tuned_model_name ) else: values["client"] = TextGenerationModel.from_pretrained(model_name) else: from vertexai.preview.language_models import CodeGenerationModel values["client"] = CodeGenerationModel.from_pretrained(model_name) except ImportError: raise_vertex_import_error() return values
async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call Vertex model to get predictions based on the prompt. Args: prompt: The prompt to pass into the model. stop: A list of stop words (optional). run_manager: A callback manager for async interaction with LLMs. Returns: The string generated by the model. """ return await asyncio.wrap_future( self._get_task_executor().submit(self._predict, prompt, stop) ) def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call Vertex model to get predictions based on the prompt. Args: prompt: The prompt to pass into the model. stop: A list of stop words (optional). run_manager: A Callbackmanager for LLM run, optional. Returns: The string generated by the model. """ return self._predict(prompt, stop, **kwargs)