Source code for langchain.llms.azureml_endpoint

import json
import urllib.request
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Mapping, Optional

from pydantic import BaseModel, validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env


[docs]class AzureMLEndpointClient(object): """AzureML Managed Endpoint client.""" def __init__( self, endpoint_url: str, endpoint_api_key: str, deployment_name: str = "" ) -> None: """Initialize the class.""" if not endpoint_api_key or not endpoint_url: raise ValueError( """A key/token and REST endpoint should be provided to invoke the endpoint""" ) self.endpoint_url = endpoint_url self.endpoint_api_key = endpoint_api_key self.deployment_name = deployment_name
[docs] def call(self, body: bytes, **kwargs: Any) -> bytes: """call.""" # The azureml-model-deployment header will force the request to go to a # specific deployment. Remove this header to have the request observe the # endpoint traffic rules. headers = { "Content-Type": "application/json", "Authorization": ("Bearer " + self.endpoint_api_key), } if self.deployment_name != "": headers["azureml-model-deployment"] = self.deployment_name req = urllib.request.Request(self.endpoint_url, body, headers) response = urllib.request.urlopen(req, timeout=kwargs.get("timeout", 50)) result = response.read() return result
class ContentFormatterBase: """Transform request and response of AzureML endpoint to match with required schema. """ """ Example: .. code-block:: python class ContentFormatter(ContentFormatterBase): content_type = "application/json" accepts = "application/json" def format_request_payload( self, prompt: str, model_kwargs: Dict ) -> bytes: input_str = json.dumps( { "inputs": {"input_string": [prompt]}, "parameters": model_kwargs, } ) return str.encode(input_str) def format_response_payload(self, output: str) -> str: response_json = json.loads(output) return response_json[0]["0"] """ content_type: Optional[str] = "application/json" """The MIME type of the input data passed to the endpoint""" accepts: Optional[str] = "application/json" """The MIME type of the response data returned from the endpoint""" @staticmethod def escape_special_characters(prompt: str) -> str: """Escapes any special characters in `prompt`""" escape_map = { "\\": "\\\\", '"': '\\"', "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", } # Replace each occurrence of the specified characters with escaped versions for escape_sequence, escaped_sequence in escape_map.items(): prompt = prompt.replace(escape_sequence, escaped_sequence) return prompt @abstractmethod def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: """Formats the request body according to the input schema of the model. Returns bytes or seekable file like object in the format specified in the content_type request header. """ @abstractmethod def format_response_payload(self, output: bytes) -> str: """Formats the response body according to the output schema of the model. Returns the data type that is received from the response. """
[docs]class GPT2ContentFormatter(ContentFormatterBase): """Content handler for GPT2"""
[docs] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( {"inputs": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs} ) return str.encode(request_payload)
[docs] def format_response_payload(self, output: bytes) -> str: return json.loads(output)[0]["0"]
[docs]class OSSContentFormatter(GPT2ContentFormatter): """Deprecated: Kept for backwards compatibility Content handler for LLMs from the OSS catalog.""" content_formatter: Any = None def __init__(self) -> None: super().__init__() warnings.warn( """`OSSContentFormatter` will be deprecated in the future. Please use `GPT2ContentFormatter` instead. """ )
[docs]class HFContentFormatter(ContentFormatterBase): """Content handler for LLMs from the HuggingFace catalog."""
[docs] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( {"inputs": [f'"{prompt}"'], "parameters": model_kwargs} ) return str.encode(request_payload)
[docs] def format_response_payload(self, output: bytes) -> str: return json.loads(output)[0]["generated_text"]
[docs]class DollyContentFormatter(ContentFormatterBase): """Content handler for the Dolly-v2-12b model"""
[docs] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: prompt = ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( { "input_data": {"input_string": [f'"{prompt}"']}, "parameters": model_kwargs, } ) return str.encode(request_payload)
[docs] def format_response_payload(self, output: bytes) -> str: return json.loads(output)[0]
[docs]class LlamaContentFormatter(ContentFormatterBase): """Content formatter for LLaMa"""
[docs] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: """Formats the request according the the chosen api""" prompt = ContentFormatterBase.escape_special_characters(prompt) request_payload = json.dumps( { "input_data": { "input_string": [f'"{prompt}"'], "parameters": model_kwargs, } } ) return str.encode(request_payload)
[docs] def format_response_payload(self, output: bytes) -> str: """Formats response""" return json.loads(output)[0]["0"]
[docs]class AzureMLOnlineEndpoint(LLM, BaseModel): """Azure ML Online Endpoint models. Example: .. code-block:: python azure_llm = AzureMLOnlineEndpoint( endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_api_key="my-api-key", content_formatter=content_formatter, ) """ # noqa: E501 endpoint_url: str = "" """URL of pre-existing Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_URL`.""" endpoint_api_key: str = "" """Authentication Key for Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_API_KEY`.""" deployment_name: str = "" """Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed to constructor or specified as env var `AZUREML_DEPLOYMENT_NAME`.""" http_client: Any = None #: :meta private: content_formatter: Any = None """The content formatter that provides an input and output transform function to handle formats between the LLM and the endpoint""" model_kwargs: Optional[dict] = None """Key word arguments to pass to the model."""
[docs] @validator("http_client", always=True, allow_reuse=True) @classmethod def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: """Validate that api key and python package exists in environment.""" endpoint_key = get_from_dict_or_env( values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY" ) endpoint_url = get_from_dict_or_env( values, "endpoint_url", "AZUREML_ENDPOINT_URL" ) deployment_name = get_from_dict_or_env( values, "deployment_name", "AZUREML_DEPLOYMENT_NAME", "" ) http_client = AzureMLEndpointClient(endpoint_url, endpoint_key, deployment_name) return http_client
@property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" _model_kwargs = self.model_kwargs or {} return { **{"deployment_name": self.deployment_name}, **{"model_kwargs": _model_kwargs}, } @property def _llm_type(self) -> str: """Return type of llm.""" return "azureml_endpoint" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Call out to an AzureML Managed Online endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = azureml_model("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} request_payload = self.content_formatter.format_request_payload( prompt, _model_kwargs ) response_payload = self.http_client.call(request_payload, **kwargs) generated_text = self.content_formatter.format_response_payload( response_payload ) return generated_text