from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
TypedDict,
TypeVar,
Union,
cast,
)
from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
[docs]class RunnableConfig(TypedDict, total=False):
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Other = TypeVar("Other")
[docs]class Runnable(Generic[Input, Output], ABC):
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
[docs] @abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config
)
[docs] def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0])]
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs))
[docs] async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros)
[docs] def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield self.invoke(input, config)
[docs] async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
yield await self.ainvoke(input, config)
[docs] def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""
Bind arguments to a Runnable, returning a new Runnable.
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
config
if isinstance(config, list)
else [config.copy() if config is not None else {} for _ in range(length)]
)
def _call_with_config(
self,
func: Callable[[Input], Output],
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
from langchain.callbacks.manager import CallbackManager
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_type=run_type,
)
try:
output = func(input)
except Exception as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
[docs]class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any]
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
last: Runnable[Any, Output]
@property
def steps(self) -> List[Runnable[Any, Any]]:
return [self.first] + self.middle + [self.last]
@property
def lc_serializable(self) -> bool:
return True
[docs] class Config:
arbitrary_types_allowed = True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last] + [other.first] + other.middle,
last=other.last,
)
else:
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last],
last=_coerce_to_runnable(other),
)
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
middle=other.middle + [other.last] + [self.first] + self.middle,
last=self.last,
)
else:
return RunnableSequence(
first=_coerce_to_runnable(other),
middle=[self.first] + self.middle,
last=self.last,
)
[docs] def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
[docs] def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
]
# invoke
try:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
rm.on_chain_error(e)
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
return cast(List[Output], inputs)
[docs] async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
)
)
# invoke .batch() on each step
# this uses batching optimizations in Runnable subclasses, like LLM
try:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
raise
else:
await asyncio.gather(
*(
rm.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
for rm, input in zip(run_managers, inputs)
)
)
return cast(List[Output], inputs)
[docs] def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
for output in self.last.stream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
[docs] async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
async for output in self.last.astream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
[docs]class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
steps: Mapping[str, Runnable[Input, Any]]
def __init__(
self,
steps: Mapping[
str,
Union[
Runnable[Input, Any],
Callable[[Input], Any],
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
],
],
) -> None:
super().__init__(
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
)
@property
def lc_serializable(self) -> bool:
return True
[docs] class Config:
arbitrary_types_allowed = True
[docs] def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
return output
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}
)
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
results = await asyncio.gather(
*(
step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
)
)
output = {key: value for key, value in zip(steps, results)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
return output
[docs]class RunnableLambda(Runnable[Input, Output]):
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
else:
raise TypeError(
"Expected a callable type for `func`."
f"Instead got an unsupported type: {type(func)}"
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
return self.func == other.func
else:
return False
[docs] def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self._call_with_config(self.func, input, config)
[docs]class RunnablePassthrough(Serializable, Runnable[Input, Input]):
@property
def lc_serializable(self) -> bool:
return True
[docs] def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
[docs]class RunnableBinding(Serializable, Runnable[Input, Output]):
bound: Runnable[Input, Output]
kwargs: Mapping[str, Any]
[docs] class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
[docs] def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
[docs] def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self.bound.invoke(input, config, **self.kwargs)
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return await self.bound.ainvoke(input, config, **self.kwargs)
[docs] def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
return self.bound.batch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs
)
[docs] async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
return await self.bound.abatch(
inputs, config, max_concurrency=max_concurrency, **self.kwargs
)
[docs] def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield from self.bound.stream(input, config, **self.kwargs)
[docs] async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
async for item in self.bound.astream(input, config, **self.kwargs):
yield item
[docs]class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(self, runnables: Mapping[str, Runnable[Input, Output]]) -> None:
super().__init__(runnables=runnables)
[docs] class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
[docs] def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
[docs] async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
[docs] def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
[docs] async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await _gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
[docs] def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
[docs] async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
return config
def _coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
]
) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)