Source code for langchain.embeddings.clarifai
import logging
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
[docs]class ClarifaiEmbeddings(BaseModel, Embeddings):
"""Clarifai embedding models.
To use, you should have the ``clarifai`` python package installed, and the
environment variable ``CLARIFAI_PAT`` set with your personal access token or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.embeddings import ClarifaiEmbeddings
clarifai = ClarifaiEmbeddings(
model="embed-english-light-v2.0", clarifai_api_key="my-api-key"
)
"""
stub: Any #: :meta private:
"""Clarifai stub."""
userDataObject: Any
"""Clarifai user data object."""
model_id: Optional[str] = None
"""Model id to use."""
model_version_id: Optional[str] = None
"""Model version id to use."""
app_id: Optional[str] = None
"""Clarifai application id to use."""
user_id: Optional[str] = None
"""Clarifai user id to use."""
pat: Optional[str] = None
"""Clarifai personal access token to use."""
api_base: str = "https://api.clarifai.com"
[docs] class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
[docs] @root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["pat"] = get_from_dict_or_env(values, "pat", "CLARIFAI_PAT")
user_id = values.get("user_id")
app_id = values.get("app_id")
model_id = values.get("model_id")
if values["pat"] is None:
raise ValueError("Please provide a pat.")
if user_id is None:
raise ValueError("Please provide a user_id.")
if app_id is None:
raise ValueError("Please provide a app_id.")
if model_id is None:
raise ValueError("Please provide a model_id.")
try:
from clarifai.auth.helper import ClarifaiAuthHelper
from clarifai.client import create_stub
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
auth = ClarifaiAuthHelper(
user_id=user_id,
app_id=app_id,
pat=values["pat"],
base=values["api_base"],
)
values["userDataObject"] = auth.get_user_app_id_proto()
values["stub"] = create_stub(auth)
return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Clarifai's embedding models.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
try:
from clarifai_grpc.grpc.api import (
resources_pb2,
service_pb2,
)
from clarifai_grpc.grpc.api.status import status_code_pb2
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in texts
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
)
embeddings = [
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
return embeddings
[docs] def embed_query(self, text: str) -> List[float]:
"""Call out to Clarifai's embedding models.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
try:
from clarifai_grpc.grpc.api import (
resources_pb2,
service_pb2,
)
from clarifai_grpc.grpc.api.status import status_code_pb2
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=text))
)
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
)
embeddings = [
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
return embeddings[0]