Source code for langchain.callbacks.base

"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import UUID

if TYPE_CHECKING:
    from langchain.schema.agent import AgentAction, AgentFinish
    from langchain.schema.document import Document
    from langchain.schema.messages import BaseMessage
    from langchain.schema.output import LLMResult


class RetrieverManagerMixin:
    """Mixin for Retriever callbacks."""

    def on_retriever_error(
        self,
        error: Union[Exception, KeyboardInterrupt],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when Retriever errors."""

    def on_retriever_end(
        self,
        documents: Sequence[Document],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when Retriever ends running."""


class LLMManagerMixin:
    """Mixin for LLM callbacks."""

    def on_llm_new_token(
        self,
        token: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on new LLM token. Only available when streaming is enabled."""

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when LLM ends running."""

    def on_llm_error(
        self,
        error: Union[Exception, KeyboardInterrupt],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when LLM errors."""


class ChainManagerMixin:
    """Mixin for chain callbacks."""

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when chain ends running."""

    def on_chain_error(
        self,
        error: Union[Exception, KeyboardInterrupt],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when chain errors."""

    def on_agent_action(
        self,
        action: AgentAction,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on agent action."""

    def on_agent_finish(
        self,
        finish: AgentFinish,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on agent end."""


class ToolManagerMixin:
    """Mixin for tool callbacks."""

    def on_tool_end(
        self,
        output: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when tool ends running."""

    def on_tool_error(
        self,
        error: Union[Exception, KeyboardInterrupt],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when tool errors."""


class CallbackManagerMixin:
    """Mixin for callback manager."""

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when LLM starts running."""

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when a chat model starts running."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not implement `on_chat_model_start`"
        )

    def on_retriever_start(
        self,
        serialized: Dict[str, Any],
        query: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when Retriever starts running."""

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when chain starts running."""

    def on_tool_start(
        self,
        serialized: Dict[str, Any],
        input_str: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when tool starts running."""


class RunManagerMixin:
    """Mixin for run manager."""

    def on_text(
        self,
        text: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on arbitrary text."""


[docs]class BaseCallbackHandler( LLMManagerMixin, ChainManagerMixin, ToolManagerMixin, RetrieverManagerMixin, CallbackManagerMixin, RunManagerMixin, ): """Base callback handler that can be used to handle callbacks from langchain.""" raise_error: bool = False run_inline: bool = False @property def ignore_llm(self) -> bool: """Whether to ignore LLM callbacks.""" return False @property def ignore_retry(self) -> bool: """Whether to ignore retry callbacks.""" return False @property def ignore_chain(self) -> bool: """Whether to ignore chain callbacks.""" return False @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" return False @property def ignore_retriever(self) -> bool: """Whether to ignore retriever callbacks.""" return False @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" return False
[docs]class AsyncCallbackHandler(BaseCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain."""
[docs] async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when LLM starts running."""
[docs] async def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running.""" raise NotImplementedError( f"{self.__class__.__name__} does not implement `on_chat_model_start`" )
[docs] async def on_llm_new_token( self, token: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled."""
[docs] async def on_llm_end( self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when LLM ends running."""
[docs] async def on_llm_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when LLM errors."""
[docs] async def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when chain starts running."""
[docs] async def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain ends running."""
[docs] async def on_chain_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain errors."""
[docs] async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when tool starts running."""
[docs] async def on_tool_end( self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool ends running."""
[docs] async def on_tool_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool errors."""
[docs] async def on_text( self, text: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on arbitrary text."""
[docs] async def on_agent_action( self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on agent action."""
[docs] async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on agent end."""
[docs] async def on_retriever_start( self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run on retriever start."""
[docs] async def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on retriever end."""
[docs] async def on_retriever_error( self, error: Union[Exception, KeyboardInterrupt], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on retriever error."""
[docs]class BaseCallbackManager(CallbackManagerMixin): """Base callback manager that handles callbacks from LangChain.""" def __init__( self, handlers: List[BaseCallbackHandler], inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, parent_run_id: Optional[UUID] = None, *, tags: Optional[List[str]] = None, inheritable_tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inheritable_metadata: Optional[Dict[str, Any]] = None, ) -> None: """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers self.inheritable_handlers: List[BaseCallbackHandler] = ( inheritable_handlers or [] ) self.parent_run_id: Optional[UUID] = parent_run_id self.tags = tags or [] self.inheritable_tags = inheritable_tags or [] self.metadata = metadata or {} self.inheritable_metadata = inheritable_metadata or {} @property def is_async(self) -> bool: """Whether the callback manager is async.""" return False
[docs] def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Add a handler to the callback manager.""" if handler not in self.handlers: self.handlers.append(handler) if inherit and handler not in self.inheritable_handlers: self.inheritable_handlers.append(handler)
[docs] def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" self.handlers.remove(handler) self.inheritable_handlers.remove(handler)
[docs] def set_handlers( self, handlers: List[BaseCallbackHandler], inherit: bool = True ) -> None: """Set handlers as the only handlers on the callback manager.""" self.handlers = [] self.inheritable_handlers = [] for handler in handlers: self.add_handler(handler, inherit=inherit)
[docs] def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Set handler as the only handler on the callback manager.""" self.set_handlers([handler], inherit=inherit)
[docs] def add_tags(self, tags: List[str], inherit: bool = True) -> None: for tag in tags: if tag in self.tags: self.remove_tags([tag]) self.tags.extend(tags) if inherit: self.inheritable_tags.extend(tags)
[docs] def remove_tags(self, tags: List[str]) -> None: for tag in tags: self.tags.remove(tag) self.inheritable_tags.remove(tag)
[docs] def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: self.metadata.update(metadata) if inherit: self.inheritable_metadata.update(metadata)
[docs] def remove_metadata(self, keys: List[str]) -> None: for key in keys: self.metadata.pop(key) self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]