导航菜单

  • 1.vector
  • 2.milvus
  • 3.pymilvus
  • 4.rag
  • 5.rag_measure
  • 7.search
  • ragflow
  • heapq
  • HNSW
  • cosine_similarity
  • math
  • typing
  • etcd
  • minio
  • collections
  • jieba
  • random
  • beautifulsoup4
  • chromadb
  • sentence_transformers
  • numpy
  • lxml
  • openpyxl
  • PyMuPDF
  • python-docx
  • requests
  • python-pptx
  • text_splitter
  • all-MiniLM-L6-v2
  • openai
  • llm
  • BPETokenizer
  • Flask
  • RAGAS
  • BagofWords
  • langchain
  • Pydantic
  • abc
  • faiss
  • MMR
  • scikit-learn
  • LangChain 教程
    • 1. ChatOpenAI
      • 1.1 main.py
      • 1.2 langchain__init__.py
      • 1.3 chat_models.py
      • 1.4 messages.py
    • 2. ChatDeepSeek
      • 2.1. init.py
      • 2.2. deepseek.py
      • 2.3. chat_models.py
    • 3. ChatTongyi
      • 3.1. init.py
      • 3.2. tongyi.py
      • 3.3. chat_models.py
    • 4. 多轮对话
      • 4.1. conversations.py
      • 4.2. messages.py
      • 4.3. chat_models.py
    • 5. PromptTemplate
      • 5.1. PromptTemplate.py
      • 5.2. init.py
      • 5.3. prompts.py
    • 6. ChatPromptTemplate
      • 6.1. init.py
      • 6.2. ChatPromptTemplate.py
      • 6.3. prompts.py
      • 6.4. chat_models.py
    • 7. MessagesPlaceholder
      • 7.1. MessagesPlaceholder.py
      • 7.2. init.py
      • 7.3. prompts.py
    • 8. FewShotPromptTemplate
      • 8.1. init.py
      • 8.2. FewShotPromptTemplate.py
      • 8.3. prompts.py
    • 9. load_prompt
      • 9.1. prompt_template.json
      • 9.2. Load_prompt.py
      • 9.3. prompts.py
    • 10. partial
      • 10.1. partial.py
      • 10.2. prompts.py
    • 11. PipelinePromptTemplate
      • 11.1. PipelinePromptTemplate.py
      • 11.2. prompts.py
    • 12. LengthBasedExampleSelector
      • 12.1. LengthBasedExampleSelector.py
      • 12.2. example_selectors.py
      • 12.3. prompts.py
    • 13. MaxMarginalRelevanceExampleSelector
      • 13.1. MaxMarginalRelevanceExampleSelector.py
      • 13.2. embeddings.py
      • 13.3. example_selectors.py
      • 13.4. vectorstores.py
      • 13.5. chat_models.py
    • 14. SemanticSimilarityExampleSelector
      • 14.1. SemanticSimilarityExampleSelector.py
      • 14.2. example_selectors.py
    • 15. BaseExampleSelector
      • 15.1. BaseExampleSelector.py
    • 16. StrOutputParser
      • 16.1. output_parsers.py
      • 16.2. StrOutputParser.py
      • 16.3. chat_models.py
    • 17. JsonOutputParser
      • 17.1. output_parsers.py
  • 尝试匹配 { ... } 或 [ ... ]
  • 解析 JSON
  • 使用辅助函数解析 JSON
  • 如果指定了 Pydantic 模型,进行验证
  • 尝试使用 Pydantic 验证(如果可用)
  • 如果验证失败,返回原始解析结果
  • 如果有 Pydantic 模型,返回其 schema
    • 18. PydanticOutputParser
      • 18.1. PydanticOutputParser.py
      • 18.2. output_parsers.py
  • 定义 JSON 输出解析器类
    • 19.2. output_parsers.py
  • 定义 JSON 输出解析器类
  • 定义 Pydantic 输出解析器类
  • 默认修复提示模板
  • 创建修复链:Prompt -> LLM -> StrOutputParser
  • 尝试使用基础解析器解析
  • 如果已达到最大重试次数,抛出异常
  • 获取格式说明(如果解析器支持)
  • 使用 LLM 修复输出
  • 新式链式调用
  • 旧式链式调用
  • 直接调用(作为函数)
  • 确保返回的是字符串
  • 格式化提示词
  • 调用 LLM
  • 解析输出
  • 使用解析器解析
    • 20. RetryOutputParser
      • 20.1. RetryOutputParser.py
      • 20.2. output_parsers.py
  • 定义 JSON 输出解析器类
  • 定义 Pydantic 输出解析器类
  • 定义输出解析异常类
  • 定义输出修复解析器类
    • {instructions}
    • Completion:
    • {completion}
    • 错误信息:
    • {error}
  • 简单的链式调用包装类
  • 默认重试提示模板
  • 创建重试链:Prompt -> LLM -> StrOutputParser
  • 将 prompt_value 转换为字符串
  • 尝试使用基础解析器解析
  • 如果已达到最大重试次数,抛出异常
  • 使用 LLM 重新生成输出
  • 新式链式调用
  • 旧式链式调用
  • 直接调用
  • 确保返回的是字符串
  • 默认重试提示模板(包含错误信息)
  • 创建重试链:Prompt -> LLM -> StrOutputParser
  • 将 prompt_value 转换为字符串
  • 尝试使用基础解析器解析
  • 如果已达到最大重试次数,抛出异常
  • 使用 LLM 重新生成输出(包含错误信息)
  • 新式链式调用
  • 旧式链式调用
  • 直接调用
  • 确保返回的是字符串
    • 21. BaseOutputParser
      • 21.1. BaseOutputParser.py
    • 22. LCEL表达式
      • 22.1. init.py
      • 22.2. runnables.py
      • 22.3. LCEL.py
      • 22.4. chat_models.py
      • 22.5. prompts.py
      • 22.6. output_parsers.py
  • 定义 JSON 输出解析器类
  • 定义输出解析异常类
  • 定义输出修复解析器类
    • {instructions}
    • Completion:
    • {completion}
    • 错误信息:
    • {error}
  • 简单的链式调用包装类
  • 简单的 PromptValue 包装类
  • 定义重试输出解析器类
  • 定义带错误信息的重试输出解析器类
    • 23.1. runnables.py
    • 23.2. Runnable.py
    • 24. RunnableSequence和RunnableLambda
      • 24.1. runnables.py
      • 24.2. RunnableSequence.py
    • 25. RunnableParallel
      • 25.1. runnables.py
      • 25.2. RunnableParallel.py
    • 26. RunnablePassthrough
      • 26.1. RunnablePassthrough.py
      • 26.2. runnables.py
    • 27. RunnableBranch
      • 27.1. RunnableBranch.py
      • 27.2. runnables.py
    • 28. Retry
      • 28.1. with_retry.py
      • 28.2. runnables.py
    • 29. Config
      • 29.1. Config.py
      • 29.2. callbacks.py
      • 29.3. runnables.py
    • 30. ConfigurableField
      • 30.1. Configurable.py
      • 30.2. runnables.py
    • 31. ChatMessageHistory
      • 31.1. init.py
      • 31.2. chat_message_histories.py
    • 32. RunnableWithMessageHistory
      • 32.1. RunnableWithMessageHistory.py
      • 32.2. runnables.py
  • 调用父类初始化
  • 从 config 中获取 session_id
  • 如果没有 session_id,直接调用原始 runnable
  • 获取会话历史
  • 保存原始输入(用于后续更新历史)
  • 准备输入
  • 如果输入是字典,注入历史消息
  • 如果只有 input_messages_key,将当前输入和历史合并
  • 如果输入是字符串,转换为消息列表
  • 如果输入是消息列表,合并历史
  • 调用原始 runnable
  • 更新历史记录(使用原始输入,避免重复添加)
  • 获取历史消息数量(用于判断哪些是新消息)
  • 添加用户输入到历史(只添加新的消息)
  • 只添加新的消息(历史消息之后的部分)
  • 只添加新的消息(历史消息之后的部分)
  • 添加 AI 回复到历史
  • 如果是消息对象(如 AIMessage)
  • ConfigurableField 相关类

LangChain 教程 #

1. ChatOpenAI #

ChatOpenAI是 LangChain 对 OpenAI 聊天大模型(如 gpt-4、gpt-3.5 等)的简易封装。通过该类,我们可以方便地调用 OpenAI 的 API,实现与大语言模型的对话。

  • 只需指定模型名称(如 "gpt-4o"),即可创建会话实例。
  • 使用 .invoke() 方法,通过输入用户消息,快速获得 AI 的回复内容。
  • ChatOpenAI 会自动管理消息结构和底层 API 调用,开发者无需关心复杂的请求细节。
uv add langchain_openai

1.1 main.py #

# 从 langchain 导入 ChatOpenAI 类
from langchain import ChatOpenAI

# 创建一个 ChatOpenAI 实例,使用 "gpt-4o" 模型
llm = ChatOpenAI(model="gpt-4o")

# 调用模型,传入用户消息,返回 AI 的回复结果
result = llm.invoke("你好,你是谁")

# 打印 AI 回复的内容
print(result.content)

1.2 langchain__init__.py #

langchain__init__.py

# 从当前包中导入 ChatOpenAI 类
from .chat_models import ChatOpenAI
# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
__all__ = ["ChatOpenAI"]

1.3 chat_models.py #

langchain\chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage 与 HumanMessage 类
from langchain.messages import AIMessage, HumanMessage

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant
                    role = "user" if isinstance(msg, HumanMessage) else "assistant"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

1.4 messages.py #

langchain\messages.py

# 定义基础消息类
class BaseMessage:
    # 基础消息类的文档字符串
    """基础消息类"""

    # 初始化方法,content为消息内容,其余为可选参数
    def __init__(self, content: str, **kwargs):
        """
        初始化消息

        Args:
            content: 消息内容
            **kwargs: 其他可选参数
        """
        # 保存消息内容
        self.content = content
        # 获取消息类型,如果未指定则默认为"base"
        self.type = kwargs.get("type", "base")
        # 遍历其他参数,并设置为成员变量(排除type参数)
        for key, value in kwargs.items():
            if key != "type":
                setattr(self, key, value)

    # 定义当对象被str()或print时的输出内容
    def __str__(self):
        return self.content

    # 定义对象的官方字符串表示,用于debug
    def __repr__(self):
        return f"{self.__class__.__name__}(content={self.content!r})"

# 定义用户消息类,继承自BaseMessage
class HumanMessage(BaseMessage):
    # 用户消息类的文档字符串
    """用户消息"""

    # 初始化方法,调用父类构造方法,并指定type为"human"
    def __init__(self, content: str, **kwargs):
        super().__init__(content, type="human", **kwargs)

# 定义AI消息类,继承自BaseMessage
class AIMessage(BaseMessage):
    # AI消息类的文档字符串
    """AI 消息"""

    # 初始化方法,调用父类构造方法,并指定type为"ai"
    def __init__(self, content: str, **kwargs):
        super().__init__(content, type="ai", **kwargs)

2. ChatDeepSeek #

ChatDeepSeek是 LangChain 对 DeepSeek 聊天大模型的简单封装。其用法与 ChatOpenAI 高度相似,开发者可以用统一的接口来调用 DeepSeek API,与大模型进行对话。

  • 支持通过 model 参数指定 DeepSeek 的模型名称(如 "deepseek-chat" 等)。
  • 只需传入 api_key 与模型等必要参数,即可完成模型初始化。
  • 通过 .invoke() 方法,用一行代码获取模型回复。
  • 输入与输出均与 ChatOpenAI 保持一致,方便模型替换。
uv add langchain_deepseek

2.1. init.py #

langchain/init.py

+# 从当前包中导入 ChatOpenAI 和 ChatDeepSeek 类
+from .chat_models import ChatOpenAI, ChatDeepSeek

# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
+__all__ = ["ChatOpenAI", "ChatDeepSeek"]

2.2. deepseek.py #

2.deepseek.py

#from langchain_deepseek import ChatDeepSeek
# 从 langchain 导入 ChatDeepSeek 类
from langchain import ChatDeepSeek

# 创建一个 ChatOpenAI 实例,使用 "gpt-4o" 模型
llm = ChatDeepSeek(model="deepseek-chat",api_key="sk-bb99cf132b184a169b5e053b346a7c25")

# 调用模型,传入用户消息,返回 AI 的回复结果
result = llm.invoke("你是谁")

# 打印 AI 回复的内容
print(result.content)

2.3. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage 与 HumanMessage 类
from langchain.messages import AIMessage, HumanMessage

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant
                    role = "user" if isinstance(msg, HumanMessage) else "assistant"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

+
+# 定义与 DeepSeek 聊天模型交互的类
+class ChatDeepSeek:
+   
+   # 初始化方法
+   def __init__(self, model: str = "deepseek-chat", **kwargs):
+       """
+       初始化 ChatDeepSeek
+       
+       Args:
+           model: 模型名称,如 "deepseek-chat"
+           **kwargs: 其他参数(如 temperature, max_tokens 等)
+       """
+       # 设置模型名称
+       self.model = model
+       # 获取 api_key,优先从参数获取,否则从环境变量获取
+       self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
+       # 如果没有提供 api_key,则抛出异常
+       if not self.api_key:
+           raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
+       # 保存除 api_key 之外的其他参数,用于 API 调用
+       self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
+       # DeepSeek 的 API base URL
+       base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
+       # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
+       self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)
+   
+   # 调用模型生成回复的方法
+   def invoke(self, input, **kwargs):
+       """
+       调用模型生成回复
+       
+       Args:
+           input: 输入内容,可以是字符串或消息列表
+           **kwargs: 额外的 API 参数
+       
+       Returns:
+           AIMessage: AI 的回复消息
+       """
+       # 将输入数据转换为消息格式
+       messages = self._convert_input(input)
+       # 构建 API 请求参数字典
+       params = {
+           "model": self.model,
+           "messages": messages,
+           **self.model_kwargs,
+           **kwargs
+       }
+       # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
+       response = self.client.chat.completions.create(**params)
+       # 取出返回结果中的第一个选项
+       choice = response.choices[0]
+       # 获取消息内容
+       content = choice.message.content or ""
+       # 返回一个 AIMessage 对象
+       return AIMessage(content=content)
+   
+   # 内部方法,将输入转换为 API 需要的消息格式
+   def _convert_input(self, input):
+       """
+       将输入转换为 API 需要的消息格式
+       
+       Args:
+           input: 字符串或消息列表
+       
+       Returns:
+           list[dict]: API 格式的消息列表
+       """
+       # 输入为字符串时,直接封装为用户角色消息
+       if isinstance(input, str):
+           return [{"role": "user", "content": input}]
+       # 输入为列表时,需逐个元素处理
+       elif isinstance(input, list):
+           messages = []
+           # 遍历输入列表中的每个元素
+           for msg in input:
+               # 如果该元素为字符串,作为用户消息处理
+               if isinstance(msg, str):
+                   messages.append({"role": "user", "content": msg})
+               # 如果为 HumanMessage 或 AIMessage 类型,判断角色并获取内容
+               elif isinstance(msg, (HumanMessage, AIMessage)):
+                   # 判断消息类型,human 为 user,ai 为 assistant
+                   role = "user" if isinstance(msg, HumanMessage) else "assistant"
+                   # 获取消息内容属性
+                   content = msg.content if hasattr(msg, "content") else str(msg)
+                   messages.append({"role": role, "content": content})
+               # 如果该元素本身为字典,直接添加
+               elif isinstance(msg, dict):
+                   messages.append(msg)
+               # 如果是元组且长度为2,解包为 role 与 content
+               elif isinstance(msg, tuple) and len(msg) == 2:
+                   role, content = msg
+                   messages.append({"role": role, "content": content})
+           # 返回处理后的消息列表
+           return messages
+       else:
+           # 其他输入类型,转为字符串作为 user 消息
+           return [{"role": "user", "content": str(input)}]
+

3. ChatTongyi #

ChatTongyi是与阿里云通义千问大模型(Qwen)进行对话交互的类。 ChatTongyi 实现了与 langchain 统一接口兼容的调用方式,可以像调用 ChatOpenAI 模型一样便捷地使用通义千问,通过输入用户消息,获取大模型生成的回复。

使用 ChatTongyi 时,只需提供模型名称(如 "qwen-max")和通义千问的平台 API Key,即可完成初始化。 随后,可以通过 invoke 等接口传递用户输入,获得 AI 的回复结果。返回内容为 AIMessage 实例,方便后续与 langchain 的其它组件集成与串联。

uv add langchain_community  dashscope

3.1. init.py #

langchain/init.py

+# 从当前包中导入 ChatOpenAI、ChatDeepSeek 和 ChatTongyi 类
+from .chat_models import ChatOpenAI, ChatDeepSeek, ChatTongyi

# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
+__all__ = ["ChatOpenAI", "ChatDeepSeek", "ChatTongyi"]

3.2. tongyi.py #

3.tongyi.py

#from langchain_community.chat_models.tongyi import ChatTongyi
# 从 langchain 导入 ChatDeepSeek 类
from langchain import ChatTongyi

# 创建一个 ChatOpenAI 实例,使用 "gpt-4o" 模型
llm = ChatTongyi(model="qwen-max",api_key="sk-cc2054c29cf54fec92503bf7016cf383")

# 调用模型,传入用户消息,返回 AI 的回复结果
result = llm.invoke("你是谁")

# 打印 AI 回复的内容
print(result.content)

3.3. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage 与 HumanMessage 类
from langchain.messages import AIMessage, HumanMessage

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant
                    role = "user" if isinstance(msg, HumanMessage) else "assistant"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage 或 AIMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant
                    role = "user" if isinstance(msg, HumanMessage) else "assistant"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

+
+# 定义与通义千问(Tongyi)聊天模型交互的类
+class ChatTongyi:
+   
+   # 初始化方法
+   def __init__(self, model: str = "qwen-max", **kwargs):
+       """
+       初始化 ChatTongyi
+       
+       Args:
+           model: 模型名称,如 "qwen-max"
+           **kwargs: 其他参数(如 temperature, max_tokens 等)
+       """
+       # 设置模型名称
+       self.model = model
+       # 获取 api_key,优先从参数获取,否则从环境变量获取
+       self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
+       # 如果没有提供 api_key,则抛出异常
+       if not self.api_key:
+           raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
+       # 保存除 api_key 之外的其他参数,用于 API 调用
+       self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
+       # 通义千问的 API base URL(使用 OpenAI 兼容模式)
+       base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
+       # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
+       self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)
+   
+   # 调用模型生成回复的方法
+   def invoke(self, input, **kwargs):
+       """
+       调用模型生成回复
+       
+       Args:
+           input: 输入内容,可以是字符串或消息列表
+           **kwargs: 额外的 API 参数
+       
+       Returns:
+           AIMessage: AI 的回复消息
+       """
+       # 将输入数据转换为消息格式
+       messages = self._convert_input(input)
+       # 构建 API 请求参数字典
+       params = {
+           "model": self.model,
+           "messages": messages,
+           **self.model_kwargs,
+           **kwargs
+       }
+       # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
+       response = self.client.chat.completions.create(**params)
+       # 取出返回结果中的第一个选项
+       choice = response.choices[0]
+       # 获取消息内容
+       content = choice.message.content or ""
+       # 返回一个 AIMessage 对象
+       return AIMessage(content=content)
+   
+   # 内部方法,将输入转换为 API 需要的消息格式
+   def _convert_input(self, input):
+       """
+       将输入转换为 API 需要的消息格式
+       
+       Args:
+           input: 字符串或消息列表
+       
+       Returns:
+           list[dict]: API 格式的消息列表
+       """
+       # 输入为字符串时,直接封装为用户角色消息
+       if isinstance(input, str):
+           return [{"role": "user", "content": input}]
+       # 输入为列表时,需逐个元素处理
+       elif isinstance(input, list):
+           messages = []
+           # 遍历输入列表中的每个元素
+           for msg in input:
+               # 如果该元素为字符串,作为用户消息处理
+               if isinstance(msg, str):
+                   messages.append({"role": "user", "content": msg})
+               # 如果为 HumanMessage 或 AIMessage 类型,判断角色并获取内容
+               elif isinstance(msg, (HumanMessage, AIMessage)):
+                   # 判断消息类型,human 为 user,ai 为 assistant
+                   role = "user" if isinstance(msg, HumanMessage) else "assistant"
+                   # 获取消息内容属性
+                   content = msg.content if hasattr(msg, "content") else str(msg)
+                   messages.append({"role": role, "content": content})
+               # 如果该元素本身为字典,直接添加
+               elif isinstance(msg, dict):
+                   messages.append(msg)
+               # 如果是元组且长度为2,解包为 role 与 content
+               elif isinstance(msg, tuple) and len(msg) == 2:
+                   role, content = msg
+                   messages.append({"role": role, "content": content})
+           # 返回处理后的消息列表
+           return messages
+       else:
+           # 其他输入类型,转为字符串作为 user 消息
+           return [{"role": "user", "content": str(input)}]

4. 多轮对话 #

messages

多轮对话是指用户与 AI 助手能够连续多轮进行交流,每一轮对话都可以参考前面的上下文,从而生成更符合语境的智能回复。在 LangChain 的设计中,通过 messages 列表来积累多轮对话的所有历史消息,包括系统设定、用户提问、AI 回答等。每条消息都有明确的角色(如 "system"、"user"、"assistant"),以便模型理解对话的上下文和参与者。

具体流程如下:

  1. 首先,构造一个包含多条消息的消息列表 messages,每条消息都指定了角色和内容。
  2. 初始化一个聊天模型对象(如 ChatOpenAI),可以指定所用的大模型,比如 "gpt-4o"。
  3. 调用模型的 invoke 方法,把包含历史消息的 messages 传递给模型,模型根据历史对话内容生成新一轮回复。
  4. 得到 AI 的回复后,可将其添加到 messages 列表中,不断积累对话历史,实现自然流畅的多轮交流。

这种机制能够让 AI 在回答每个问题时都能“记住”之前的内容,更好地理解用户的意图,实现连贯的多轮对话体验。

4.1. conversations.py #

4.conversations.py

# 导入 ChatOpenAI 类
from langchain import ChatOpenAI
# 导入消息类型:AIMessage、HumanMessage、SystemMessage
from langchain.messages import AIMessage, HumanMessage, SystemMessage

# 创建一个 ChatOpenAI 实例,指定模型为 "gpt-4o"
llm = ChatOpenAI(model="gpt-4o")

# 构建消息列表,包括系统消息、用户消息和 AI 消息
messages = [
    # 系统消息,设定 AI 助手身份
    SystemMessage(content="你是一个AI助手,请回答用户的问题。"),
    # 用户消息,介绍自己并提问
    HumanMessage(content="你好,我叫张三,你是谁?"),
    # AI 消息,AI 的自我介绍
    AIMessage(content="我是GPT-4o,一个AI助手。"),
    # 用户再次提问
    HumanMessage(content="你知道我叫什么吗?"),
]

# 调用 llm.invoke,将消息传入模型,获得 AI 回复
result = llm.invoke(messages)

# 输出 AI 回复内容
print(result.content)

4.2. messages.py #

langchain/messages.py

# 定义基础消息类
class BaseMessage:
    # 基础消息类的文档字符串
    """基础消息类"""

    # 初始化方法,content为消息内容,其余为可选参数
    def __init__(self, content: str, **kwargs):
        """
        初始化消息

        Args:
            content: 消息内容
            **kwargs: 其他可选参数
        """
        # 保存消息内容
        self.content = content
        # 获取消息类型,如果未指定则默认为"base"
        self.type = kwargs.get("type", "base")
        # 遍历其他参数,并设置为成员变量(排除type参数)
        for key, value in kwargs.items():
            if key != "type":
                setattr(self, key, value)

    # 定义当对象被str()或print时的输出内容
    def __str__(self):
        return self.content

    # 定义对象的官方字符串表示,用于debug
    def __repr__(self):
        return f"{self.__class__.__name__}(content={self.content!r})"

# 定义用户消息类,继承自BaseMessage
class HumanMessage(BaseMessage):
    # 用户消息类的文档字符串
    """用户消息"""

    # 初始化方法,调用父类构造方法,并指定type为"human"
    def __init__(self, content: str, **kwargs):
        super().__init__(content, type="human", **kwargs)

# 定义AI消息类,继承自BaseMessage
class AIMessage(BaseMessage):
    # AI消息类的文档字符串
    """AI 消息"""

    # 初始化方法,调用父类构造方法,并指定type为"ai"
    def __init__(self, content: str, **kwargs):
        super().__init__(content, type="ai", **kwargs)

+# 定义系统消息类,继承自BaseMessage
+class SystemMessage(BaseMessage):
+   # 系统消息类的文档字符串
+   """系统消息"""
+   
+   # 初始化方法,调用父类构造方法,并指定type为"system"
+   def __init__(self, content: str, **kwargs):
+       super().__init__(content, type="system", **kwargs)
+

4.3. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
+# 从 langchain.messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
+from langchain.messages import AIMessage, HumanMessage, SystemMessage

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
+               # 如果为 HumanMessage/AIMessage/SystemMessage 类型,判断角色并获取内容
+               elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
+                   # 判断消息类型,human 为 user, ai 为 assistant, system 为 system
+                   if isinstance(msg, HumanMessage):
+                       role = "user"
+                   elif isinstance(msg, AIMessage):
+                       role = "assistant"
+                   elif isinstance(msg, SystemMessage):
+                       role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
+               # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
+               elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
+                   # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
+                   if isinstance(msg, HumanMessage):
+                       role = "user"
+                   elif isinstance(msg, AIMessage):
+                       role = "assistant"
+                   elif isinstance(msg, SystemMessage):
+                       role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:

    # 初始化方法
    def __init__(self, model: str = "qwen-max", **kwargs):
        """
        初始化 ChatTongyi

        Args:
            model: 模型名称,如 "qwen-max"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # 通义千问的 API base URL(使用 OpenAI 兼容模式)
        base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
        # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串或消息列表

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
+               # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
+               elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
+                   # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
+                   if isinstance(msg, HumanMessage):
+                       role = "user"
+                   elif isinstance(msg, AIMessage):
+                       role = "assistant"
+                   elif isinstance(msg, SystemMessage):
+                       role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

5. PromptTemplate #

PromptTemplate可以让你用类似 Python 字符串格式化的方法,灵活拼装对话提示词,实现参数化的对话输入。

主要特点和用法如下:

  • 通过 PromptTemplate.from_template("模板字符串") 创建模板实例,模板字符串中可用如 {变量} 这样的占位符。
  • 使用 prompt_template.format(变量名=值) 方法,将变量实际值填入模板,得到最终提示词文本。
  • 这样可以将用户输入、业务变量等灵活地插入到聊天提示词内容中,使调用大模型更加灵活和自动化。
  • 结合 ChatOpenAI 等聊天接口一起用,实现带参数的对话,如“你好,我叫张三,你是谁?”

这种模板机制,适合批量生成问题、实用型机器人等需要动态构造输入的应用场景。

5.1. PromptTemplate.py #

5.PromptTemplate.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
# 从 langchain 中导入 ChatOpenAI 和 PromptTemplate
from langchain import ChatOpenAI, PromptTemplate

# 使用模板字符串创建 PromptTemplate 实例
prompt_template = PromptTemplate.from_template("你好,我叫{name},你是谁?")

# 创建 ChatOpenAI 实例,指定模型为 "gpt-4o"
llm = ChatOpenAI(model="gpt-4o")

# 使用 prompt_template 格式化输入,将 name 替换为“张三”后传入 llm.invoke 生成回复
result = llm.invoke(prompt_template.format(name="张三"))

# 输出 AI 回复的内容
print(result.content)

5.2. init.py #

langchain/init.py

# 从当前包中导入 ChatOpenAI、ChatDeepSeek 和 ChatTongyi 类
from .chat_models import ChatOpenAI, ChatDeepSeek, ChatTongyi
+# 从当前包中导入 PromptTemplate 类
+from .prompts import PromptTemplate

# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
+__all__ = ["ChatOpenAI", "ChatDeepSeek", "ChatTongyi", "PromptTemplate"]

5.3. prompts.py #

langchain/prompts.py

# 导入正则表达式模块
import re

# 定义 PromptTemplate 类,用于处理提示词模板
class PromptTemplate:
    # 提示词模板类的文档字符串
    """提示词模板类,用于格式化字符串模板"""

    # 初始化方法,传入模板字符串
    def __init__(self, template: str):
        """
        初始化 PromptTemplate

        Args:
            template: 模板字符串,支持 {变量名} 格式
        """
        # 保存传入的模板字符串
        self.template = template
        # 提取模板中的所有变量名,并保存到 input_variables 列表中
        self.input_variables = self._extract_variables(template)

    # 类方法,用于从模板字符串创建 PromptTemplate 实例
    @classmethod
    def from_template(cls, template: str):
        """
        从模板字符串创建 PromptTemplate 实例

        Args:
            template: 模板字符串,支持 {变量名} 格式

        Returns:
            PromptTemplate 实例
        """
        # 返回通过模板字符串实例化的 PromptTemplate 对象
        return cls(template=template)

    # 格式化方法,使用传入的参数对模板进行填充
    def format(self, **kwargs):
        """
        使用提供的参数格式化模板

        Args:
            **kwargs: 用于填充模板变量的关键字参数

        Returns:
            格式化后的字符串
        """
        # 计算哪些必需变量未被提供
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果存在未提供的变量,抛出异常
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")

        # 使用模板字符串的 format 方法进行字符串填充
        return self.template.format(**kwargs)

    # 内部方法,从模板字符串中提取变量名
    def _extract_variables(self, template: str):
        """
        从模板字符串中提取所有变量名

        Args:
            template: 模板字符串

        Returns:
            变量名列表
        """
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式} 的结构
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 用正则表达式查找所有匹配项
        matches = re.findall(pattern, template)
        # 去重并保持原有顺序,返回变量名列表
        return list(dict.fromkeys(matches))  # 保持顺序并去重

6. ChatPromptTemplate #

ChatPromptTemplate 是 LangChain 中用于多轮对话消息生成与管理的消息模板类。它允许你灵活地定义一组多角色(如 system、human、ai)的消息模板,并通过变量注入,动态生成一组完整的对话消息。

主要特点

  • 支持多消息类型:你可以用 ("system", "..."), ("human", "..."), ("ai", "...") 等多种角色模板描述一轮或多轮对话,甚至嵌入变量如 {name}、{user_input}。
  • 模板化对话历史:只需定义好消息模板,传入变量,就能快速生成符合大模型 API 的消息结构,方便多轮上下文对话。
  • 集成 with LLM:可直接将 ChatPromptTemplate 生成的 prompt_value 作为模型输入,自动转换为符合格式的消息列表,实现更灵活的对话场景。

用法举例

  1. 定义一个对话模板,包含系统、人类、AI 的多条消息,并支持变量。
  2. 通过 .invoke() 或 .format_messages() 方法传入变量,获得格式化后的消息列表。
  3. 可以将这些消息直接传递给 llm.invoke(),实现多轮上下文对话,或做进一步处理。

这种机制大大提升了多轮对话 prompt 的可复用性与可维护性,尤其适合构建需要“记忆历史轮次”的聊天机器人。

6.1. init.py #

langchain/init.py

# 从当前包中导入 ChatOpenAI、ChatDeepSeek 和 ChatTongyi 类
from .chat_models import ChatOpenAI, ChatDeepSeek, ChatTongyi
+# 从当前包中导入 PromptTemplate 及相关类
+from .prompts import (
+   PromptTemplate,
+   ChatPromptTemplate,
+   SystemMessagePromptTemplate,
+   HumanMessagePromptTemplate,
+   AIMessagePromptTemplate,
+)

# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
+__all__ = [
+   "ChatOpenAI",
+   "ChatDeepSeek",
+   "ChatTongyi",
+   "PromptTemplate",
+   "ChatPromptTemplate",
+   "SystemMessagePromptTemplate",
+   "HumanMessagePromptTemplate",
+   "AIMessagePromptTemplate",
+]

6.2. ChatPromptTemplate.py #

6.ChatPromptTemplate.py

# 从 langchain 库中导入 ChatOpenAI、ChatPromptTemplate 及消息模板类
from langchain import ChatOpenAI, ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate

# 创建一个 ChatOpenAI 实例,指定模型为 "gpt-4o"
llm = ChatOpenAI(model="gpt-4o")

# 定义一个 ChatPromptTemplate,包含系统消息、人类消息和 AI 消息,多轮对话模板
template = ChatPromptTemplate(
    [
        ("system", "你是一个乐于助人的 AI 机器人。你的名字叫{name}。"),
        ("human", "你好,你最近怎么样?"),
        ("ai", "我很好,谢谢你的关心!"),
        ("human", "{user_input}"),
    ]
)

# 使用模板注入变量,生成 prompt_value
prompt_value = template.invoke(
    {
        "name": "小助",
        "user_input": "你叫什么名字?",
    }
)

# 打印 prompt_value 转换为字符串形式的内容
print(prompt_value.to_string())

# 打印 prompt_value 转换为消息对象列表的内容
print(prompt_value.to_messages())

# 使用 prompt_value 调用 llm.invoke,获得 AI 回复
result = llm.invoke(prompt_value)

# 打印 AI 回复的内容
print(result.content)

# 使用 from_messages 方法,基于各类消息模板构造 ChatPromptTemplate
template = ChatPromptTemplate.from_messages([
    SystemMessagePromptTemplate.from_template("你是一个乐于助人的 AI 机器人。你的名字叫{name}。"),
    HumanMessagePromptTemplate.from_template("你好,你最近怎么样?"),
    AIMessagePromptTemplate.from_template("我很好,谢谢你的关心!"),
    HumanMessagePromptTemplate.from_template("{user_input}"),
])

# 格式化消息模板,传入变量,生成消息列表
prompt_messages = template.format_messages(name="小助", user_input="你叫什么名字?")

# 打印 format_messages 返回的消息列表
print("format_messages 返回的消息列表:")
print(prompt_messages)

# 使用格式化后的消息列表调用 llm.invoke,获取 AI 回复
result = llm.invoke(prompt_messages)

# 打印 AI 回复的内容
print(result.content)

6.3. prompts.py #

langchain/prompts.py

# 导入正则表达式模块
import re
+# 导入各种消息类型
+from .messages import SystemMessage, HumanMessage, AIMessage
# 定义 PromptTemplate 类,用于处理提示词模板
class PromptTemplate:
    # 提示词模板类的文档字符串
    """提示词模板类,用于格式化字符串模板"""

    # 初始化方法,传入模板字符串
    def __init__(self, template: str):
        """
        初始化 PromptTemplate

        Args:
            template: 模板字符串,支持 {变量名} 格式
        """
        # 保存传入的模板字符串
        self.template = template
        # 提取模板中的所有变量名,并保存到 input_variables 列表中
        self.input_variables = self._extract_variables(template)

    # 类方法,用于从模板字符串创建 PromptTemplate 实例
    @classmethod
    def from_template(cls, template: str):
        """
        从模板字符串创建 PromptTemplate 实例

        Args:
            template: 模板字符串,支持 {变量名} 格式

        Returns:
            PromptTemplate 实例
        """
        # 返回通过模板字符串实例化的 PromptTemplate 对象
        return cls(template=template)

    # 格式化方法,使用传入的参数对模板进行填充
    def format(self, **kwargs):
        """
        使用提供的参数格式化模板

        Args:
            **kwargs: 用于填充模板变量的关键字参数

        Returns:
            格式化后的字符串
        """
        # 计算哪些必需变量未被提供
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果存在未提供的变量,抛出异常
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")

        # 使用模板字符串的 format 方法进行字符串填充
        return self.template.format(**kwargs)

    # 内部方法,从模板字符串中提取变量名
    def _extract_variables(self, template: str):
        """
        从模板字符串中提取所有变量名

        Args:
            template: 模板字符串

        Returns:
            变量名列表
        """
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式} 的结构
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 用正则表达式查找所有匹配项
        matches = re.findall(pattern, template)
        # 去重并保持原有顺序,返回变量名列表
        return list(dict.fromkeys(matches))  # 保持顺序并去重
+# 定义消息提示词值类,用于存储格式化后的消息列表
+class ChatPromptValue:
+   """聊天提示词值类,包含格式化后的消息列表"""
+
+   # 初始化方法,接收消息列表
+   def __init__(self, messages):
+       """
+       初始化 ChatPromptValue
+
+       Args:
+           messages: 消息列表
+       """
+       # 保存消息列表到实例变量
+       self.messages = messages
+
+   # 将消息对象列表转换为字符串表示的方法
+   def to_string(self):
+       """
+       将消息列表转换为字符串表示
+
+       Returns:
+           字符串表示
+       """
+       # 创建一个空列表用于存放每条消息的字符串
+       parts = []
+       # 遍历每条消息
+       for msg in self.messages:
+           # 判断消息对象是否有 type 和 content 属性(区分类别)
+           if hasattr(msg, 'type') and hasattr(msg, 'content'):
+               # 定义角色映射字典,将内部角色类型转为大写首字母字符串
+               role_map = {
+                   "system": "System",
+                   "human": "Human",
+                   "ai": "AI"
+               }
+               # 获取角色字符串
+               role = role_map.get(msg.type, msg.type.capitalize())
+               # 组装为“角色: 内容”的形式插入 parts
+               parts.append(f"{role}: {msg.content}")
+           else:
+               # 不是预定消息对象就直接转为字符串
+               parts.append(str(msg))
+       # 用换行符拼接所有消息字符串后返回
+       return "\n".join(parts)
+
+   # 返回消息对象列表的方法
+   def to_messages(self):
+       """
+       返回消息列表
+
+       Returns:
+           消息列表
+       """
+       # 直接返回消息列表
+       return self.messages
+
+# 定义基础消息提示词模板类
+class BaseMessagePromptTemplate:
+   """基础消息提示词模板类"""
+
+   # 初始化方法,必须给定 PromptTemplate
+   def __init__(self, prompt: PromptTemplate):
+       """
+       初始化消息提示词模板
+
+       Args:
+           prompt: PromptTemplate 实例
+       """
+       # 保存 PromptTemplate 到实例变量
+       self.prompt = prompt
+
+   # 工厂方法,用模板字符串直接创建本类实例
+   @classmethod
+   def from_template(cls, template: str):
+       """
+       从模板字符串创建消息提示词模板
+
+       Args:
+           template: 模板字符串
+
+       Returns:
+           消息提示词模板实例
+       """
+       # 使用 PromptTemplate.from_template 创建 PromptTemplate
+       prompt = PromptTemplate.from_template(template)
+       # 返回本类实例
+       return cls(prompt=prompt)
+
+   # 格式化当前模板,返回消息对象
+   def format(self, **kwargs):
+       """
+       格式化消息模板
+
+       Args:
+           **kwargs: 用于填充模板变量的关键字参数
+
+       Returns:
+           格式化后的消息对象
+       """
+       # 用 PromptTemplate 格式化获取内容字符串
+       content = self.prompt.format(**kwargs)
+       # 调用具体子类的 _create_message 创建消息对象
+       return self._create_message(content)

+   # 创建消息对象(子类实现)
+   def _create_message(self, content):
+       """创建消息对象,由子类实现"""
+       # 子类必须实现该方法
+       raise NotImplementedError
+
+# 定义系统消息提示词模板类
+class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
+   """系统消息提示词模板"""
+
+   # 创建 SystemMessage 对象实现
+   def _create_message(self, content):
+       """创建 SystemMessage 对象"""
+       # 延迟导入 SystemMessage 类型
+       from langchain.messages import SystemMessage
+       # 以 content 创建 SystemMessage 并返回
+       return SystemMessage(content=content)
+
+# 定义人类消息提示词模板类
+class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
+   """人类消息提示词模板"""
+
+   # 创建 HumanMessage 对象实现
+   def _create_message(self, content):
+       """创建 HumanMessage 对象"""
+       # 延迟导入 HumanMessage 类型
+       from langchain.messages import HumanMessage
+       # 以 content 创建 HumanMessage 并返回
+       return HumanMessage(content=content)
+
+# 定义AI消息提示词模板类
+class AIMessagePromptTemplate(BaseMessagePromptTemplate):
+   """AI消息提示词模板"""
+
+   # 创建 AIMessage 对象实现
+   def _create_message(self, content):
+       """创建 AIMessage 对象"""
+       # 延迟导入 AIMessage 类型
+       from langchain.messages import AIMessage
+       # 以 content 创建 AIMessage 并返回
+       return AIMessage(content=content)
+
+# 定义聊天提示词模板类
+class ChatPromptTemplate:
+   """聊天提示词模板类,用于创建多轮对话的提示词"""
+
+   # 初始化方法,接收消息列表
+   def __init__(self, messages):
+       """
+       初始化 ChatPromptTemplate
+
+       Args:
+           messages: 消息列表,可以是:
+               - 元组列表:[(role, template), ...]
+               - 消息提示词模板列表:[SystemMessagePromptTemplate, ...]
+       """
+       # 保存消息列表到实例变量
+       self.messages = messages
+       # 提取所有输入变量到实例变量
+       self.input_variables = self._extract_input_variables()
+
+   # 类方法:由消息提示词模板列表创建 ChatPromptTemplate
+   @classmethod
+   def from_messages(cls, messages):
+       """
+       从消息提示词模板列表创建 ChatPromptTemplate
+
+       Args:
+           messages: 消息提示词模板列表
+
+       Returns:
+           ChatPromptTemplate 实例
+       """
+       # 初始化 ChatPromptTemplate 实例并返回
+       return cls(messages=messages)
+
+   # 根据输入变量填充模板并生成 ChatPromptValue
+   def invoke(self, input_variables):
+       """
+       使用提供的变量格式化模板
+
+       Args:
+           input_variables: 包含模板变量的字典
+
+       Returns:
+           ChatPromptValue 实例
+       """
+       # 创建空列表存储格式化后的消息
+       formatted_messages = []
+       # 遍历每个消息模板
+       for msg in self.messages:
+           # 如果是元组 (role, template)
+           if isinstance(msg, tuple) and len(msg) == 2:
+               role, template_str = msg
+               # 用 PromptTemplate.from_template 创建模板对象
+               prompt = PromptTemplate.from_template(template_str)
+               # 用输入变量格式化模板获取内容
+               content = prompt.format(**input_variables)
+               # 根据角色生成不同类型消息对象
+               if role == "system":
+                   formatted_messages.append(SystemMessage(content=content))
+               elif role == "human" or role == "user":
+                   formatted_messages.append(HumanMessage(content=content))
+               elif role == "ai" or role == "assistant":
+                   formatted_messages.append(AIMessage(content=content))
+           # 如果是 BaseMessagePromptTemplate 实例
+           elif isinstance(msg, BaseMessagePromptTemplate):
+               # 用输入变量格式化并生成消息对象
+               formatted_msg = msg.format(**input_variables)
+               formatted_messages.append(formatted_msg)
+           # 如果已经是消息对象,直接加入列表
+           else:
+               formatted_messages.append(msg)
+       # 封装成 ChatPromptValue 并返回
+       return ChatPromptValue(messages=formatted_messages)
+
+   # 直接返回格式化后的消息对象列表
+   def format_messages(self, **kwargs):
+       """
+       使用提供的变量格式化模板,直接返回消息列表
+
+       Args:
+           **kwargs: 用于填充模板变量的关键字参数
+
+       Returns:
+           格式化后的消息列表
+       """
+       # 新建格式化后消息的列表
+       formatted_messages = []
+       # 遍历每个消息模板
+       for msg in self.messages:
+           # 如果是元组 (role, template)
+           if isinstance(msg, tuple) and len(msg) == 2:
+               role, template_str = msg
+               # 用 PromptTemplate.from_template 创建模板对象
+               prompt = PromptTemplate.from_template(template_str)
+               # 用输入变量格式化模板获取内容
+               content = prompt.format(**kwargs)
+               # 延迟导入各种消息类型
+               from langchain.messages import SystemMessage, HumanMessage, AIMessage
+               # 根据角色生成不同类型消息对象
+               if role == "system":
+                   formatted_messages.append(SystemMessage(content=content))
+               elif role == "human" or role == "user":
+                   formatted_messages.append(HumanMessage(content=content))
+               elif role == "ai" or role == "assistant":
+                   formatted_messages.append(AIMessage(content=content))
+           # 如果是 BaseMessagePromptTemplate 实例
+           elif isinstance(msg, BaseMessagePromptTemplate):
+               # 用输入变量格式化并生成消息对象
+               formatted_msg = msg.format(**kwargs)
+               formatted_messages.append(formatted_msg)
+           # 如果已经是消息对象,直接加入列表
+           else:
+               formatted_messages.append(msg)
+       # 返回格式化后的消息对象列表
+       return formatted_messages
+
+   # 提取所有消息模板中出现的变量名,存入 input_variables
+   def _extract_input_variables(self):
+       """
+       从所有消息模板中提取输入变量
+
+       Returns:
+           输入变量列表
+       """
+       # 用集合去重存储所有变量名
+       variables = set()
+       # 遍历消息模板
+       for msg in self.messages:
+           # 如果是元组 (role, template)
+           if isinstance(msg, tuple) and len(msg) == 2:
+               _, template_str = msg
+               # 用 PromptTemplate.from_template 提取变量
+               prompt = PromptTemplate.from_template(template_str)
+               variables.update(prompt.input_variables)
+           # 如果是 BaseMessagePromptTemplate 实例
+           elif isinstance(msg, BaseMessagePromptTemplate):
+               variables.update(msg.prompt.input_variables)
+       # 返回变量名列表
+       return list(variables)

6.4. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
from langchain.messages import AIMessage, HumanMessage, SystemMessage

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
+           input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
+       # 如果是 ChatPromptValue,转换为消息列表
+       from langchain.prompts import ChatPromptValue
+       if isinstance(input, ChatPromptValue):
+           input = input.to_messages()
+       
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage/SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant, system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
+           input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
+       # 如果是 ChatPromptValue,转换为消息列表
+       from langchain.prompts import ChatPromptValue
+       if isinstance(input, ChatPromptValue):
+           input = input.to_messages()
+       
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:

    # 初始化方法
    def __init__(self, model: str = "qwen-max", **kwargs):
        """
        初始化 ChatTongyi

        Args:
            model: 模型名称,如 "qwen-max"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # 通义千问的 API base URL(使用 OpenAI 兼容模式)
        base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
        # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
+           input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
+       # 如果是 ChatPromptValue,转换为消息列表
+       from langchain.prompts import ChatPromptValue
+       if isinstance(input, ChatPromptValue):
+           input = input.to_messages()
+       
        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

7. MessagesPlaceholder #

在实际对话中,我们常常需要动态地插入“多轮对话历史”,比如让 AI 模型记住前面的聊天内容。这时候,MessagesPlaceholder 就派上用场了!

MessagesPlaceholder是一种特殊的“变量插槽”,可以让你在 ChatPromptTemplate 中占位——例如位置为 "history"。你只需把历史消息(如多条 HumanMessage、AIMessage)作为变量传入占位符对应的 key,LangChain 会自动展开这些历史对话,并融入模板指定的位置,实现真正支持多轮上下文的智能对话。

主要用途如下:

  • 插入、扩展聊天历史,丰富 LLM 对当前轮问题的理解上下文。
  • 将不同来源(数据库、缓存、临时整理等)的消息灵活组织进 prompt。
  • 用于多 Agent 对话、流程驱动型对话等高级应用场景。
  • 可以自定义组合:混合 MessagesPlaceholder 与普通消息模板灵活嵌套。

这样一来,你只需准备好历史消息和新的问题,通过格式化方法一并传入,LangChain 就能自动组装起完整、顺畅的多轮对话 prompt,大大提升对话智能和用户体验。

7.1. MessagesPlaceholder.py #

7.MessagesPlaceholder.py

#from langchain_core.prompts import ChatPromptTemplate,MessagesPlaceholder
#from langchain_openai import ChatOpenAI
#from langchain_core.messages import HumanMessage, AIMessage

from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.messages import HumanMessage, AIMessage

# 历史对话,随便举个例子
history = [
    HumanMessage(content="你好"),
    AIMessage(content="你好,很高兴见到你"),
]

# 构造一个包含 MessagesPlaceholder 的模板
template = ChatPromptTemplate(
    [
        ("system", "你是一个乐于助人的 AI 助手。"),
        MessagesPlaceholder("history"),
        ("human", "{question}"),
    ]
)

# 创建 LLM 实例
llm = ChatOpenAI(model="gpt-4o")

# 使用 format_messages 注入历史消息和本次用户问题
prompt_messages = template.format_messages(
    history=history,
    question="请介绍一下你自己?",
)

print("format_messages 输出的消息:")
for msg in prompt_messages:
    print(msg)

# 使用生成的消息列表直接调用模型
response = llm.invoke(prompt_messages)
print("\n模型回复:")
print(response.content)

7.2. init.py #

langchain/init.py

# 定义本模块对外暴露的接口(即可通过 from langchain import ... 导入的名称)
+__all__ = []

7.3. prompts.py #

langchain/prompts.py

+# 导入正则表达式库
import re
+# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage
+
+# 定义提示词模板类
class PromptTemplate:
+   # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""
+
+   # 构造方法
    def __init__(self, template: str):
+       # 保存模板字符串
        self.template = template
+       # 提取模板中的变量名
        self.input_variables = self._extract_variables(template)
+
+   # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
+       # 返回PromptTemplate实例
        return cls(template=template)
+
+   # 格式化模板字符串
    def format(self, **kwargs):
+       # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
+       # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
+       # 使用format方法将变量填充到模板字符串
        return self.template.format(**kwargs)
+
+   # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
+       # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
+       # 匹配所有变量名
        matches = re.findall(pattern, template)
+       # 去重但保持顺序返回列表
+       return list(dict.fromkeys(matches))
+
+# 定义格式化消息值类
class ChatPromptValue:
+   # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

+   # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

+   # 将消息对象列表转为字符串
    def to_string(self):
+       # 新建一个用于存放字符串的列表
        parts = []
+       # 遍历每个消息
        for msg in self.messages:
+           # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
+               # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
+               # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
+               # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
+               # 其他对象直接str
                parts.append(str(msg))
+       # 用换行符拼接所有消息并返回
        return "\n".join(parts)

+   # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
+   # 基础消息提示词模板类
    """基础消息提示词模板类"""

+   # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
+       # 保存PromptTemplate到实例变量
        self.prompt = prompt

+   # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
+       # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
+       # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
+       # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
+       # 由子类方法创建对应类型消息对象
        return self._create_message(content)

+   # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
+   # 系统消息提示词模板类
    """系统消息提示词模板"""

+   # 创建SystemMessage对象
    def _create_message(self, content):
+       # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
+       # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
+   # 人类消息提示词模板类
    """人类消息提示词模板"""

+   # 创建HumanMessage对象
    def _create_message(self, content):
+       # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
+       # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
+   # AI消息提示词模板类
    """AI消息提示词模板"""

+   # 创建AIMessage对象
    def _create_message(self, content):
+       # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
+       # 返回生成的AIMessage对象
        return AIMessage(content=content)

+# 定义动态消息列表占位符类
+class MessagesPlaceholder:
+   # 在聊天模板中插入动态消息列表的占位符
+   """在聊天模板中插入动态消息列表的占位符"""
+
+   # 构造方法,存储变量名
+   def __init__(self, variable_name: str):
+       self.variable_name = variable_name
+
# 定义聊天提示词模板类
class ChatPromptTemplate:
+   # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

+   # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

+   # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
+       # 返回通过messages参数新建的实例
        return cls(messages=messages)

+   # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
+       # 格式化所有消息对象
+       formatted_messages = self._format_all_messages(input_variables)
+       # 返回ChatPromptValue对象
+       return ChatPromptValue(messages=formatted_messages)

+   # 使用提供的变量格式化模板,返回消息列表
+   def format_messages(self, **kwargs):
+       # 格式化所有消息并返回
+       return self._format_all_messages(kwargs)

+   # 提取所有输入变量
+   def _extract_input_variables(self):
+       # 用集合避免变量重复
+       variables = set()
+       # 遍历所有消息模板或对象
        for msg in self.messages:
+           # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
+               _, template_str = msg
+               # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
+               variables.update(prompt.input_variables)
+           # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
+               variables.update(msg.prompt.input_variables)
+           # 如果是占位符对象
+           elif isinstance(msg, MessagesPlaceholder):
+               variables.add(msg.variable_name)
+       # 返回变量名列表
+       return list(variables)

+   # 给所有消息模板/对象填充变量并变为消息对象列表
+   def _format_all_messages(self, variables):
+       # 存放格式化后消息
        formatted_messages = []
+       # 遍历每个消息模板或对象
        for msg in self.messages:
+           # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
+               content = prompt.format(**variables)
+               formatted_messages.append(self._create_message_from_role(role, content))
+           # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
+               formatted_messages.append(msg.format(**variables))
+           # 如果是占位符对象
+           elif isinstance(msg, MessagesPlaceholder):
+               placeholder_messages = self._coerce_placeholder_value(
+                   msg.variable_name, variables.get(msg.variable_name)
+               )
+               formatted_messages.extend(placeholder_messages)
+           # 其他情况直接追加
            else:
                formatted_messages.append(msg)
+       # 返回格式化的消息列表
        return formatted_messages

+   # 处理占位符对象的值,返回消息对象列表
+   def _coerce_placeholder_value(self, variable_name, value):
+       # 如果未传入变量,抛出异常
+       if value is None:
+           raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
+       # 如果是ChatPromptValue实例,转换为消息列表
+       if isinstance(value, ChatPromptValue):
+           return value.to_messages()
+       # 如果已经是消息对象/结构列表,则依次转换
+       if isinstance(value, list):
+           return [self._coerce_single_message(item) for item in value]
+       # 其他情况尝试单个转换
+       return [self._coerce_single_message(value)]
+
+   # 单个原始值转换为消息对象
+   def _coerce_single_message(self, value):
+       # 已是有效消息类型,直接返回
+       if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
+           return value
+       # 有type和content属性,也当消息对象直接返回
+       if hasattr(value, "type") and hasattr(value, "content"):
+           return value
+       # 字符串变为人类消息
+       if isinstance(value, str):
+           return HumanMessage(content=value)
+       # (role, content)元组转为指定角色的消息
+       if isinstance(value, tuple) and len(value) == 2:
+           role, content = value
+           return self._create_message_from_role(role, content)
+       # 字典,默认user角色
+       if isinstance(value, dict):
+           role = value.get("role", "user")
+           content = value.get("content", "")
+           return self._create_message_from_role(role, content)
+       # 其他无法识别类型,抛出异常
+       raise TypeError("无法将占位符内容转换为消息")
+
+   # 通过角色字符串和内容构建标准消息对象
+   def _create_message_from_role(self, role, content):
+       # 角色字符串全部转小写
+       normalized_role = role.lower()
+       # 系统角色
+       if normalized_role == "system":
+           return SystemMessage(content=content)
+       # 人类/用户角色
+       if normalized_role in ("human", "user"):
+           return HumanMessage(content=content)
+       # AI/assistant角色
+       if normalized_role in ("ai", "assistant"):
+           return AIMessage(content=content)
+       # 其它未知角色抛异常
+       raise ValueError(f"未知的消息角色: {role}")

8. FewShotPromptTemplate #

FewShotPromptTemplate 是 LangChain 中用于构造 few-shot learning 提示词(prompt)的模板类。 它的核心思想是:在向大语言模型(LLM)提问时,除了用户具体的问题以外,先给出若干条带“输入-输出”示例,让 LLM 学习和模仿示例的风格或逻辑,从而更好地理解实际问题的意图。这种方法被称为“few-shot prompting”。

主要功能与用法简介:

  1. 示例集合(examples)
    以列表的形式存储多条“问题-答案”或“输入-输出”内容,帮助模型理解“输入应该如何被回答”。

  2. 示例模板(example_prompt)
    每个例子的排版格式。例如可以统一描述为“问题:{question}\n答案:{answer}”,从而方便后续批量格式化。

  3. 前缀与后缀(prefix/suffix)
    在 few-shot 示例块之前或之后添加整体性的指导语、系统设定或引导描述。

  4. 分隔符(example_separator)
    用于分隔多个示例之间的内容,提升可读性。

  5. 自动格式化与变量填充
    通过 .format(user_question="...") 方法,仅需指定实际用户问题变量,即可快速生成完整的 few-shot prompt。

  6. 自动变量推断
    如果未手动指定输入变量名,类会自动分析需要的变量(如 {user_question}),避免出错。

适合场景举例:

  • 算术推理:给定算式和答案的示例,让模型模拟并回答新问题。
  • 语义改写、翻译等:展示若干输入对应输出,要求模型模仿风格进行新输入处理。
  • 文本分类、槽位抽取等任务的 few-shot 演示。

优势
通过 FewShotPromptTemplate,可以大幅提升 prompt 工程效率和复用性,让提示词拼接与变量管理变得结构化、可维护。

8.1. init.py #

langchain/init.py


+__all__ = []

8.2. FewShotPromptTemplate.py #

8.FewShotPromptTemplate.py

#from langchain_core.prompts import PromptTemplate,FewShotPromptTemplate
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate,FewShotPromptTemplate

# few-shot 示例集合
examples = [
    {"question": "1 plus 1 等于多少?", "answer": "答案是 2。"},
    {"question": "2 plus 2 等于多少?", "answer": "答案是 4。"},
]

# 每条示例的格式
example_prompt = PromptTemplate.from_template(
    "示例问题:{question}\n示例回答:{answer}"
)

# 构建 few-shot 提示词模板
few_shot_prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
    prefix="你是一个擅长算术的 AI 助手。以下是一些示例:",
    suffix="请回答用户问题:{user_question}\nAI:",
    input_variables=["user_question"],
)

# 将用户问题注入模板
formatted_prompt = few_shot_prompt.format(user_question="3 plus 5 等于多少?")
print("few-shot 提示词:\n")
print(formatted_prompt)

# 调用模型
llm = ChatOpenAI(model="gpt-4o")
response = llm.invoke(formatted_prompt)
print("\n模型回答:")
print(response.content)

8.3. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
    def __init__(self, template: str):
        # 保存模板字符串
        self.template = template
        # 提取模板中的变量名
        self.input_variables = self._extract_variables(template)

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
        return self.template.format(**kwargs)

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
+# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
+class FewShotPromptTemplate:
+   # few-shot 提示词模板说明
+   """用于构造 few-shot 提示词的模板"""
+
+   # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
+   def __init__(
+       self,
+       *,
+       examples: list[dict],  # 示例列表,元素为字典
+       example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
+       prefix: str = "",  # 提示词前缀
+       suffix: str = "",  # 提示词后缀
+       example_separator: str = "\n\n",  # 示例之间的分隔符
+       input_variables: list[str] | None = None,  # 输入变量列表
+   ):
+       # 如果未传示例,则设为默认空列表
+       self.examples = examples or []
+       # 判断 example_prompt 参数类型
+       if isinstance(example_prompt, PromptTemplate):
+           # 如果是 PromptTemplate 实例,直接赋值
+           self.example_prompt = example_prompt
+       else:
+           # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
+           self.example_prompt = PromptTemplate.from_template(example_prompt)
+       # 保存前缀
+       self.prefix = prefix
+       # 保存后缀
+       self.suffix = suffix
+       # 保存示例分隔符
+       self.example_separator = example_separator
+       # 输入变量列表,未提供则自动推断
+       self.input_variables = input_variables or self._infer_input_variables()
+
+   # 内部方法:自动推断输入变量列表
+   def _infer_input_variables(self) -> list[str]:
+       # 创建空集合用于存放变量名
+       variables = set()
+       # 提取前缀中的变量并加入集合
+       variables.update(self._extract_vars(self.prefix))
+       # 提取后缀中的变量并加入集合
+       variables.update(self._extract_vars(self.suffix))
+       # 返回变量集合的列表形式
+       return list(variables)
+
+   # 内部方法:从文本中提取所有花括号包围的变量名
+   def _extract_vars(self, text: str) -> list[str]:
+       # 如果文本为空,返回空列表
+       if not text:
+           return []
+       # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
+       pattern = r"\{([^}:]+)(?::[^}]+)?\}"
+       # 使用正则在文本中查找所有变量名
+       matches = re.findall(pattern, text)
+       # 去重并保持顺序,返回变量名列表
+       return list(dict.fromkeys(matches))
+
+   # 向示例列表中添加一个新示例(字典类型)
+   def add_example(self, example: dict):
+       """动态添加单条示例"""
+       # 在示例列表末尾追加新示例
+       self.examples.append(example)
+
+   # 格式化所有示例,返回由字符串组成的列表
+   def format_examples(self) -> list[str]:
+       """返回格式化后的示例字符串列表"""
+       # 创建空列表用于保存格式化结果
+       formatted = []
+       # 遍历所有示例
+       for example in self.examples:
+           # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
+           formatted.append(self.example_prompt.format(**example))
+       # 返回格式化后的字符串列表
+       return formatted
+
+   # 格式化 few-shot 提示串,根据输入变量生成完整提示词
+   def format(self, **kwargs) -> str:
+       """根据传入变量生成 few-shot 提示词"""
+       # 检查传入的变量是否完整,缺失就抛异常
+       missing = set(self.input_variables) - set(kwargs.keys())
+       if missing:
+           raise ValueError(f"缺少必需的变量: {missing}")
+
+       # 创建 parts 列表,用于拼接最终组成部分
+       parts: list[str] = []
+       # 如果有前缀,则格式化并加入 parts
+       if self.prefix:
+           parts.append(self._format_text(self.prefix, **kwargs))
+
+       # 格式化所有示例并拼接为块
+       example_block = self.example_separator.join(self.format_examples())
+       # 如果示例块非空,加到 parts
+       if example_block:
+           parts.append(example_block)
+
+       # 如果有后缀,则格式化并加入 parts
+       if self.suffix:
+           parts.append(self._format_text(self.suffix, **kwargs))
+
+       # 用分隔符拼接所有组成部分,返回最终结果
+       return self.example_separator.join(part for part in parts if part)
+
+   # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
+   def _format_text(self, text: str, **kwargs) -> str:
+       # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
+       temp_prompt = PromptTemplate.from_template(text)
+       return temp_prompt.format(**kwargs)

9. load_prompt #

load_prompt 函数从 JSON 文件加载自定义提示词模板,并结合 LangChain 的 ChatOpenAI 实现自动化对话。 只需准备好符合规范的 JSON 格式的模板文件(如 prompt_template.json),即可快速实现多语言、多场景的提示词动态加载与应用。 这种方式提升了提示词的可维护性和灵活性,方便根据实际需求定制和更新 AI 助手的行为。 而配合 ChatOpenAI 等模型,可以轻松将模板化的自然语言提示与大语言模型推理结合,实现自动问答等多种应用场景。

9.1. prompt_template.json #

prompt_template.json

{
    "_type": "prompt",
    "template": "你是一个乐于助人的AI助手。用户问:{question},请回答:"
}

9.2. Load_prompt.py #

9.Load_prompt.py

#from langchain_core.prompts import load_prompt
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import load_prompt

# 从 JSON 文件加载提示词模板
prompt_template = load_prompt("prompt_template.json",encoding="utf-8")

# 创建 ChatOpenAI 实例
llm = ChatOpenAI(model="gpt-4o")

# 使用加载的模板格式化提示词
formatted_prompt = prompt_template.format(question="什么是人工智能?")

print("格式化后的提示词:")
print(formatted_prompt)
print("\n" + "="*50 + "\n")

# 调用模型生成回复
result = llm.invoke(formatted_prompt)
print("AI 回复:")
print(result.content)

9.3. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
+# 导入 JSON 和路径处理
+import json
+from pathlib import Path
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
    def __init__(self, template: str):
        # 保存模板字符串
        self.template = template
        # 提取模板中的变量名
        self.input_variables = self._extract_variables(template)

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
        return self.template.format(**kwargs)

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
class FewShotPromptTemplate:
    # few-shot 提示词模板说明
    """用于构造 few-shot 提示词的模板"""

    # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
    def __init__(
        self,
        *,
        examples: list[dict],  # 示例列表,元素为字典
        example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
        prefix: str = "",  # 提示词前缀
        suffix: str = "",  # 提示词后缀
        example_separator: str = "\n\n",  # 示例之间的分隔符
        input_variables: list[str] | None = None,  # 输入变量列表
    ):
        # 如果未传示例,则设为默认空列表
        self.examples = examples or []
        # 判断 example_prompt 参数类型
        if isinstance(example_prompt, PromptTemplate):
            # 如果是 PromptTemplate 实例,直接赋值
            self.example_prompt = example_prompt
        else:
            # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 保存前缀
        self.prefix = prefix
        # 保存后缀
        self.suffix = suffix
        # 保存示例分隔符
        self.example_separator = example_separator
        # 输入变量列表,未提供则自动推断
        self.input_variables = input_variables or self._infer_input_variables()

    # 内部方法:自动推断输入变量列表
    def _infer_input_variables(self) -> list[str]:
        # 创建空集合用于存放变量名
        variables = set()
        # 提取前缀中的变量并加入集合
        variables.update(self._extract_vars(self.prefix))
        # 提取后缀中的变量并加入集合
        variables.update(self._extract_vars(self.suffix))
        # 返回变量集合的列表形式
        return list(variables)

    # 内部方法:从文本中提取所有花括号包围的变量名
    def _extract_vars(self, text: str) -> list[str]:
        # 如果文本为空,返回空列表
        if not text:
            return []
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
        pattern = r"\{([^}:]+)(?::[^}]+)?\}"
        # 使用正则在文本中查找所有变量名
        matches = re.findall(pattern, text)
        # 去重并保持顺序,返回变量名列表
        return list(dict.fromkeys(matches))

    # 向示例列表中添加一个新示例(字典类型)
    def add_example(self, example: dict):
        """动态添加单条示例"""
        # 在示例列表末尾追加新示例
        self.examples.append(example)

    # 格式化所有示例,返回由字符串组成的列表
    def format_examples(self) -> list[str]:
        """返回格式化后的示例字符串列表"""
        # 创建空列表用于保存格式化结果
        formatted = []
        # 遍历所有示例
        for example in self.examples:
            # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
            formatted.append(self.example_prompt.format(**example))
        # 返回格式化后的字符串列表
        return formatted

    # 格式化 few-shot 提示串,根据输入变量生成完整提示词
    def format(self, **kwargs) -> str:
        """根据传入变量生成 few-shot 提示词"""
        # 检查传入的变量是否完整,缺失就抛异常
        missing = set(self.input_variables) - set(kwargs.keys())
        if missing:
            raise ValueError(f"缺少必需的变量: {missing}")

        # 创建 parts 列表,用于拼接最终组成部分
        parts: list[str] = []
        # 如果有前缀,则格式化并加入 parts
        if self.prefix:
            parts.append(self._format_text(self.prefix, **kwargs))

        # 格式化所有示例并拼接为块
        example_block = self.example_separator.join(self.format_examples())
        # 如果示例块非空,加到 parts
        if example_block:
            parts.append(example_block)

        # 如果有后缀,则格式化并加入 parts
        if self.suffix:
            parts.append(self._format_text(self.suffix, **kwargs))

        # 用分隔符拼接所有组成部分,返回最终结果
        return self.example_separator.join(part for part in parts if part)

    # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
    def _format_text(self, text: str, **kwargs) -> str:
        # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
        temp_prompt = PromptTemplate.from_template(text)
        return temp_prompt.format(**kwargs)
+
+
+# 定义一个从文件加载提示词模板的函数
+def load_prompt(path: str | Path,encoding: str | None = None) -> PromptTemplate:
+   """
+   从 JSON 文件加载提示词模板
+
+   Args:
+       path: 提示词配置文件的路径(支持 .json 格式)
+
+   Returns:
+       PromptTemplate 实例
+
+   JSON 文件格式示例:
+       {
+           "_type": "prompt",
+           "template": "你好,我叫{name},你是谁?"
+       }
+   """
+   # 将传入路径转换为 Path 对象
+   file_path = Path(path)
+   # 检查文件是否存在,不存在则抛出异常
+   if not file_path.exists():
+       raise FileNotFoundError(f"提示词文件不存在: {path}")
+   # 检查文件扩展名是否为 .json,不是则抛出异常
+   if file_path.suffix != ".json":
+       raise ValueError(f"只支持 .json 格式文件,当前文件: {file_path.suffix}")
+   # 打开文件并以 utf-8 编码读取 JSON 内容到 config 变量中
+   with file_path.open(encoding=encoding) as f:
+       config = json.load(f)
+   # 从配置中获取 _type 字段,默认值为 "prompt"
+   config_type = config.get("_type", "prompt")
+   # 如果配置类型不是 "prompt",则抛出异常
+   if config_type != "prompt":
+       raise ValueError(f"不支持的提示词类型: {config_type},当前只支持 'prompt'")
+   # 获取模板字符串,若不存在则抛出异常
+   template = config.get("template")
+   if template is None:
+       raise ValueError("配置文件中缺少 'template' 字段")
+   # 使用读取到的模板字符串创建 PromptTemplate 实例并返回
+   return PromptTemplate.from_template(template)
+

10. partial #

通过 partial,你可以预先填充部分变量,生成一个新的模板对象,
只需后续传入剩余的变量即可完成格式化。这在需要复用模板、但有部分上下文已知时非常方便。
例如:以 "你是一个{role}。用户{user_name}问:{question},请回答:" 为模板,
先用 partial 填充 role="AI助手" 和 user_name="张三",
只需提供 question 就能格式化完整提示词。

10.1. partial.py #

10.partial.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate

# 创建一个包含多个变量的模板
template = PromptTemplate.from_template(
    "你是一个{role}。用户{user_name}问:{question},请回答:"
)

print("原始模板的输入变量:", template.input_variables)

# 使用 partial 方法部分填充变量
# 这里预先填充 role 和 user_name,只保留 question 需要后续提供
partial_template = template.partial(role="AI助手", user_name="张三")

print("部分填充后的输入变量:", partial_template.input_variables)
print("部分填充的变量:", partial_template.partial_variables)

# 现在只需要提供剩余的变量
formatted_prompt = partial_template.format(question="什么是人工智能?")
print("格式化后的提示词:")
print(formatted_prompt)

# 创建 ChatOpenAI 实例并调用
llm = ChatOpenAI(model="gpt-4o")
result = llm.invoke(formatted_prompt)
print("AI 回复:")
print(result.content)

10.2. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
# 导入 JSON 和路径处理
import json
from pathlib import Path
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
+   def __init__(self, template: str, partial_variables: dict = None):
        # 保存模板字符串
        self.template = template
+       # 保存部分变量(已预填充的变量)
+       self.partial_variables = partial_variables or {}
        # 提取模板中的变量名
+       all_variables = self._extract_variables(template)
+       # 从所有变量中排除已部分填充的变量
+       self.input_variables = [v for v in all_variables if v not in self.partial_variables]

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
+       # 合并部分变量和用户提供的变量
+       all_vars = {**self.partial_variables, **kwargs}
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
+       return self.template.format(**all_vars)
+   
+   # 定义部分填充模板变量的方法,返回新的模板实例
+   def partial(self, **kwargs):
+       """
+       部分填充模板变量,返回一个新的 PromptTemplate 实例
+       
+       Args:
+           **kwargs: 要部分填充的变量及其值
+       
+       Returns:
+           新的 PromptTemplate 实例,其中指定的变量已被填充
+       
+       示例:
+           template = PromptTemplate.from_template("你好,我叫{name},我来自{city}")
+           partial_template = template.partial(name="张三")
+           # 现在只需要提供 city 参数
+           result = partial_template.format(city="北京")
+       """
+       # 合并现有对象的部分变量(partial_variables)和本次要填充的新变量
+       new_partial_variables = {**self.partial_variables, **kwargs}
+       # 使用原模板字符串和更新后的部分变量,创建新的 PromptTemplate 实例
+       new_template = PromptTemplate(
+           template=self.template,
+           partial_variables=new_partial_variables
+       )
+       # 返回新的 PromptTemplate 实例
+       return new_template

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
class FewShotPromptTemplate:
    # few-shot 提示词模板说明
    """用于构造 few-shot 提示词的模板"""

    # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
    def __init__(
        self,
        *,
        examples: list[dict],  # 示例列表,元素为字典
        example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
        prefix: str = "",  # 提示词前缀
        suffix: str = "",  # 提示词后缀
        example_separator: str = "\n\n",  # 示例之间的分隔符
        input_variables: list[str] | None = None,  # 输入变量列表
    ):
        # 如果未传示例,则设为默认空列表
        self.examples = examples or []
        # 判断 example_prompt 参数类型
        if isinstance(example_prompt, PromptTemplate):
            # 如果是 PromptTemplate 实例,直接赋值
            self.example_prompt = example_prompt
        else:
            # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 保存前缀
        self.prefix = prefix
        # 保存后缀
        self.suffix = suffix
        # 保存示例分隔符
        self.example_separator = example_separator
        # 输入变量列表,未提供则自动推断
        self.input_variables = input_variables or self._infer_input_variables()

    # 内部方法:自动推断输入变量列表
    def _infer_input_variables(self) -> list[str]:
        # 创建空集合用于存放变量名
        variables = set()
        # 提取前缀中的变量并加入集合
        variables.update(self._extract_vars(self.prefix))
        # 提取后缀中的变量并加入集合
        variables.update(self._extract_vars(self.suffix))
        # 返回变量集合的列表形式
        return list(variables)

    # 内部方法:从文本中提取所有花括号包围的变量名
    def _extract_vars(self, text: str) -> list[str]:
        # 如果文本为空,返回空列表
        if not text:
            return []
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
        pattern = r"\{([^}:]+)(?::[^}]+)?\}"
        # 使用正则在文本中查找所有变量名
        matches = re.findall(pattern, text)
        # 去重并保持顺序,返回变量名列表
        return list(dict.fromkeys(matches))

    # 向示例列表中添加一个新示例(字典类型)
    def add_example(self, example: dict):
        """动态添加单条示例"""
        # 在示例列表末尾追加新示例
        self.examples.append(example)

    # 格式化所有示例,返回由字符串组成的列表
    def format_examples(self) -> list[str]:
        """返回格式化后的示例字符串列表"""
        # 创建空列表用于保存格式化结果
        formatted = []
        # 遍历所有示例
        for example in self.examples:
            # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
            formatted.append(self.example_prompt.format(**example))
        # 返回格式化后的字符串列表
        return formatted

    # 格式化 few-shot 提示串,根据输入变量生成完整提示词
    def format(self, **kwargs) -> str:
        """根据传入变量生成 few-shot 提示词"""
        # 检查传入的变量是否完整,缺失就抛异常
        missing = set(self.input_variables) - set(kwargs.keys())
        if missing:
            raise ValueError(f"缺少必需的变量: {missing}")

        # 创建 parts 列表,用于拼接最终组成部分
        parts: list[str] = []
        # 如果有前缀,则格式化并加入 parts
        if self.prefix:
            parts.append(self._format_text(self.prefix, **kwargs))

        # 格式化所有示例并拼接为块
        example_block = self.example_separator.join(self.format_examples())
        # 如果示例块非空,加到 parts
        if example_block:
            parts.append(example_block)

        # 如果有后缀,则格式化并加入 parts
        if self.suffix:
            parts.append(self._format_text(self.suffix, **kwargs))

        # 用分隔符拼接所有组成部分,返回最终结果
        return self.example_separator.join(part for part in parts if part)

    # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
    def _format_text(self, text: str, **kwargs) -> str:
        # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
        temp_prompt = PromptTemplate.from_template(text)
        return temp_prompt.format(**kwargs)


# 定义一个从文件加载提示词模板的函数
def load_prompt(path: str | Path,encoding: str | None = None) -> PromptTemplate:
    """
    从 JSON 文件加载提示词模板

    Args:
        path: 提示词配置文件的路径(支持 .json 格式)

    Returns:
        PromptTemplate 实例

    JSON 文件格式示例:
        {
            "_type": "prompt",
            "template": "你好,我叫{name},你是谁?"
        }
    """
    # 将传入路径转换为 Path 对象
    file_path = Path(path)
    # 检查文件是否存在,不存在则抛出异常
    if not file_path.exists():
        raise FileNotFoundError(f"提示词文件不存在: {path}")
    # 检查文件扩展名是否为 .json,不是则抛出异常
    if file_path.suffix != ".json":
        raise ValueError(f"只支持 .json 格式文件,当前文件: {file_path.suffix}")
    # 打开文件并以 utf-8 编码读取 JSON 内容到 config 变量中
    with file_path.open(encoding=encoding) as f:
        config = json.load(f)
    # 从配置中获取 _type 字段,默认值为 "prompt"
    config_type = config.get("_type", "prompt")
    # 如果配置类型不是 "prompt",则抛出异常
    if config_type != "prompt":
        raise ValueError(f"不支持的提示词类型: {config_type},当前只支持 'prompt'")
    # 获取模板字符串,若不存在则抛出异常
    template = config.get("template")
    if template is None:
        raise ValueError("配置文件中缺少 'template' 字段")
    # 使用读取到的模板字符串创建 PromptTemplate 实例并返回
    return PromptTemplate.from_template(template)

11. PipelinePromptTemplate #

PipelinePromptTemplate是一种支持多阶段提示词构建的高级工具,适用于需要多步处理信息后再生成最终提示内容的复杂场景。

主要概念

  • 中间模板 (prompt_templates):一系列 PromptTemplate 实例,每个负责生成管道中的一个中间输出,常用于结构化或分步梳理上下文。
  • 最终模板 (final_prompt):一个 PromptTemplate,负责整合所有中间模板的输出 (output_0, output_1, …) 以及剩余用户输入的变量,生成最终用于模型调用的提示词文本。

工作流程说明

  1. 输入变量整合
    通过自动收集所有中间模板与最终模板中需要的输入变量,PipelinePromptTemplate 能清晰地知道“哪些变量由用户输入,哪些由上游生成”。调用 input_variables 可以查看所需所有变量名。
  2. 分阶段格式化
    当调用 format() 方法时,会依次用用户输入的数据格式化每个中间模板,生成 output_0, output_1 等变量。最终这些中间结果会与用户原始输入合并,传递给 final_prompt 进行总格式化,得到最终完整提示词文本。
  3. 多用途场景
    当需要“先总结上下文,再组合进最终问题”或“分步构造链式思路”的复杂提示词编排时,用管道模板可以极大简化代码和维护成本。

这种方式常用于长链路复杂场景、分阶段思考(Chain-of-Thought, CoT)等大模型提示词工程实践。

11.1. PipelinePromptTemplate.py #

11.PipelinePromptTemplate.py

#from langchain_core.prompts import PromptTemplate
#from langchain_core.prompts.pipeline import PipelinePromptTemplate
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate, PipelinePromptTemplate

# 创建第一个模板:生成问题描述
question_template = PromptTemplate.from_template(
    "用户问题:{user_question}"
)

# 创建第二个模板:生成上下文
context_template = PromptTemplate.from_template(
    "上下文:{context}"
)

# 创建最终模板:组合所有内容
final_template = PromptTemplate.from_template(
    "你是一个AI助手。\n{output_0}\n{output_1}\n请回答用户的问题。"
)

# 创建管道模板
pipeline = PipelinePromptTemplate(
    prompt_templates=[question_template, context_template],
    final_prompt=final_template
)

print("管道模板的输入变量:", pipeline.input_variables)
print("="*50)

# 格式化管道模板
formatted = pipeline.format(
    user_question="什么是Python?",
    context="Python是一种编程语言"
)

print("格式化后的提示词:")
print(formatted)
print("="*50)

# 创建 ChatOpenAI 实例并调用
llm = ChatOpenAI(model="gpt-4o")
result = llm.invoke(formatted)
print("AI 回复:")
print(result.content)

11.2. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
# 导入 JSON 和路径处理
import json
from pathlib import Path
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
    def __init__(self, template: str, partial_variables: dict = None):
        # 保存模板字符串
        self.template = template
        # 保存部分变量(已预填充的变量)
        self.partial_variables = partial_variables or {}
        # 提取模板中的变量名
        all_variables = self._extract_variables(template)
        # 从所有变量中排除已部分填充的变量
        self.input_variables = [v for v in all_variables if v not in self.partial_variables]

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
        # 合并部分变量和用户提供的变量
        all_vars = {**self.partial_variables, **kwargs}
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
        return self.template.format(**all_vars)

    # 定义部分填充模板变量的方法,返回新的模板实例
    def partial(self, **kwargs):
        """
        部分填充模板变量,返回一个新的 PromptTemplate 实例

        Args:
            **kwargs: 要部分填充的变量及其值

        Returns:
            新的 PromptTemplate 实例,其中指定的变量已被填充

        示例:
            template = PromptTemplate.from_template("你好,我叫{name},我来自{city}")
            partial_template = template.partial(name="张三")
            # 现在只需要提供 city 参数
            result = partial_template.format(city="北京")
        """
        # 合并现有对象的部分变量(partial_variables)和本次要填充的新变量
        new_partial_variables = {**self.partial_variables, **kwargs}
        # 使用原模板字符串和更新后的部分变量,创建新的 PromptTemplate 实例
        new_template = PromptTemplate(
            template=self.template,
            partial_variables=new_partial_variables
        )
        # 返回新的 PromptTemplate 实例
        return new_template

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
class FewShotPromptTemplate:
    # few-shot 提示词模板说明
    """用于构造 few-shot 提示词的模板"""

    # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
    def __init__(
        self,
        *,
        examples: list[dict],  # 示例列表,元素为字典
        example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
        prefix: str = "",  # 提示词前缀
        suffix: str = "",  # 提示词后缀
        example_separator: str = "\n\n",  # 示例之间的分隔符
        input_variables: list[str] | None = None,  # 输入变量列表
    ):
        # 如果未传示例,则设为默认空列表
        self.examples = examples or []
        # 判断 example_prompt 参数类型
        if isinstance(example_prompt, PromptTemplate):
            # 如果是 PromptTemplate 实例,直接赋值
            self.example_prompt = example_prompt
        else:
            # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 保存前缀
        self.prefix = prefix
        # 保存后缀
        self.suffix = suffix
        # 保存示例分隔符
        self.example_separator = example_separator
        # 输入变量列表,未提供则自动推断
        self.input_variables = input_variables or self._infer_input_variables()

    # 内部方法:自动推断输入变量列表
    def _infer_input_variables(self) -> list[str]:
        # 创建空集合用于存放变量名
        variables = set()
        # 提取前缀中的变量并加入集合
        variables.update(self._extract_vars(self.prefix))
        # 提取后缀中的变量并加入集合
        variables.update(self._extract_vars(self.suffix))
        # 返回变量集合的列表形式
        return list(variables)

    # 内部方法:从文本中提取所有花括号包围的变量名
    def _extract_vars(self, text: str) -> list[str]:
        # 如果文本为空,返回空列表
        if not text:
            return []
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
        pattern = r"\{([^}:]+)(?::[^}]+)?\}"
        # 使用正则在文本中查找所有变量名
        matches = re.findall(pattern, text)
        # 去重并保持顺序,返回变量名列表
        return list(dict.fromkeys(matches))

    # 向示例列表中添加一个新示例(字典类型)
    def add_example(self, example: dict):
        """动态添加单条示例"""
        # 在示例列表末尾追加新示例
        self.examples.append(example)

    # 格式化所有示例,返回由字符串组成的列表
    def format_examples(self) -> list[str]:
        """返回格式化后的示例字符串列表"""
        # 创建空列表用于保存格式化结果
        formatted = []
        # 遍历所有示例
        for example in self.examples:
            # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
            formatted.append(self.example_prompt.format(**example))
        # 返回格式化后的字符串列表
        return formatted

    # 格式化 few-shot 提示串,根据输入变量生成完整提示词
    def format(self, **kwargs) -> str:
        """根据传入变量生成 few-shot 提示词"""
        # 检查传入的变量是否完整,缺失就抛异常
        missing = set(self.input_variables) - set(kwargs.keys())
        if missing:
            raise ValueError(f"缺少必需的变量: {missing}")

        # 创建 parts 列表,用于拼接最终组成部分
        parts: list[str] = []
        # 如果有前缀,则格式化并加入 parts
        if self.prefix:
            parts.append(self._format_text(self.prefix, **kwargs))

        # 格式化所有示例并拼接为块
        example_block = self.example_separator.join(self.format_examples())
        # 如果示例块非空,加到 parts
        if example_block:
            parts.append(example_block)

        # 如果有后缀,则格式化并加入 parts
        if self.suffix:
            parts.append(self._format_text(self.suffix, **kwargs))

        # 用分隔符拼接所有组成部分,返回最终结果
        return self.example_separator.join(part for part in parts if part)

    # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
    def _format_text(self, text: str, **kwargs) -> str:
        # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
        temp_prompt = PromptTemplate.from_template(text)
        return temp_prompt.format(**kwargs)


# 定义一个从文件加载提示词模板的函数
def load_prompt(path: str | Path,encoding: str | None = None) -> PromptTemplate:
    """
    从 JSON 文件加载提示词模板

    Args:
        path: 提示词配置文件的路径(支持 .json 格式)

    Returns:
        PromptTemplate 实例

    JSON 文件格式示例:
        {
            "_type": "prompt",
            "template": "你好,我叫{name},你是谁?"
        }
    """
    # 将传入路径转换为 Path 对象
    file_path = Path(path)
    # 检查文件是否存在,不存在则抛出异常
    if not file_path.exists():
        raise FileNotFoundError(f"提示词文件不存在: {path}")
    # 检查文件扩展名是否为 .json,不是则抛出异常
    if file_path.suffix != ".json":
        raise ValueError(f"只支持 .json 格式文件,当前文件: {file_path.suffix}")
    # 打开文件并以 utf-8 编码读取 JSON 内容到 config 变量中
    with file_path.open(encoding=encoding) as f:
        config = json.load(f)
    # 从配置中获取 _type 字段,默认值为 "prompt"
    config_type = config.get("_type", "prompt")
    # 如果配置类型不是 "prompt",则抛出异常
    if config_type != "prompt":
        raise ValueError(f"不支持的提示词类型: {config_type},当前只支持 'prompt'")
    # 获取模板字符串,若不存在则抛出异常
    template = config.get("template")
    if template is None:
        raise ValueError("配置文件中缺少 'template' 字段")
    # 使用读取到的模板字符串创建 PromptTemplate 实例并返回
    return PromptTemplate.from_template(template)

+# 定义管道式提示词模板类
+class PipelinePromptTemplate:
+   """
+   管道式提示词模板,将多个模板串联
+   前一个模板的输出作为后一个模板的输入变量
+   
+   使用方式:
+       template1 = PromptTemplate.from_template("问题:{question}")
+       template2 = PromptTemplate.from_template("上下文:{context}")
+       final = PromptTemplate.from_template("{output_0}\n{output_1}\n请回答")
+       pipeline = PipelinePromptTemplate([template1, template2], final)
+       result = pipeline.format(question="...", context="...")
+   """
+   
+   # 构造方法,初始化管道式提示词模板
+   def __init__(self, prompt_templates: list[PromptTemplate], final_prompt: PromptTemplate):
+       """
+       初始化 PipelinePromptTemplate
+       
+       Args:
+           prompt_templates: 中间模板列表,每个模板的输出会作为最终模板的输入
+           final_prompt: 最终模板,使用所有中间模板的输出(output_0, output_1, ...)和用户变量
+       """
+       # 保存中间模板列表
+       self.prompt_templates = prompt_templates
+       # 保存最终模板
+       self.final_prompt = final_prompt
+       # 提取管道模板所需的全部输入变量
+       self.input_variables = self._extract_input_variables()
+   
+   # 内部方法:提取所有需要的输入变量
+   def _extract_input_variables(self):
+       """
+       提取所有需要的输入变量
+       
+       Returns:
+           输入变量列表
+       """
+       # 定义变量集合用于存储所有输入变量
+       variables = set()
+       # 遍历所有中间模板,收集其输入变量
+       for template in self.prompt_templates:
+           variables.update(template.input_variables)
+       # 收集最终模板需要的变量
+       final_vars = set(self.final_prompt.input_variables)
+       # 构造中间模板输出变量名集合(output_0, output_1, ...)
+       intermediate_outputs = {f"output_{i}" for i in range(len(self.prompt_templates))}
+       # 最终模板中去掉这些output_x后的用户变量
+       user_vars = final_vars - intermediate_outputs
+       # 合并所有输入变量
+       variables.update(user_vars)
+       # 返回变量列表
+       return list(variables)
+   
+   # 格式化管道提示词模板
+   def format(self, **kwargs):
+       """
+       格式化管道模板
+       
+       Args:
+           **kwargs: 用户提供的变量
+       
+       Returns:
+           格式化后的最终字符串
+       """
+       # 创建字典存储所有中间步骤的输出
+       intermediate_outputs = {}
+       # 逐个遍历并格式化中间模板
+       for i, template in enumerate(self.prompt_templates):
+           # 准备传递给当前模板的变量
+           template_vars = {}
+           # 取用户提供的变量,如果模板需要则填入
+           for key, value in kwargs.items():
+               if key in template.input_variables:
+                   template_vars[key] = value
+           # 用当前变量格式化模板,获得输出
+           output = template.format(**template_vars)
+           # 保存输出到中间输出字典,变量名为output_{i}
+           intermediate_outputs[f"output_{i}"] = output
+       # 合并用户变量和中间模板所有输出,准备给最终模板用
+       final_vars = {**kwargs, **intermediate_outputs}
+       # 用最终变量格式化最终模板
+       return self.final_prompt.format(**final_vars)
+

12. LengthBasedExampleSelector #

LengthBasedExampleSelector(基于长度的示例选择器)是一种用于动态选择 few-shot 示例的实用工具。
当我们用FewShotPromptTemplate构造提示词时,往往需要根据大模型输入的实际长度限制,动态选择数量合适的示例,避免超过最大长度导致截断或报错。LengthBasedExampleSelector 就能根据输入变量的实际长度,自动筛选出尽可能多的示例,使总长度不超过事先设定的 max_length。

核心原理说明:

  • 初始化时指定所有候选示例、每个示例的格式模板、以及允许的最大总长度。
  • 每当需要选择示例时,会先计算用户输入内容的长度,预估还可以分配给示例的剩余空间。
  • 按序遍历所有示例,只要本条示例加进去不会造成超长,就继续选入,直到不能再加为止。
  • 返回实际被选中的示例列表,供下游 few-shot template 拼接。

适用场景
适合 few-shot learning 场景下,模型输入最大 token/字符长度有限,又希望自动兼顾尽可能多的高质量示例,提升泛化能力与模型响应效果。

用法简要:

  1. 实例化 LengthBasedExampleSelector,提供所有示例、示例提示模板和最大长度(通常和你的大模型上下文窗口有关)。
  2. 可配合 FewShotPromptTemplate 的 example_selector 参数自动集成,无需手动处理示例数量。
  3. 对于不同长度的用户输入,能灵活调整选入的示例条数,实现“短问题多举例,长输入少举例”。

12.1. LengthBasedExampleSelector.py #

12.LengthBasedExampleSelector.py

#from langchain_core.prompts import PromptTemplate,FewShotPromptTemplate
#from langchain_core.example_selectors import LengthBasedExampleSelector
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.example_selectors import LengthBasedExampleSelector

# 创建示例列表
examples = [
    {"question": "1+1等于多少?", "answer": "答案是2"},
    {"question": "2+2等于多少?", "answer": "答案是4"},
    {"question": "3+3等于多少?", "answer": "答案是6"},
    {"question": "4+4等于多少?", "answer": "答案是8"},
    {"question": "5+5等于多少?", "answer": "答案是10"},
]

# 创建示例模板
example_prompt = PromptTemplate.from_template(
    "问题:{question}\n答案:{answer}"
)

# 创建基于长度的示例选择器
# max_length 设置为较小的值以便演示选择效果
selector = LengthBasedExampleSelector(
    examples=examples,
    example_prompt=example_prompt,
    max_length=50,  # 设置较小的最大长度以便演示
)

print("所有示例:")
for i, ex in enumerate(examples):
    formatted = example_prompt.format(**ex)
    print(f"示例{i+1}: {formatted} (长度: {selector.get_text_length(formatted)})")

print("\n" + "="*50 + "\n")

# 测试不同长度的输入
test_inputs = [
    {"user_question": "6+6等于多少?"},  # 短输入,可以选择更多示例
    {"user_question": "这是一个非常长的问题,用来测试当输入很长时,示例选择器会选择更少的示例,因为剩余的长度更少了。"},  # 长输入,选择更少示例
]

for test_input in test_inputs:
    input_text = " ".join(str(v) for v in test_input.values())
    input_length = selector.get_text_length(input_text)
    print(f"输入:{input_text}")
    print(f"输入长度:{input_length}")

    # 选择示例
    selected = selector.select_examples(test_input)
    print(f"选中的示例数量:{len(selected)}")
    print("选中的示例:")
    for i, ex in enumerate(selected):
        formatted = example_prompt.format(**ex)
        print(f"  {i+1}. {formatted}")
    print("\n" + "="*50 + "\n")

# 使用选择器与 FewShotPromptTemplate 结合
few_shot_prompt = FewShotPromptTemplate(
    example_prompt=example_prompt,
    prefix="你是一个数学助手。以下是一些示例:",
    example_selector=selector,  # 使用选择器(会自动从选择器中获取示例),
    suffix="问题:{user_question}\n答案:",
    input_variables=["user_question"],
)

print("使用 FewShotPromptTemplate + LengthBasedExampleSelector:")
formatted = few_shot_prompt.format(user_question="7+7等于多少?")
print(formatted)
print("\n" + "="*50 + "\n")

# 调用模型
llm = ChatOpenAI(model="gpt-4o")
result = llm.invoke(formatted)
print("AI 回复:")
print(result.content)

12.2. example_selectors.py #

langchain/example_selectors.py

# 导入正则表达式库
import re
# 从本地prompts模块导入PromptTemplate类
from .prompts import PromptTemplate

# 定义一个基于长度的示例选择器类
class LengthBasedExampleSelector:
    """
    基于长度的示例选择器
    根据输入长度自动选择合适数量的示例,确保总长度不超过限制
    """

    # 构造方法,初始化LengthBasedExampleSelector对象
    def __init__(
        self,
        examples: list[dict],               # 示例列表,每个元素为字典
        example_prompt: PromptTemplate | str,  # 用于格式化示例的模板,可以是PromptTemplate或字符串
        max_length: int = 2048,             # 提示词最大长度,默认为2048
        get_text_length=None,               # 可选的文本长度计算函数
    ):
        """
        初始化 LengthBasedExampleSelector

        Args:
            examples: 示例列表,每个示例是一个字典
            example_prompt: 用于格式化示例的模板(PromptTemplate 或字符串)
            max_length: 提示词的最大长度(默认 2048)
            get_text_length: 计算文本长度的函数(默认按单词数计算)
        """
        # 保存所有示例到实例变量
        self.examples = examples
        # 如果example_prompt是字符串,则构建为PromptTemplate对象
        if isinstance(example_prompt, str):
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 如果已是PromptTemplate则直接赋值
        else:
            self.example_prompt = example_prompt
        # 保存最大长度参数
        self.max_length = max_length
        # 设置文本长度计算函数,默认为内部定义的按单词数计算
        self.get_text_length = get_text_length or self._default_get_text_length
        # 计算并缓存每个示例(格式化后)的长度
        self.example_text_lengths = self._calculate_example_lengths()

    # 默认的长度计算方法,统计文本中的单词数
    def _default_get_text_length(self, text: str) -> int:
        """
        默认的长度计算函数:按单词数计算

        Args:
            text: 文本内容

        Returns:
            文本长度(单词数)
        """
        # 利用正则按空白字符分割,统计词数
        return len(re.split(r'\s+', text.strip()))

    # 计算所有示例格式化后的长度
    def _calculate_example_lengths(self) -> list[int]:
        """
        计算所有示例的长度

        Returns:
            每个示例的长度列表
        """
        # 初始化长度列表
        lengths = []
        # 对每个示例进行格式化并计算长度
        for example in self.examples:
            # 用模板对示例内容进行格式化
            formatted_example = self.example_prompt.format(**example)
            # 计算格式化后示例的长度
            length = self.get_text_length(formatted_example)
            # 记录到列表中
            lengths.append(length)
        # 返回长度列表
        return lengths

    # 添加新示例,并计算其长度
    def add_example(self, example: dict):
        """
        添加新示例

        Args:
            example: 新示例字典
        """
        # 添加到示例列表
        self.examples.append(example)
        # 用模板格式化新示例
        formatted_example = self.example_prompt.format(**example)
        # 计算格式化后的长度
        length = self.get_text_length(formatted_example)
        # 将长度加入缓存
        self.example_text_lengths.append(length)

    # 根据输入内容长度,选择合适的示例列表
    def select_examples(self, input_variables: dict) -> list[dict]:
        """
        根据输入长度选择示例

        Args:
            input_variables: 输入变量字典

        Returns:
            选中的示例列表
        """
        # 将输入的所有变量拼成一个字符串
        input_text = " ".join(str(v) for v in input_variables.values())
        # 计算输入内容的长度
        input_length = self.get_text_length(input_text)
        # 计算剩余可用长度
        remaining_length = self.max_length - input_length
        # 初始化选中的示例列表
        selected_examples = []
        # 遍历所有示例
        for i, example in enumerate(self.examples):
            # 如果剩余长度已经不足,提前停止选择
            if remaining_length <= 0:
                break
            # 获取当前示例的长度
            example_length = self.example_text_lengths[i]
            # 判断如果再添加这个示例会超过剩余长度,则停止选择
            if remaining_length - example_length < 0:
                break
            # 把当前示例加入已选择列表
            selected_examples.append(example)
            # 更新剩余可用长度
            remaining_length -= example_length
        # 返回最终选中的示例列表
        return selected_examples

12.3. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
# 导入 JSON 和路径处理
import json
from pathlib import Path
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
    def __init__(self, template: str, partial_variables: dict = None):
        # 保存模板字符串
        self.template = template
        # 保存部分变量(已预填充的变量)
        self.partial_variables = partial_variables or {}
        # 提取模板中的变量名
        all_variables = self._extract_variables(template)
        # 从所有变量中排除已部分填充的变量
        self.input_variables = [v for v in all_variables if v not in self.partial_variables]

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
        # 合并部分变量和用户提供的变量
        all_vars = {**self.partial_variables, **kwargs}
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
        return self.template.format(**all_vars)

    # 定义部分填充模板变量的方法,返回新的模板实例
    def partial(self, **kwargs):
        """
        部分填充模板变量,返回一个新的 PromptTemplate 实例

        Args:
            **kwargs: 要部分填充的变量及其值

        Returns:
            新的 PromptTemplate 实例,其中指定的变量已被填充

        示例:
            template = PromptTemplate.from_template("你好,我叫{name},我来自{city}")
            partial_template = template.partial(name="张三")
            # 现在只需要提供 city 参数
            result = partial_template.format(city="北京")
        """
        # 合并现有对象的部分变量(partial_variables)和本次要填充的新变量
        new_partial_variables = {**self.partial_variables, **kwargs}
        # 使用原模板字符串和更新后的部分变量,创建新的 PromptTemplate 实例
        new_template = PromptTemplate(
            template=self.template,
            partial_variables=new_partial_variables
        )
        # 返回新的 PromptTemplate 实例
        return new_template

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
class FewShotPromptTemplate:
    # few-shot 提示词模板说明
    """用于构造 few-shot 提示词的模板"""

    # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
    def __init__(
        self,
        *,
+       examples: list[dict] = None,  # 示例列表,元素为字典
        example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
        prefix: str = "",  # 提示词前缀
        suffix: str = "",  # 提示词后缀
        example_separator: str = "\n\n",  # 示例之间的分隔符
        input_variables: list[str] | None = None,  # 输入变量列表
+       example_selector=None,  # 示例选择器(可选)
    ):
+       # 如果提供了示例选择器,则使用选择器;否则使用提供的示例列表
+       self.example_selector = example_selector
        # 如果未传示例,则设为默认空列表
        self.examples = examples or []
        # 判断 example_prompt 参数类型
        if isinstance(example_prompt, PromptTemplate):
            # 如果是 PromptTemplate 实例,直接赋值
            self.example_prompt = example_prompt
        else:
            # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 保存前缀
        self.prefix = prefix
        # 保存后缀
        self.suffix = suffix
        # 保存示例分隔符
        self.example_separator = example_separator
        # 输入变量列表,未提供则自动推断
        self.input_variables = input_variables or self._infer_input_variables()

    # 内部方法:自动推断输入变量列表
    def _infer_input_variables(self) -> list[str]:
        # 创建空集合用于存放变量名
        variables = set()
        # 提取前缀中的变量并加入集合
        variables.update(self._extract_vars(self.prefix))
        # 提取后缀中的变量并加入集合
        variables.update(self._extract_vars(self.suffix))
        # 返回变量集合的列表形式
        return list(variables)

    # 内部方法:从文本中提取所有花括号包围的变量名
    def _extract_vars(self, text: str) -> list[str]:
        # 如果文本为空,返回空列表
        if not text:
            return []
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
        pattern = r"\{([^}:]+)(?::[^}]+)?\}"
        # 使用正则在文本中查找所有变量名
        matches = re.findall(pattern, text)
        # 去重并保持顺序,返回变量名列表
        return list(dict.fromkeys(matches))

    # 向示例列表中添加一个新示例(字典类型)
    def add_example(self, example: dict):
        """动态添加单条示例"""
        # 在示例列表末尾追加新示例
        self.examples.append(example)

    # 格式化所有示例,返回由字符串组成的列表
+   def format_examples(self, input_variables: dict = None) -> list[str]:
+       """
+       返回格式化后的示例字符串列表
+       
+       Args:
+           input_variables: 输入变量字典,用于示例选择器选择示例
+       """
+       # 如果提供了示例选择器,使用选择器选择示例
+       if self.example_selector and input_variables:
+           selected_examples = self.example_selector.select_examples(input_variables)
+       else:
+           selected_examples = self.examples
+       
        # 创建空列表用于保存格式化结果
        formatted = []
+       # 遍历选中的示例
+       for example in selected_examples:
            # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
            formatted.append(self.example_prompt.format(**example))
        # 返回格式化后的字符串列表
        return formatted

    # 格式化 few-shot 提示串,根据输入变量生成完整提示词
    def format(self, **kwargs) -> str:
+       """
+       根据传入变量生成 few-shot 提示词
+       
+       Args:
+           **kwargs: 输入变量,如果提供了 example_selector,会用于选择示例
+       """
        # 检查传入的变量是否完整,缺失就抛异常
        missing = set(self.input_variables) - set(kwargs.keys())
        if missing:
            raise ValueError(f"缺少必需的变量: {missing}")

        # 创建 parts 列表,用于拼接最终组成部分
        parts: list[str] = []
        # 如果有前缀,则格式化并加入 parts
        if self.prefix:
            parts.append(self._format_text(self.prefix, **kwargs))

        # 格式化所有示例并拼接为块
+       # 如果使用示例选择器,传递输入变量;否则不传递
+       if self.example_selector:
+           example_block = self.example_separator.join(self.format_examples(input_variables=kwargs))
+       else:
+           example_block = self.example_separator.join(self.format_examples())
        # 如果示例块非空,加到 parts
        if example_block:
            parts.append(example_block)

        # 如果有后缀,则格式化并加入 parts
        if self.suffix:
            parts.append(self._format_text(self.suffix, **kwargs))

        # 用分隔符拼接所有组成部分,返回最终结果
        return self.example_separator.join(part for part in parts if part)

    # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
    def _format_text(self, text: str, **kwargs) -> str:
        # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
        temp_prompt = PromptTemplate.from_template(text)
        return temp_prompt.format(**kwargs)


# 定义一个从文件加载提示词模板的函数
def load_prompt(path: str | Path,encoding: str | None = None) -> PromptTemplate:
    """
    从 JSON 文件加载提示词模板

    Args:
        path: 提示词配置文件的路径(支持 .json 格式)

    Returns:
        PromptTemplate 实例

    JSON 文件格式示例:
        {
            "_type": "prompt",
            "template": "你好,我叫{name},你是谁?"
        }
    """
    # 将传入路径转换为 Path 对象
    file_path = Path(path)
    # 检查文件是否存在,不存在则抛出异常
    if not file_path.exists():
        raise FileNotFoundError(f"提示词文件不存在: {path}")
    # 检查文件扩展名是否为 .json,不是则抛出异常
    if file_path.suffix != ".json":
        raise ValueError(f"只支持 .json 格式文件,当前文件: {file_path.suffix}")
    # 打开文件并以 utf-8 编码读取 JSON 内容到 config 变量中
    with file_path.open(encoding=encoding) as f:
        config = json.load(f)
    # 从配置中获取 _type 字段,默认值为 "prompt"
    config_type = config.get("_type", "prompt")
    # 如果配置类型不是 "prompt",则抛出异常
    if config_type != "prompt":
        raise ValueError(f"不支持的提示词类型: {config_type},当前只支持 'prompt'")
    # 获取模板字符串,若不存在则抛出异常
    template = config.get("template")
    if template is None:
        raise ValueError("配置文件中缺少 'template' 字段")
    # 使用读取到的模板字符串创建 PromptTemplate 实例并返回
    return PromptTemplate.from_template(template)

# 定义管道式提示词模板类
class PipelinePromptTemplate:
    """
    管道式提示词模板,将多个模板串联
    前一个模板的输出作为后一个模板的输入变量

    使用方式:
        template1 = PromptTemplate.from_template("问题:{question}")
        template2 = PromptTemplate.from_template("上下文:{context}")
        final = PromptTemplate.from_template("{output_0}\n{output_1}\n请回答")
        pipeline = PipelinePromptTemplate([template1, template2], final)
        result = pipeline.format(question="...", context="...")
    """

    # 构造方法,初始化管道式提示词模板
    def __init__(self, prompt_templates: list[PromptTemplate], final_prompt: PromptTemplate):
        """
        初始化 PipelinePromptTemplate

        Args:
            prompt_templates: 中间模板列表,每个模板的输出会作为最终模板的输入
            final_prompt: 最终模板,使用所有中间模板的输出(output_0, output_1, ...)和用户变量
        """
        # 保存中间模板列表
        self.prompt_templates = prompt_templates
        # 保存最终模板
        self.final_prompt = final_prompt
        # 提取管道模板所需的全部输入变量
        self.input_variables = self._extract_input_variables()

    # 内部方法:提取所有需要的输入变量
    def _extract_input_variables(self):
        """
        提取所有需要的输入变量

        Returns:
            输入变量列表
        """
        # 定义变量集合用于存储所有输入变量
        variables = set()
        # 遍历所有中间模板,收集其输入变量
        for template in self.prompt_templates:
            variables.update(template.input_variables)
        # 收集最终模板需要的变量
        final_vars = set(self.final_prompt.input_variables)
        # 构造中间模板输出变量名集合(output_0, output_1, ...)
        intermediate_outputs = {f"output_{i}" for i in range(len(self.prompt_templates))}
        # 最终模板中去掉这些output_x后的用户变量
        user_vars = final_vars - intermediate_outputs
        # 合并所有输入变量
        variables.update(user_vars)
        # 返回变量列表
        return list(variables)

    # 格式化管道提示词模板
    def format(self, **kwargs):
        """
        格式化管道模板

        Args:
            **kwargs: 用户提供的变量

        Returns:
            格式化后的最终字符串
        """
        # 创建字典存储所有中间步骤的输出
        intermediate_outputs = {}
        # 逐个遍历并格式化中间模板
        for i, template in enumerate(self.prompt_templates):
            # 准备传递给当前模板的变量
            template_vars = {}
            # 取用户提供的变量,如果模板需要则填入
            for key, value in kwargs.items():
                if key in template.input_variables:
                    template_vars[key] = value
            # 用当前变量格式化模板,获得输出
            output = template.format(**template_vars)
            # 保存输出到中间输出字典,变量名为output_{i}
            intermediate_outputs[f"output_{i}"] = output
        # 合并用户变量和中间模板所有输出,准备给最终模板用
        final_vars = {**kwargs, **intermediate_outputs}
        # 用最终变量格式化最终模板
        return self.final_prompt.format(**final_vars)

13. MaxMarginalRelevanceExampleSelector #

MaxMarginalRelevanceExampleSelector选择器常用于 Few-Shot Prompting 场景中。它的作用是在一批已知示例(examples)中,根据当前用户输入问题(query),自动挑选出若干个“最相关且多样性较高”的示例,用于构建 Few-Shot Prompt。

与仅考虑相关性的选择器不同,MaxMarginalRelevanceExampleSelector 采用最大边际相关性(MMR, Maximal Marginal Relevance)算法,在确保和输入问题相关的同时还会尽可能让每个被选中的示例之间互不重复、覆盖面广,从而提升大模型上下文中的信息多样性,减少示例冗余和偏见。

工作流程简述如下:

  1. 首先通过嵌入模型(如 OpenAIEmbeddings)将所有示例和输入问题向量化。
  2. 对比计算每个示例和输入问题的相似度,初步筛选。
  3. 按照最大边际相关性顺序迭代选择出指定数量(k 个)的示例,既保证它们与输入问题接近,又确保彼此之间不要过于相似。

13.1. MaxMarginalRelevanceExampleSelector.py #

13.MaxMarginalRelevanceExampleSelector.py

#from langchain_core.prompts import PromptTemplate, FewShotPromptTemplate
#from langchain_core.example_selectors import MaxMarginalRelevanceExampleSelector
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings
#from langchain_community.vectorstores import FAISS

from langchain.chat_models import ChatOpenAI,OpenAIEmbeddings
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.example_selectors import MaxMarginalRelevanceExampleSelector
from langchain.vectorstores import FAISS

examples = [
    {"question": "今天天气怎么样?", "answer": "今天天气晴朗,适合出门活动。"},
    {"question": "怎么做西红柿炒鸡蛋?", "answer": "先把西红柿和鸡蛋切好,鸡蛋炒熟后盛出,再炒西红柿,最后把鸡蛋倒回去一起炒匀即可。"},
    {"question": "如何快速减肥?", "answer": "合理饮食结合锻炼,每天保持运动,避免高热量食物。"},
    {"question": "手机没电了怎么办?", "answer": "用充电器充电,或者借用移动电源。"},
    {"question": "头疼该怎么办?", "answer": "多休息,如果严重可以适当吃点止痛药。"},
    {"question": "怎样养护盆栽?", "answer": "定期浇水,保持阳光,不要积水。"},
    {"question": "想学英语怎么入门?", "answer": "可以先从背单词、学基础语法和多听多说开始。"},
    {"question": "晚上失眠怎么办?", "answer": "睡前放松,避免咖啡因,可以听点轻音乐帮助入睡。"},
    {"question": "烧水壶如何清理水垢?", "answer": "可以倒入一点醋和水煮几分钟,再用清水冲洗干净。"},
    {"question": "手机上怎么截图?", "answer": "可以同时按住电源键和音量减键进行截图,不同手机略有区别。"},
]

# 创建示例模板
example_prompt = PromptTemplate.from_template(
    "问题:{question}\n答案:{answer}"
)

print("=" * 60)
print("MaxMarginalRelevanceExampleSelector 演示")
print("=" * 60)
print("\nMaxMarginalRelevanceExampleSelector 使用最大边际相关性算法")
print("来选择示例,它既考虑与查询的相似度,也考虑示例之间的多样性。")
print("这样可以避免选择过于相似的示例,提供更全面的示例集合。\n")

# 创建嵌入模型
print("正在初始化嵌入模型和向量存储...")
embeddings = OpenAIEmbeddings()

# 使用 from_examples 方法创建 MaxMarginalRelevanceExampleSelector
selector = MaxMarginalRelevanceExampleSelector.from_examples(
    examples=examples,
    embeddings=embeddings,
    vectorstore_cls=FAISS,
    k=3,       # 最终选择 3 个示例
    fetch_k=5, # 先从向量存储获取 5 个最相似的,然后选择 3 个最多样化的
    input_keys=["question"],  # 基于 "question" 字段进行相似度搜索
)

print(f"已创建选择器,包含 {len(examples)} 个示例\n")
print("=" * 60)
print("\n")

# 测试不同的问题,观察选择器如何选择示例
test_questions = [
    "有什么简单的家常菜推荐?",         # 与如何做饭、日常生活相关
    "如何改善睡眠质量?",               # 与失眠、放松相关
    "手机没信号怎么办?",                # 与手机使用问题相关
]

for i, question in enumerate(test_questions, 1):
    print(f"测试问题 {i}: {question}")
    print("-" * 60)

    # 选择示例
    selected = selector.select_examples({"question": question})
    print(f"选中的示例数量:{len(selected)}")
    print("\n选中的示例:")
    for j, ex in enumerate(selected, 1):
        formatted = example_prompt.format(**ex)
        print(f"\n示例 {j}:")
        print(formatted)

    print("\n" + "=" * 60 + "\n")

# 使用选择器与 FewShotPromptTemplate 结合
print("使用 FewShotPromptTemplate + MaxMarginalRelevanceExampleSelector:")
few_shot_prompt = FewShotPromptTemplate(
    example_prompt=example_prompt,
    prefix="你是一个乐于助人的生活小助手。以下是一些建议示例:",
    suffix="问题:{question}\n答案:",  # 使用 "question" 变量
    example_selector=selector,
    input_variables=["question"],
)

# 格式化提示词
user_question = "怎样挑西瓜比较甜?"
formatted = few_shot_prompt.format(question=user_question)
print(f"\n用户问题:{user_question}")
print("\n生成的提示词:")
print(formatted)
print("\n" + "=" * 60 + "\n")

# 调用模型(可选)
print("调用 AI 模型生成回答...")
try:
    llm = ChatOpenAI(model="gpt-4o")
    result = llm.invoke(formatted)
    print("AI 回复:")
    print(result.content)
except Exception as e:
    print(f"调用模型时出错:{e}")
    print("(这可能是由于 API 密钥未设置或网络问题)")

13.2. embeddings.py #

langchain/embeddings.py

# 导入操作系统相关模块
import os
# 导入 openai 模块
import openai
# 从 abc 模块导入 ABC 和 abstractmethod,用于定义抽象基类
from abc import ABC, abstractmethod
# 导入类型提示相关模块
from typing import List


# 定义嵌入模型的抽象基类
class Embeddings(ABC):
    """嵌入模型的抽象基类"""

    @abstractmethod
    def embed_query(self, text: str) -> List[float]:
        """将单个文本转换为嵌入向量"""
        pass

    @abstractmethod
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """将多个文本转换为嵌入向量列表"""
        pass


# 定义与 OpenAI 嵌入模型交互的类
class OpenAIEmbeddings(Embeddings):
    """OpenAI 嵌入模型集成"""

    # 初始化方法
    def __init__(self, model: str = "text-embedding-3-small", **kwargs):
        """
        初始化 OpenAIEmbeddings

        Args:
            model: 模型名称,如 "text-embedding-3-small"
            **kwargs: 其他参数(如 api_key, dimensions 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存其他参数(如 dimensions)
        self.embedding_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 将单个文本转换为嵌入向量
    def embed_query(self, text: str) -> List[float]:
        """
        将单个文本转换为嵌入向量

        Args:
            text: 要嵌入的文本

        Returns:
            List[float]: 嵌入向量
        """
        # 调用 OpenAI 的 embeddings API
        response = self.client.embeddings.create(
            model=self.model,
            input=text,
            **self.embedding_kwargs
        )
        # 返回第一个(也是唯一的)嵌入向量
        return response.data[0].embedding

    # 将多个文本转换为嵌入向量列表
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        将多个文本转换为嵌入向量列表

        Args:
            texts: 要嵌入的文本列表

        Returns:
            List[List[float]]: 嵌入向量列表
        """
        # 调用 OpenAI 的 embeddings API(支持批量输入)
        response = self.client.embeddings.create(
            model=self.model,
            input=texts,
            **self.embedding_kwargs
        )
        # 返回所有嵌入向量
        return [item.embedding for item in response.data]

13.3. example_selectors.py #

langchain/example_selectors.py

# 导入正则表达式库
import re
# 从本地prompts模块导入PromptTemplate类
from .prompts import PromptTemplate
+# 从本地vectorstores模块导入VectorStore和Document类
+from .vectorstores import VectorStore, Document
+# 导入类型提示相关模块
+from typing import Optional, List, Dict, Any
+from abc import ABC, abstractmethod

# 定义一个基于长度的示例选择器类
class LengthBasedExampleSelector:
    """
    基于长度的示例选择器
    根据输入长度自动选择合适数量的示例,确保总长度不超过限制
    """

    # 构造方法,初始化LengthBasedExampleSelector对象
    def __init__(
        self,
        examples: list[dict],               # 示例列表,每个元素为字典
        example_prompt: PromptTemplate | str,  # 用于格式化示例的模板,可以是PromptTemplate或字符串
        max_length: int = 2048,             # 提示词最大长度,默认为2048
        get_text_length=None,               # 可选的文本长度计算函数
    ):
        """
        初始化 LengthBasedExampleSelector

        Args:
            examples: 示例列表,每个示例是一个字典
            example_prompt: 用于格式化示例的模板(PromptTemplate 或字符串)
            max_length: 提示词的最大长度(默认 2048)
            get_text_length: 计算文本长度的函数(默认按单词数计算)
        """
        # 保存所有示例到实例变量
        self.examples = examples
        # 如果example_prompt是字符串,则构建为PromptTemplate对象
        if isinstance(example_prompt, str):
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 如果已是PromptTemplate则直接赋值
        else:
            self.example_prompt = example_prompt
        # 保存最大长度参数
        self.max_length = max_length
        # 设置文本长度计算函数,默认为内部定义的按单词数计算
        self.get_text_length = get_text_length or self._default_get_text_length
        # 计算并缓存每个示例(格式化后)的长度
        self.example_text_lengths = self._calculate_example_lengths()

    # 默认的长度计算方法,统计文本中的单词数
    def _default_get_text_length(self, text: str) -> int:
        """
        默认的长度计算函数:按单词数计算

        Args:
            text: 文本内容

        Returns:
            文本长度(单词数)
        """
        # 利用正则按空白字符分割,统计词数
        return len(re.split(r'\s+', text.strip()))

    # 计算所有示例格式化后的长度
    def _calculate_example_lengths(self) -> list[int]:
        """
        计算所有示例的长度

        Returns:
            每个示例的长度列表
        """
        # 初始化长度列表
        lengths = []
        # 对每个示例进行格式化并计算长度
        for example in self.examples:
            # 用模板对示例内容进行格式化
            formatted_example = self.example_prompt.format(**example)
            # 计算格式化后示例的长度
            length = self.get_text_length(formatted_example)
            # 记录到列表中
            lengths.append(length)
        # 返回长度列表
        return lengths

    # 添加新示例,并计算其长度
    def add_example(self, example: dict):
        """
        添加新示例

        Args:
            example: 新示例字典
        """
        # 添加到示例列表
        self.examples.append(example)
        # 用模板格式化新示例
        formatted_example = self.example_prompt.format(**example)
        # 计算格式化后的长度
        length = self.get_text_length(formatted_example)
        # 将长度加入缓存
        self.example_text_lengths.append(length)

    # 根据输入内容长度,选择合适的示例列表
    def select_examples(self, input_variables: dict) -> list[dict]:
        """
        根据输入长度选择示例

        Args:
            input_variables: 输入变量字典

        Returns:
            选中的示例列表
        """
        # 将输入的所有变量拼成一个字符串
        input_text = " ".join(str(v) for v in input_variables.values())
        # 计算输入内容的长度
        input_length = self.get_text_length(input_text)
        # 计算剩余可用长度
        remaining_length = self.max_length - input_length
        # 初始化选中的示例列表
        selected_examples = []
        # 遍历所有示例
        for i, example in enumerate(self.examples):
            # 如果剩余长度已经不足,提前停止选择
            if remaining_length <= 0:
                break
            # 获取当前示例的长度
            example_length = self.example_text_lengths[i]
            # 判断如果再添加这个示例会超过剩余长度,则停止选择
            if remaining_length - example_length < 0:
                break
            # 把当前示例加入已选择列表
            selected_examples.append(example)
            # 更新剩余可用长度
            remaining_length -= example_length
        # 返回最终选中的示例列表
        return selected_examples

+
+# 定义示例选择器的抽象基类
+class BaseExampleSelector(ABC):
+   """示例选择器的抽象基类"""
+   
+   @abstractmethod
+   def select_examples(self, input_variables: dict) -> list[dict]:
+       """根据输入变量选择示例"""
+       pass
+   
+   @abstractmethod
+   def add_example(self, example: dict) -> Any:
+       """添加新示例"""
+       pass
+
+
+# 辅助函数:返回字典中按键排序的值列表
+def sorted_values(values: dict) -> list:
+   """返回字典中按键排序的值列表"""
+   return [values[val] for val in sorted(values)]
+
+
+# 定义基于向量存储的示例选择器基类
+class _VectorStoreExampleSelector(BaseExampleSelector):
+   """基于向量存储的示例选择器基类"""
+   
+   def __init__(
+       self,
+       vectorstore: VectorStore,
+       k: int = 4,
+       input_keys: Optional[List[str]] = None,
+       example_keys: Optional[List[str]] = None,
+       vectorstore_kwargs: Optional[Dict[str, Any]] = None,
+   ):
+       """
+       初始化 _VectorStoreExampleSelector
+       
+       Args:
+           vectorstore: 向量存储实例
+           k: 选择的示例数量
+           input_keys: 用于搜索的输入键列表(如果提供,只使用这些键进行搜索)
+           example_keys: 用于过滤示例的键列表(如果提供,只返回这些键的示例)
+           vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
+       """
+       self.vectorstore = vectorstore
+       self.k = k
+       self.input_keys = input_keys
+       self.example_keys = example_keys
+       self.vectorstore_kwargs = vectorstore_kwargs or {}
+   
+   @staticmethod
+   def _example_to_text(example: dict, input_keys: Optional[List[str]] = None) -> str:
+       """
+       将示例字典转换为文本字符串
+       
+       Args:
+           example: 示例字典
+           input_keys: 要使用的键列表(如果提供,只使用这些键)
+       
+       Returns:
+           str: 文本字符串
+       """
+       if input_keys:
+           # 只使用指定的键
+           filtered_example = {key: example[key] for key in input_keys if key in example}
+           return " ".join(sorted_values(filtered_example))
+       # 使用所有键
+       return " ".join(sorted_values(example))
+   
+   def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
+       """
+       将文档列表转换为示例字典列表
+       
+       Args:
+           documents: 文档列表
+       
+       Returns:
+           List[dict]: 示例字典列表
+       """
+       # 从文档的元数据中获取示例
+       examples = [dict(doc.metadata) for doc in documents]
+       
+       # 如果指定了 example_keys,则过滤示例
+       if self.example_keys:
+           examples = [{k: eg[k] for k in self.example_keys if k in eg} for eg in examples]
+       
+       return examples
+   
+   def add_example(self, example: dict) -> str:
+       """
+       添加新示例到向量存储
+       
+       Args:
+           example: 示例字典
+       
+       Returns:
+           str: 添加的示例 ID
+       """
+       # 将示例转换为文本
+       text = self._example_to_text(example, self.input_keys)
+       # 添加到向量存储
+       ids = self.vectorstore.add_texts(
+           texts=[text],
+           metadatas=[example],
+           **self.vectorstore_kwargs
+       )
+       return ids[0] if ids else ""
+
+
+# 定义最大边际相关性示例选择器类
+class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
+   """
+   基于最大边际相关性(MMR)的示例选择器
+   
+   使用 MMR 算法选择示例,既考虑与查询的相似度,也考虑示例之间的多样性。
+   这样可以避免选择过于相似的示例,提供更全面的示例集合。
+   
+   参考论文:https://arxiv.org/pdf/2211.13892.pdf
+   """
+   
+   def __init__(
+       self,
+       vectorstore: VectorStore,
+       k: int = 4,
+       fetch_k: int = 20,
+       input_keys: Optional[List[str]] = None,
+       example_keys: Optional[List[str]] = None,
+       vectorstore_kwargs: Optional[Dict[str, Any]] = None,
+   ):
+       """
+       初始化 MaxMarginalRelevanceExampleSelector
+       
+       Args:
+           vectorstore: 向量存储实例
+           k: 最终选择的示例数量
+           fetch_k: 从向量存储中获取的候选示例数量(用于 MMR 算法)
+           input_keys: 用于搜索的输入键列表
+           example_keys: 用于过滤示例的键列表
+           vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
+       """
+       super().__init__(
+           vectorstore=vectorstore,
+           k=k,
+           input_keys=input_keys,
+           example_keys=example_keys,
+           vectorstore_kwargs=vectorstore_kwargs,
+       )
+       self.fetch_k = fetch_k
+   
+   def select_examples(self, input_variables: dict) -> List[dict]:
+       """
+       根据最大边际相关性选择示例
+       
+       Args:
+           input_variables: 输入变量字典
+       
+       Returns:
+           List[dict]: 选中的示例列表
+       """
+       # 将输入变量转换为文本
+       query_text = self._example_to_text(input_variables, self.input_keys)
+       
+       # 使用 MMR 搜索
+       example_docs = self.vectorstore.max_marginal_relevance_search(
+           query=query_text,
+           k=self.k,
+           fetch_k=self.fetch_k,
+           **self.vectorstore_kwargs
+       )
+       
+       # 将文档转换为示例
+       return self._documents_to_examples(example_docs)
+   
+   @classmethod
+   def from_examples(
+       cls,
+       examples: List[dict],
+       embeddings,
+       vectorstore_cls: type,
+       k: int = 4,
+       fetch_k: int = 20,
+       input_keys: Optional[List[str]] = None,
+       example_keys: Optional[List[str]] = None,
+       vectorstore_kwargs: Optional[Dict[str, Any]] = None,
+       **vectorstore_cls_kwargs: Any,
+   ) -> "MaxMarginalRelevanceExampleSelector":
+       """
+       从示例列表创建 MaxMarginalRelevanceExampleSelector
+       
+       Args:
+           examples: 示例列表
+           embeddings: 嵌入模型实例
+           vectorstore_cls: 向量存储类(如 FAISS)
+           k: 最终选择的示例数量
+           fetch_k: 从向量存储中获取的候选示例数量
+           input_keys: 用于搜索的输入键列表
+           example_keys: 用于过滤示例的键列表
+           vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
+           **vectorstore_cls_kwargs: 传递给向量存储类的额外参数
+       
+       Returns:
+           MaxMarginalRelevanceExampleSelector: 示例选择器实例
+       """
+       # 将示例转换为文本列表
+       string_examples = [
+           cls._example_to_text(eg, input_keys) for eg in examples
+       ]
+       
+       # 创建向量存储
+       vectorstore = vectorstore_cls.from_texts(
+           texts=string_examples,
+           embedding=embeddings,
+           metadatas=examples,
+           **vectorstore_cls_kwargs
+       )
+       
+       # 创建并返回选择器实例
+       return cls(
+           vectorstore=vectorstore,
+           k=k,
+           fetch_k=fetch_k,
+           input_keys=input_keys,
+           example_keys=example_keys,
+           vectorstore_kwargs=vectorstore_kwargs,
+       )
+

13.4. vectorstores.py #

langchain/vectorstores.py

# 导入必要的模块
import os
import numpy as np
from typing import List, Optional, Dict, Any
from abc import ABC, abstractmethod

# 尝试导入 faiss,如果失败则使用简化实现
try:
    import faiss
    FAISS_AVAILABLE = True
except ImportError:
    FAISS_AVAILABLE = False
    print("警告:faiss 未安装,将使用简化的向量存储实现。")
    print("建议安装:pip install faiss-cpu 或 pip install faiss-gpu")


# 定义 Document 类(用于存储文档和元数据)
class Document:
    """文档类,包含内容和元数据"""

    def __init__(self, page_content: str, metadata: Optional[Dict] = None):
        """
        初始化 Document

        Args:
            page_content: 文档内容
            metadata: 文档元数据
        """
        self.page_content = page_content
        self.metadata = metadata or {}


# 定义向量存储的抽象基类
class VectorStore(ABC):
    """向量存储的抽象基类"""

    @abstractmethod
    def add_texts(
        self,
        texts: List[str],
        metadatas: Optional[List[Dict]] = None,
        **kwargs: Any
    ) -> List[str]:
        """添加文本到向量存储"""
        pass

    @abstractmethod
    def similarity_search(
        self,
        query: str,
        k: int = 4,
        **kwargs: Any
    ) -> List[Document]:
        """相似度搜索"""
        pass

    @abstractmethod
    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        **kwargs: Any
    ) -> List[Document]:
        """最大边际相关性搜索"""
        pass

    @classmethod
    @abstractmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding,
        metadatas: Optional[List[Dict]] = None,
        **kwargs: Any
    ):
        """从文本列表创建向量存储"""
        pass


# 定义 FAISS 向量存储类
class FAISS(VectorStore):
    """FAISS 向量存储实现"""

    def __init__(
        self,
        embedding,
        index=None,
        docstore: Optional[Dict[str, Document]] = None,
        index_to_docstore_id: Optional[Dict[int, str]] = None,
    ):
        """
        初始化 FAISS 向量存储

        Args:
            embedding: 嵌入模型
            index: FAISS 索引(可选)
            docstore: 文档存储字典
            index_to_docstore_id: 索引到文档 ID 的映射
        """
        self.embedding = embedding
        self.index = index
        self.docstore = docstore or {}
        self.index_to_docstore_id = index_to_docstore_id or {}
        self._use_faiss = FAISS_AVAILABLE

    def add_texts(
        self,
        texts: List[str],
        metadatas: Optional[List[Dict]] = None,
        **kwargs: Any
    ) -> List[str]:
        """
        添加文本到向量存储

        Args:
            texts: 文本列表
            metadatas: 元数据列表
            **kwargs: 其他参数

        Returns:
            List[str]: 添加的文档 ID 列表
        """
        # 如果没有元数据,创建空列表
        if metadatas is None:
            metadatas = [{}] * len(texts)

        # 生成嵌入向量
        embeddings = self.embedding.embed_documents(texts)
        embeddings = np.array(embeddings, dtype=np.float32)

        # 创建或更新索引
        if self.index is None:
            dimension = len(embeddings[0])
            if self._use_faiss:
                # 使用真实的 FAISS
                self.index = faiss.IndexFlatL2(dimension)
            else:
                # 使用简化的实现(仅存储向量)
                self.index = "simple"
                if not hasattr(self, '_vectors'):
                    self._vectors = []

        # 添加向量到索引
        if self._use_faiss:
            self.index.add(embeddings)
        else:
            # 简化实现:直接存储向量
            if not hasattr(self, '_vectors'):
                self._vectors = []
            self._vectors.extend(embeddings.tolist())

        # 创建文档并存储
        ids = []
        start_idx = len(self.docstore)
        for i, (text, metadata) in enumerate(zip(texts, metadatas)):
            doc_id = str(start_idx + i)
            doc = Document(page_content=text, metadata=metadata)
            self.docstore[doc_id] = doc
            self.index_to_docstore_id[start_idx + i] = doc_id
            ids.append(doc_id)

        return ids

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        **kwargs: Any
    ) -> List[Document]:
        """
        相似度搜索

        Args:
            query: 查询文本
            k: 返回的文档数量

        Returns:
            List[Document]: 最相似的文档列表
        """
        # 获取查询的嵌入向量
        query_embedding = self.embedding.embed_query(query)
        query_vector = np.array([query_embedding], dtype=np.float32)

        # 执行搜索
        if self._use_faiss and isinstance(self.index, faiss.Index):
            # 使用真实的 FAISS 搜索
            distances, indices = self.index.search(query_vector, k)
            # 获取文档
            docs = []
            for idx in indices[0]:
                if idx in self.index_to_docstore_id:
                    doc_id = self.index_to_docstore_id[idx]
                    docs.append(self.docstore[doc_id])
            return docs
        else:
            # 简化实现:使用余弦相似度
            if not hasattr(self, '_vectors') or len(self._vectors) == 0:
                return []

            vectors = np.array(self._vectors, dtype=np.float32)
            query_vec = np.array(query_embedding, dtype=np.float32)

            # 计算余弦相似度
            # 归一化向量
            query_norm = query_vec / (np.linalg.norm(query_vec) + 1e-10)
            vectors_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-10)

            # 计算相似度
            similarities = np.dot(vectors_norm, query_norm)

            # 获取 top-k
            top_k_indices = np.argsort(similarities)[::-1][:k]

            # 获取文档
            docs = []
            for idx in top_k_indices:
                if idx in self.index_to_docstore_id:
                    doc_id = self.index_to_docstore_id[idx]
                    docs.append(self.docstore[doc_id])
            return docs

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        **kwargs: Any
    ) -> List[Document]:
        """
        最大边际相关性搜索

        Args:
            query: 查询文本
            k: 最终返回的文档数量
            fetch_k: 初始获取的文档数量
            lambda_mult: 多样性权重(0-1之间,越大越注重多样性)

        Returns:
            List[Document]: 选中的文档列表
        """
        # 获取查询的嵌入向量
        query_embedding = self.embedding.embed_query(query)
        query_vector = np.array([query_embedding], dtype=np.float32)  # 确保是 2D 数组 (1, dimension)

        # 先获取 fetch_k 个最相似的文档
        if self._use_faiss and isinstance(self.index, faiss.Index):
            # 使用真实的 FAISS(需要 2D 数组)
            distances, indices = self.index.search(query_vector, fetch_k)
            candidate_indices = indices[0]
        else:
            # 简化实现
            if not hasattr(self, '_vectors') or len(self._vectors) == 0:
                return []

            vectors = np.array(self._vectors, dtype=np.float32)
            query_vec = np.array(query_embedding, dtype=np.float32)

            # 计算余弦相似度
            query_norm = query_vec / (np.linalg.norm(query_vec) + 1e-10)
            vectors_norm = vectors / (np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-10)
            similarities = np.dot(vectors_norm, query_norm)

            # 获取 top fetch_k
            top_indices = np.argsort(similarities)[::-1][:fetch_k]
            candidate_indices = top_indices

        # 如果候选数量不足,直接返回
        if len(candidate_indices) <= k:
            docs = []
            for idx in candidate_indices:
                if idx in self.index_to_docstore_id:
                    doc_id = self.index_to_docstore_id[idx]
                    docs.append(self.docstore[doc_id])
            return docs

        # MMR 算法:选择既相似又多样化的文档
        selected_indices = []
        candidate_list = list(candidate_indices)

        # 获取候选文档的嵌入向量
        if self._use_faiss and isinstance(self.index, faiss.Index):
            # 从 FAISS 索引中获取向量(简化:使用存储的向量)
            if hasattr(self, '_vectors'):
                candidate_vectors = np.array([self._vectors[i] for i in candidate_list], dtype=np.float32)
            else:
                # 如果无法获取向量,回退到简单选择
                candidate_vectors = None
        else:
            candidate_vectors = np.array([self._vectors[i] for i in candidate_list], dtype=np.float32)

        if candidate_vectors is not None:
            # 归一化向量
            query_norm = query_vector / (np.linalg.norm(query_vector) + 1e-10)
            candidate_norm = candidate_vectors / (np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-10)

            # 计算与查询的相似度
            query_similarities = np.dot(candidate_norm, query_norm.T).flatten()

            # MMR 选择
            selected = []
            remaining = list(range(len(candidate_list)))

            # 选择第一个:最相似的
            first_idx = np.argmax(query_similarities)
            selected.append(first_idx)
            remaining.remove(first_idx)

            # 选择剩余的 k-1 个
            for _ in range(min(k - 1, len(remaining))):
                if len(remaining) == 0:
                    break

                best_score = -float('inf')
                best_idx = None

                for idx in remaining:
                    # 与查询的相似度
                    relevance = query_similarities[idx]

                    # 与已选文档的最大相似度(多样性)
                    if len(selected) > 0:
                        selected_vectors = candidate_norm[selected]
                        diversity = np.max(np.dot(candidate_norm[idx:idx+1], selected_vectors.T))
                    else:
                        diversity = 0

                    # MMR 分数:平衡相关性和多样性
                    mmr_score = lambda_mult * relevance - (1 - lambda_mult) * diversity

                    if mmr_score > best_score:
                        best_score = mmr_score
                        best_idx = idx

                if best_idx is not None:
                    selected.append(best_idx)
                    remaining.remove(best_idx)

            # 获取选中的文档
            selected_indices = [candidate_list[i] for i in selected]
        else:
            # 回退:简单选择前 k 个
            selected_indices = candidate_list[:k]

        # 返回文档
        docs = []
        for idx in selected_indices:
            if idx in self.index_to_docstore_id:
                doc_id = self.index_to_docstore_id[idx]
                docs.append(self.docstore[doc_id])
        return docs

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding,
        metadatas: Optional[List[Dict]] = None,
        **kwargs: Any
    ) -> "FAISS":
        """
        从文本列表创建 FAISS 向量存储

        Args:
            texts: 文本列表
            embedding: 嵌入模型
            metadatas: 元数据列表
            **kwargs: 其他参数

        Returns:
            FAISS: FAISS 向量存储实例
        """
        # 创建 FAISS 实例
        instance = cls(embedding=embedding)
        # 添加文本
        instance.add_texts(texts=texts, metadatas=metadatas, **kwargs)
        return instance

13.5. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
from langchain.messages import AIMessage, HumanMessage, SystemMessage
+# 从 langchain.embeddings 模块导入 OpenAIEmbeddings 类
+from langchain.embeddings import OpenAIEmbeddings

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage/SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant, system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:

    # 初始化方法
    def __init__(self, model: str = "qwen-max", **kwargs):
        """
        初始化 ChatTongyi

        Args:
            model: 模型名称,如 "qwen-max"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # 通义千问的 API base URL(使用 OpenAI 兼容模式)
        base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
        # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

14. SemanticSimilarityExampleSelector #

在实际的 Few-Shot Prompt 应用中,SemanticSimilarityExampleSelector 是最常见的基础用法。相比 MaxMarginalRelevanceExampleSelector,它仅根据输入问题与各个示例的“语义相似度”来筛选,不考虑示例间的多样性。具体而言:

  1. 嵌入编码:把每个示例的指定字段(如 question)通过嵌入模型转为向量。
  2. 相似度检索:收到用户问题后,也转为向量,然后用向量检索算法(比如 FAISS)选出和用户问题最“近”的 $k$ 个示例。
  3. 直接返回:这些示例将被用于组装到 Prompt 结构中,帮助 LLM 参考。

优点在于实现简单、检索速度快,适合问题场景比较单一、示例间差异本身较大的用途。缺点是如果示例分布高度相似,容易选出内容重复的案例,影响上下文多样性。此时可以考虑使用带多样性约束的 MaxMarginalRelevanceExampleSelector。

14.1. SemanticSimilarityExampleSelector.py #

14.SemanticSimilarityExampleSelector.py

#from langchain_core.prompts import PromptTemplate, FewShotPromptTemplate
#from langchain_core.example_selectors import SemanticSimilarityExampleSelector
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings
#from langchain_community.vectorstores import FAISS

from langchain.chat_models import ChatOpenAI,OpenAIEmbeddings
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.example_selectors import SemanticSimilarityExampleSelector
from langchain.vectorstores import FAISS

examples = [
    {"question": "今天天气怎么样?", "answer": "今天天气晴朗,适合出门活动。"},
    {"question": "怎么做西红柿炒鸡蛋?", "answer": "先把西红柿和鸡蛋切好,鸡蛋炒熟后盛出,再炒西红柿,最后把鸡蛋倒回去一起炒匀即可。"},
    {"question": "如何快速减肥?", "answer": "合理饮食结合锻炼,每天保持运动,避免高热量食物。"},
    {"question": "手机没电了怎么办?", "answer": "用充电器充电,或者借用移动电源。"},
    {"question": "头疼该怎么办?", "answer": "多休息,如果严重可以适当吃点止痛药。"},
    {"question": "怎样养护盆栽?", "answer": "定期浇水,保持阳光,不要积水。"},
    {"question": "想学英语怎么入门?", "answer": "可以先从背单词、学基础语法和多听多说开始。"},
    {"question": "晚上失眠怎么办?", "answer": "睡前放松,避免咖啡因,可以听点轻音乐帮助入睡。"},
    {"question": "烧水壶如何清理水垢?", "answer": "可以倒入一点醋和水煮几分钟,再用清水冲洗干净。"},
    {"question": "手机上怎么截图?", "answer": "可以同时按住电源键和音量减键进行截图,不同手机略有区别。"},
]

# 创建示例模板
example_prompt = PromptTemplate.from_template(
    "问题:{question}\n答案:{answer}"
)

print("=" * 60)
print("SemanticSimilarityExampleSelector 演示")
print("=" * 60)
print("\nSemanticSimilarityExampleSelector 使用语义相似度搜索")
print("来选择示例,它根据与查询的语义相似度选择最相关的示例。")
print("这是最直接的示例选择方法,简单高效。\n")

# 创建嵌入模型
print("正在初始化嵌入模型和向量存储...")
embeddings = OpenAIEmbeddings()

# 使用 from_examples 方法创建 SemanticSimilarityExampleSelector
selector = SemanticSimilarityExampleSelector.from_examples(
    examples=examples,
    embeddings=embeddings,
    vectorstore_cls=FAISS,
    k=3,       # 最终选择 3 个示例
    input_keys=["question"],  # 基于 "question" 字段进行相似度搜索
)

print(f"已创建选择器,包含 {len(examples)} 个示例\n")
print("=" * 60)
print("\n")

# 测试不同的问题,观察选择器如何选择示例
test_questions = [
    "有什么简单的家常菜推荐?",         # 与如何做饭、日常生活相关
    "如何改善睡眠质量?",               # 与失眠、放松相关
    "手机没信号怎么办?",                # 与手机使用问题相关
]

for i, question in enumerate(test_questions, 1):
    print(f"测试问题 {i}: {question}")
    print("-" * 60)

    # 选择示例
    selected = selector.select_examples({"question": question})
    print(f"选中的示例数量:{len(selected)}")
    print("\n选中的示例:")
    for j, ex in enumerate(selected, 1):
        formatted = example_prompt.format(**ex)
        print(f"\n示例 {j}:")
        print(formatted)

    print("\n" + "=" * 60 + "\n")

# 使用选择器与 FewShotPromptTemplate 结合
print("使用 FewShotPromptTemplate + SemanticSimilarityExampleSelector:")
few_shot_prompt = FewShotPromptTemplate(
    example_prompt=example_prompt,
    prefix="你是一个乐于助人的生活小助手。以下是一些建议示例:",
    suffix="问题:{question}\n答案:",  # 使用 "question" 变量
    example_selector=selector,
    input_variables=["question"],
)

# 格式化提示词
user_question = "怎样挑西瓜比较甜?"
formatted = few_shot_prompt.format(question=user_question)
print(f"\n用户问题:{user_question}")
print("\n生成的提示词:")
print(formatted)
print("\n" + "=" * 60 + "\n")

# 调用模型(可选)
print("调用 AI 模型生成回答...")
try:
    llm = ChatOpenAI(model="gpt-4o")
    result = llm.invoke(formatted)
    print("AI 回复:")
    print(result.content)
except Exception as e:
    print(f"调用模型时出错:{e}")
    print("(这可能是由于 API 密钥未设置或网络问题)")

14.2. example_selectors.py #

langchain/example_selectors.py

# 导入正则表达式库
import re
# 从本地prompts模块导入PromptTemplate类
from .prompts import PromptTemplate
# 从本地vectorstores模块导入VectorStore和Document类
from .vectorstores import VectorStore, Document
# 导入类型提示相关模块
from typing import Optional, List, Dict, Any
from abc import ABC, abstractmethod

# 定义一个基于长度的示例选择器类
class LengthBasedExampleSelector:
    """
    基于长度的示例选择器
    根据输入长度自动选择合适数量的示例,确保总长度不超过限制
    """

    # 构造方法,初始化LengthBasedExampleSelector对象
    def __init__(
        self,
        examples: list[dict],               # 示例列表,每个元素为字典
        example_prompt: PromptTemplate | str,  # 用于格式化示例的模板,可以是PromptTemplate或字符串
        max_length: int = 2048,             # 提示词最大长度,默认为2048
        get_text_length=None,               # 可选的文本长度计算函数
    ):
        """
        初始化 LengthBasedExampleSelector

        Args:
            examples: 示例列表,每个示例是一个字典
            example_prompt: 用于格式化示例的模板(PromptTemplate 或字符串)
            max_length: 提示词的最大长度(默认 2048)
            get_text_length: 计算文本长度的函数(默认按单词数计算)
        """
        # 保存所有示例到实例变量
        self.examples = examples
        # 如果example_prompt是字符串,则构建为PromptTemplate对象
        if isinstance(example_prompt, str):
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 如果已是PromptTemplate则直接赋值
        else:
            self.example_prompt = example_prompt
        # 保存最大长度参数
        self.max_length = max_length
        # 设置文本长度计算函数,默认为内部定义的按单词数计算
        self.get_text_length = get_text_length or self._default_get_text_length
        # 计算并缓存每个示例(格式化后)的长度
        self.example_text_lengths = self._calculate_example_lengths()

    # 默认的长度计算方法,统计文本中的单词数
    def _default_get_text_length(self, text: str) -> int:
        """
        默认的长度计算函数:按单词数计算

        Args:
            text: 文本内容

        Returns:
            文本长度(单词数)
        """
        # 利用正则按空白字符分割,统计词数
        return len(re.split(r'\s+', text.strip()))

    # 计算所有示例格式化后的长度
    def _calculate_example_lengths(self) -> list[int]:
        """
        计算所有示例的长度

        Returns:
            每个示例的长度列表
        """
        # 初始化长度列表
        lengths = []
        # 对每个示例进行格式化并计算长度
        for example in self.examples:
            # 用模板对示例内容进行格式化
            formatted_example = self.example_prompt.format(**example)
            # 计算格式化后示例的长度
            length = self.get_text_length(formatted_example)
            # 记录到列表中
            lengths.append(length)
        # 返回长度列表
        return lengths

    # 添加新示例,并计算其长度
    def add_example(self, example: dict):
        """
        添加新示例

        Args:
            example: 新示例字典
        """
        # 添加到示例列表
        self.examples.append(example)
        # 用模板格式化新示例
        formatted_example = self.example_prompt.format(**example)
        # 计算格式化后的长度
        length = self.get_text_length(formatted_example)
        # 将长度加入缓存
        self.example_text_lengths.append(length)

    # 根据输入内容长度,选择合适的示例列表
    def select_examples(self, input_variables: dict) -> list[dict]:
        """
        根据输入长度选择示例

        Args:
            input_variables: 输入变量字典

        Returns:
            选中的示例列表
        """
        # 将输入的所有变量拼成一个字符串
        input_text = " ".join(str(v) for v in input_variables.values())
        # 计算输入内容的长度
        input_length = self.get_text_length(input_text)
        # 计算剩余可用长度
        remaining_length = self.max_length - input_length
        # 初始化选中的示例列表
        selected_examples = []
        # 遍历所有示例
        for i, example in enumerate(self.examples):
            # 如果剩余长度已经不足,提前停止选择
            if remaining_length <= 0:
                break
            # 获取当前示例的长度
            example_length = self.example_text_lengths[i]
            # 判断如果再添加这个示例会超过剩余长度,则停止选择
            if remaining_length - example_length < 0:
                break
            # 把当前示例加入已选择列表
            selected_examples.append(example)
            # 更新剩余可用长度
            remaining_length -= example_length
        # 返回最终选中的示例列表
        return selected_examples


# 定义示例选择器的抽象基类
class BaseExampleSelector(ABC):
    """示例选择器的抽象基类"""

    @abstractmethod
    def select_examples(self, input_variables: dict) -> list[dict]:
        """根据输入变量选择示例"""
        pass

    @abstractmethod
    def add_example(self, example: dict) -> Any:
        """添加新示例"""
        pass


# 辅助函数:返回字典中按键排序的值列表
def sorted_values(values: dict) -> list:
    """返回字典中按键排序的值列表"""
    return [values[val] for val in sorted(values)]


# 定义基于向量存储的示例选择器基类
class _VectorStoreExampleSelector(BaseExampleSelector):
    """基于向量存储的示例选择器基类"""

    def __init__(
        self,
        vectorstore: VectorStore,
        k: int = 4,
        input_keys: Optional[List[str]] = None,
        example_keys: Optional[List[str]] = None,
        vectorstore_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        初始化 _VectorStoreExampleSelector

        Args:
            vectorstore: 向量存储实例
            k: 选择的示例数量
            input_keys: 用于搜索的输入键列表(如果提供,只使用这些键进行搜索)
            example_keys: 用于过滤示例的键列表(如果提供,只返回这些键的示例)
            vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
        """
        self.vectorstore = vectorstore
        self.k = k
        self.input_keys = input_keys
        self.example_keys = example_keys
        self.vectorstore_kwargs = vectorstore_kwargs or {}

    @staticmethod
    def _example_to_text(example: dict, input_keys: Optional[List[str]] = None) -> str:
        """
        将示例字典转换为文本字符串

        Args:
            example: 示例字典
            input_keys: 要使用的键列表(如果提供,只使用这些键)

        Returns:
            str: 文本字符串
        """
        if input_keys:
            # 只使用指定的键
            filtered_example = {key: example[key] for key in input_keys if key in example}
            return " ".join(sorted_values(filtered_example))
        # 使用所有键
        return " ".join(sorted_values(example))

    def _documents_to_examples(self, documents: List[Document]) -> List[dict]:
        """
        将文档列表转换为示例字典列表

        Args:
            documents: 文档列表

        Returns:
            List[dict]: 示例字典列表
        """
        # 从文档的元数据中获取示例
        examples = [dict(doc.metadata) for doc in documents]

        # 如果指定了 example_keys,则过滤示例
        if self.example_keys:
            examples = [{k: eg[k] for k in self.example_keys if k in eg} for eg in examples]

        return examples

    def add_example(self, example: dict) -> str:
        """
        添加新示例到向量存储

        Args:
            example: 示例字典

        Returns:
            str: 添加的示例 ID
        """
        # 将示例转换为文本
        text = self._example_to_text(example, self.input_keys)
        # 添加到向量存储
        ids = self.vectorstore.add_texts(
            texts=[text],
            metadatas=[example],
            **self.vectorstore_kwargs
        )
        return ids[0] if ids else ""


+# 定义语义相似度示例选择器类
+class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
+   """
+   基于语义相似度的示例选择器
+   
+   使用简单的相似度搜索来选择与查询最相似的示例。
+   这是最直接的示例选择方法,根据与查询的语义相似度排序选择 top-k 示例。
+   """
+   
+   def select_examples(self, input_variables: dict) -> List[dict]:
+       """
+       根据语义相似度选择示例
+       
+       Args:
+           input_variables: 输入变量字典
+       
+       Returns:
+           List[dict]: 选中的示例列表
+       """
+       # 将输入变量转换为文本
+       query_text = self._example_to_text(input_variables, self.input_keys)
+       
+       # 使用相似度搜索
+       vectorstore_kwargs = self.vectorstore_kwargs or {}
+       example_docs = self.vectorstore.similarity_search(
+           query=query_text,
+           k=self.k,
+           **vectorstore_kwargs
+       )
+       
+       # 将文档转换为示例
+       return self._documents_to_examples(example_docs)
+   
+   @classmethod
+   def from_examples(
+       cls,
+       examples: List[dict],
+       embeddings,
+       vectorstore_cls: type,
+       k: int = 4,
+       input_keys: Optional[List[str]] = None,
+       example_keys: Optional[List[str]] = None,
+       vectorstore_kwargs: Optional[Dict[str, Any]] = None,
+       **vectorstore_cls_kwargs: Any,
+   ) -> "SemanticSimilarityExampleSelector":
+       """
+       从示例列表创建 SemanticSimilarityExampleSelector
+       
+       Args:
+           examples: 示例列表
+           embeddings: 嵌入模型实例
+           vectorstore_cls: 向量存储类(如 FAISS)
+           k: 选择的示例数量
+           input_keys: 用于搜索的输入键列表
+           example_keys: 用于过滤示例的键列表
+           vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
+           **vectorstore_cls_kwargs: 传递给向量存储类的额外参数
+       
+       Returns:
+           SemanticSimilarityExampleSelector: 示例选择器实例
+       """
+       # 将示例转换为文本列表
+       string_examples = [
+           cls._example_to_text(eg, input_keys) for eg in examples
+       ]
+       
+       # 创建向量存储
+       vectorstore = vectorstore_cls.from_texts(
+           texts=string_examples,
+           embedding=embeddings,
+           metadatas=examples,
+           **vectorstore_cls_kwargs
+       )
+       
+       # 创建并返回选择器实例
+       return cls(
+           vectorstore=vectorstore,
+           k=k,
+           input_keys=input_keys,
+           example_keys=example_keys,
+           vectorstore_kwargs=vectorstore_kwargs,
+       )
+
+
# 定义最大边际相关性示例选择器类
class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
    """
    基于最大边际相关性(MMR)的示例选择器

    使用 MMR 算法选择示例,既考虑与查询的相似度,也考虑示例之间的多样性。
    这样可以避免选择过于相似的示例,提供更全面的示例集合。

    参考论文:https://arxiv.org/pdf/2211.13892.pdf
    """

    def __init__(
        self,
        vectorstore: VectorStore,
        k: int = 4,
        fetch_k: int = 20,
        input_keys: Optional[List[str]] = None,
        example_keys: Optional[List[str]] = None,
        vectorstore_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        初始化 MaxMarginalRelevanceExampleSelector

        Args:
            vectorstore: 向量存储实例
            k: 最终选择的示例数量
            fetch_k: 从向量存储中获取的候选示例数量(用于 MMR 算法)
            input_keys: 用于搜索的输入键列表
            example_keys: 用于过滤示例的键列表
            vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
        """
        super().__init__(
            vectorstore=vectorstore,
            k=k,
            input_keys=input_keys,
            example_keys=example_keys,
            vectorstore_kwargs=vectorstore_kwargs,
        )
        self.fetch_k = fetch_k

    def select_examples(self, input_variables: dict) -> List[dict]:
        """
        根据最大边际相关性选择示例

        Args:
            input_variables: 输入变量字典

        Returns:
            List[dict]: 选中的示例列表
        """
        # 将输入变量转换为文本
        query_text = self._example_to_text(input_variables, self.input_keys)

        # 使用 MMR 搜索
        example_docs = self.vectorstore.max_marginal_relevance_search(
            query=query_text,
            k=self.k,
            fetch_k=self.fetch_k,
            **self.vectorstore_kwargs
        )

        # 将文档转换为示例
        return self._documents_to_examples(example_docs)

    @classmethod
    def from_examples(
        cls,
        examples: List[dict],
        embeddings,
        vectorstore_cls: type,
        k: int = 4,
        fetch_k: int = 20,
        input_keys: Optional[List[str]] = None,
        example_keys: Optional[List[str]] = None,
        vectorstore_kwargs: Optional[Dict[str, Any]] = None,
        **vectorstore_cls_kwargs: Any,
    ) -> "MaxMarginalRelevanceExampleSelector":
        """
        从示例列表创建 MaxMarginalRelevanceExampleSelector

        Args:
            examples: 示例列表
            embeddings: 嵌入模型实例
            vectorstore_cls: 向量存储类(如 FAISS)
            k: 最终选择的示例数量
            fetch_k: 从向量存储中获取的候选示例数量
            input_keys: 用于搜索的输入键列表
            example_keys: 用于过滤示例的键列表
            vectorstore_kwargs: 传递给向量存储搜索函数的额外参数
            **vectorstore_cls_kwargs: 传递给向量存储类的额外参数

        Returns:
            MaxMarginalRelevanceExampleSelector: 示例选择器实例
        """
        # 将示例转换为文本列表
        string_examples = [
            cls._example_to_text(eg, input_keys) for eg in examples
        ]

        # 创建向量存储
        vectorstore = vectorstore_cls.from_texts(
            texts=string_examples,
            embedding=embeddings,
            metadatas=examples,
            **vectorstore_cls_kwargs
        )

        # 创建并返回选择器实例
        return cls(
            vectorstore=vectorstore,
            k=k,
            fetch_k=fetch_k,
            input_keys=input_keys,
            example_keys=example_keys,
            vectorstore_kwargs=vectorstore_kwargs,
        )

15. BaseExampleSelector #

可通过继承 BaseExampleSelector,实现任意定制化筛选算法(如基于关键词、规则、检索、Embedding 相似度等),以适应复杂/场景化的智能应用开发需求。

  • 如何根据自己的需求,实现基于关键词匹配(KeywordBasedExampleSelector)和随机选择(RandomExampleSelector)两种常见的示例选择逻辑。
  • 自定义选择器只需实现 select_examples 和 add_example 方法,便可以 plug-in 到 FewShotPromptTemplate 等上游组件中,实现高灵活性。
  • 关键词选择器适合输入与示例语料相关性较强、期望提供精准 few-shot 案例的应用场景;而随机选择器仅作为逻辑演示,实际生产场景一般结合具体策略与语义。
  • 配合 PromptTemplate/FewShotPromptTemplate ,自定义选择器能够动态为大模型准备高质量上下文示例,从而提升输出效果,是 LangChain few-shot 技术链条中的重要一环。

15.1. BaseExampleSelector.py #

15.BaseExampleSelector.py

#from langchain_core.prompts import PromptTemplate, FewShotPromptTemplate
#from langchain_core.example_selectors import BaseExampleSelector
#from langchain_openai import ChatOpenAI, OpenAIEmbeddings
#from langchain_community.vectorstores import FAISS

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate, FewShotPromptTemplate
from langchain.example_selectors import BaseExampleSelector
from typing import List, Dict
import re


# 自定义示例选择器:基于关键词匹配的选择器
class KeywordBasedExampleSelector(BaseExampleSelector):
    """
    基于关键词匹配的自定义示例选择器

    这个选择器会根据输入中的关键词来选择最相关的示例。
    它计算输入和示例之间的关键词重叠度,选择重叠度最高的示例。
    """

    def __init__(
        self,
        examples: List[Dict[str, str]],
        k: int = 3,
        input_key: str = "question",
        min_keyword_match: int = 1,
    ):
        """
        初始化关键词选择器

        Args:
            examples: 示例列表
            k: 选择的示例数量
            input_key: 用于匹配的输入键名
            min_keyword_match: 最少匹配的关键词数量
        """
        self.examples = examples
        self.k = k
        self.input_key = input_key
        self.min_keyword_match = min_keyword_match

    def _extract_keywords(self, text: str) -> set:
        """
        从文本中提取关键词(中文和英文)

        Args:
            text: 输入文本

        Returns:
            set: 关键词集合
        """
        # 提取中文词汇(简单实现:按字符分割,过滤单字)
        chinese_words = re.findall(r'[\u4e00-\u9fa5]{2,}', text)
        # 提取英文单词
        english_words = re.findall(r'[a-zA-Z]+', text.lower())
        # 合并并返回
        return set(chinese_words + english_words)

    def _calculate_match_score(self, input_text: str, example_text: str) -> int:
        """
        计算输入文本和示例文本的匹配分数

        Args:
            input_text: 输入文本
            example_text: 示例文本

        Returns:
            int: 匹配分数(关键词重叠数量)
        """
        input_keywords = self._extract_keywords(input_text)
        example_keywords = self._extract_keywords(example_text)
        # 返回重叠的关键词数量
        return len(input_keywords & example_keywords)

    def select_examples(self, input_variables: Dict[str, str]) -> List[Dict]:
        """
        根据关键词匹配选择示例

        Args:
            input_variables: 输入变量字典

        Returns:
            List[Dict]: 选中的示例列表
        """
        # 获取输入文本
        input_text = input_variables.get(self.input_key, "")
        if not input_text:
            # 如果没有输入,返回前 k 个示例
            return self.examples[:self.k]

        # 计算每个示例的匹配分数
        scored_examples = []
        for example in self.examples:
            example_text = example.get(self.input_key, "")
            score = self._calculate_match_score(input_text, example_text)
            if score >= self.min_keyword_match:
                scored_examples.append((score, example))

        # 按分数排序(降序)
        scored_examples.sort(key=lambda x: x[0], reverse=True)

        # 返回前 k 个示例
        selected = [example for _, example in scored_examples[:self.k]]

        # 如果选中的示例不足 k 个,用其他示例填充
        if len(selected) < self.k:
            remaining = [ex for ex in self.examples if ex not in selected]
            selected.extend(remaining[:self.k - len(selected)])

        return selected

    def add_example(self, example: Dict[str, str]) -> None:
        """
        添加新示例

        Args:
            example: 新示例字典
        """
        self.examples.append(example)


# 自定义示例选择器:随机选择器(用于演示)
class RandomExampleSelector(BaseExampleSelector):
    """
    随机选择示例的选择器

    这个选择器随机选择示例,用于演示自定义选择器的灵活性。
    在实际应用中,可以根据需要实现任何选择逻辑。
    """

    def __init__(self, examples: List[Dict[str, str]], k: int = 3):
        """
        初始化随机选择器

        Args:
            examples: 示例列表
            k: 选择的示例数量
        """
        self.examples = examples
        self.k = k

    def select_examples(self, input_variables: Dict[str, str]) -> List[Dict]:
        """
        随机选择示例

        Args:
            input_variables: 输入变量字典(这里不使用,但为了接口一致性保留)

        Returns:
            List[Dict]: 随机选中的示例列表
        """
        import random
        # 随机选择 k 个示例(不重复)
        if len(self.examples) <= self.k:
            return self.examples.copy()
        return random.sample(self.examples, self.k)

    def add_example(self, example: Dict[str, str]) -> None:
        """
        添加新示例

        Args:
            example: 新示例字典
        """
        self.examples.append(example)


# 创建示例列表
examples = [
    {"question": "今天天气怎么样?", "answer": "今天天气晴朗,适合出门活动。"},
    {"question": "怎么做西红柿炒鸡蛋?", "answer": "先把西红柿和鸡蛋切好,鸡蛋炒熟后盛出,再炒西红柿,最后把鸡蛋倒回去一起炒匀即可。"},
    {"question": "如何快速减肥?", "answer": "合理饮食结合锻炼,每天保持运动,避免高热量食物。"},
    {"question": "手机没电了怎么办?", "answer": "用充电器充电,或者借用移动电源。"},
    {"question": "头疼该怎么办?", "answer": "多休息,如果严重可以适当吃点止痛药。"},
    {"question": "怎样养护盆栽?", "answer": "定期浇水,保持阳光,不要积水。"},
    {"question": "想学英语怎么入门?", "answer": "可以先从背单词、学基础语法和多听多说开始。"},
    {"question": "晚上失眠怎么办?", "answer": "睡前放松,避免咖啡因,可以听点轻音乐帮助入睡。"},
    {"question": "烧水壶如何清理水垢?", "answer": "可以倒入一点醋和水煮几分钟,再用清水冲洗干净。"},
    {"question": "手机上怎么截图?", "answer": "可以同时按住电源键和音量减键进行截图,不同手机略有区别。"},
]

# 创建示例模板
example_prompt = PromptTemplate.from_template(
    "问题:{question}\n答案:{answer}"
)

print("=" * 60)
print("自定义示例选择器演示")
print("=" * 60)
print("\n本演示展示了如何创建和使用自定义的示例选择器。")
print("我们将演示两种自定义选择器:")
print("1. KeywordBasedExampleSelector - 基于关键词匹配")
print("2. RandomExampleSelector - 随机选择\n")

print("=" * 60)
print("演示 1: KeywordBasedExampleSelector(基于关键词匹配)")
print("=" * 60)

# 创建基于关键词的选择器
keyword_selector = KeywordBasedExampleSelector(
    examples=examples,
    k=3,
    input_key="question",
    min_keyword_match=1,
)

print(f"\n已创建关键词选择器,包含 {len(examples)} 个示例")
print(f"将选择与输入关键词最匹配的 3 个示例\n")

# 测试不同的问题
test_questions = [
    "有什么简单的家常菜推荐?",  # 应该匹配"怎么做西红柿炒鸡蛋"
    "如何改善睡眠质量?",        # 应该匹配"晚上失眠怎么办"
    "手机没信号怎么办?",         # 应该匹配"手机没电了怎么办"
]

for i, question in enumerate(test_questions, 1):
    print(f"测试问题 {i}: {question}")
    print("-" * 60)

    # 选择示例
    selected = keyword_selector.select_examples({"question": question})
    print(f"选中的示例数量:{len(selected)}")
    print("\n选中的示例:")
    for j, ex in enumerate(selected, 1):
        formatted = example_prompt.format(**ex)
        print(f"\n示例 {j}:")
        print(formatted)

    print("\n" + "=" * 60 + "\n")

print("=" * 60)
print("演示 2: RandomExampleSelector(随机选择)")
print("=" * 60)

# 创建随机选择器
random_selector = RandomExampleSelector(
    examples=examples,
    k=3,
)

print(f"\n已创建随机选择器,包含 {len(examples)} 个示例")
print(f"将随机选择 3 个示例\n")

# 测试随机选择(运行多次看效果)
for i in range(3):
    print(f"随机选择 {i+1}:")
    print("-" * 60)
    selected = random_selector.select_examples({"question": "测试问题"})
    print(f"选中的示例数量:{len(selected)}")
    print("\n选中的示例:")
    for j, ex in enumerate(selected, 1):
        formatted = example_prompt.format(**ex)
        print(f"\n示例 {j}:")
        print(formatted)
    print("\n" + "=" * 60 + "\n")

print("=" * 60)
print("演示 3: 与 FewShotPromptTemplate 结合使用")
print("=" * 60)

# 使用自定义选择器与 FewShotPromptTemplate 结合
print("\n使用 KeywordBasedExampleSelector 与 FewShotPromptTemplate:")
few_shot_prompt = FewShotPromptTemplate(
    example_prompt=example_prompt,
    prefix="你是一个乐于助人的生活小助手。以下是一些建议示例:",
    suffix="问题:{question}\n答案:",
    example_selector=keyword_selector,  # 使用自定义选择器
    input_variables=["question"],
)

# 格式化提示词
user_question = "怎样挑西瓜比较甜?"
formatted = few_shot_prompt.format(question=user_question)
print(f"\n用户问题:{user_question}")
print("\n生成的提示词:")
print(formatted)
print("\n" + "=" * 60 + "\n")

# 调用模型(可选)
print("调用 AI 模型生成回答...")
try:
    llm = ChatOpenAI(model="gpt-4o")
    result = llm.invoke(formatted)
    print("AI 回复:")
    print(result.content)
except Exception as e:
    print(f"调用模型时出错:{e}")
    print("(这可能是由于 API 密钥未设置或网络问题)")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\n自定义示例选择器的关键点:")
print("1. 继承自 BaseExampleSelector 基类")
print("2. 实现 select_examples() 方法:根据输入变量选择示例")
print("3. 实现 add_example() 方法:添加新示例")
print("4. 可以定义任何选择逻辑:关键词匹配、随机、基于规则等")
print("5. 可以与 FewShotPromptTemplate 无缝集成")
print("\n通过自定义选择器,你可以根据具体需求实现最适合的示例选择策略!")

16. StrOutputParser #

StrOutputParser 是 LangChain 中最基础、常用的输出解析器,用于确保 LLM 的输出符合字符串格式。现实中,LLM 的输出往往已经是字符串类型,但在链式调用、多步处理或需要结构化解析时,统一采用解析器模式可以保持代码风格一致、便于扩展和维护。

应用场景包括:

  1. 直接将 LLM 的输出转为字符串(如问答、摘要、生成描述性文本时)。
  2. 与 PromptTemplate、FewShotPromptTemplate 等链式组件搭配,作为输出的最后一步,保证 Chain 的输出总是 str。
  3. 在流式(streaming)场景下,将逐步拼接起来的内容整体转为字符串处理,便于后续存储、展示或交互。
  4. 为复杂结构化 OutputParser 设计提供模板(如后续的 ListOutputParser, JsonOutputParser 等)。

使用特点:

  • 只做类型保证,不做内容修改。
  • 如果输入已经是字符串,原样返回;否则会尝试用 str() 转换。
  • 简化调试,提高代码健壮性(即使前面链路出现类型“小混乱”,也不会影响最后输出格式)。
  • 背后技术实质:其实就是一个轻量级的 str() 封装,但在链式编排体系下便于统一调用入口。

""")

16.1. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"

16.2. StrOutputParser.py #

16.StrOutputParser.py

#from langchain_core.prompts import PromptTemplate
#from langchain_core.output_parsers import StrOutputParser
#from langchain.chat_models import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import StrOutputParser

print("=" * 60)
print("StrOutputParser 演示")
print("=" * 60)
print("\nStrOutputParser 是一个简单的输出解析器,用于将 LLM 的输出")
print("转换为字符串格式。它主要用于链式调用中,确保输出是字符串类型。\n")

# 创建模型和输出解析器
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: 基本用法 - 直接解析 LLM 输出")
print("=" * 60)

# 直接调用模型
print("\n1. 直接调用模型(不使用解析器):")
response = llm.invoke("请用一句话介绍 Python 编程语言。")
print(f"输出类型:{type(response)}")
print(f"输出内容:{response.content}")
print(f"输出内容类型:{type(response.content)}")

print("\n2. 使用 StrOutputParser 解析输出:")
# 使用解析器
parsed_output = parser.parse(response.content)
print(f"解析后类型:{type(parsed_output)}")
print(f"解析后内容:{parsed_output}")

print("\n" + "=" * 60)
print("演示 2: 在链式调用中使用 StrOutputParser")
print("=" * 60)

# 创建提示模板
prompt = PromptTemplate.from_template(
    "你是一个专业的编程助手。请回答以下问题:\n问题:{question}\n回答:"
)

# 手动构建链:Prompt -> LLM -> Parser
print("\n使用手动链式调用(Prompt -> LLM -> Parser):")
questions = [
    "什么是 Python?",
    "Python 的主要特点是什么?",
    "如何学习 Python?",
]

for i, question in enumerate(questions, 1):
    print(f"\n问题 {i}: {question}")
    print("-" * 60)
    # 步骤1: 格式化提示词
    formatted_prompt = prompt.format(question=question)
    # 步骤2: 调用 LLM
    llm_response = llm.invoke(formatted_prompt)
    # 步骤3: 使用解析器解析输出
    result = parser.parse(llm_response.content)
    print(f"回答:{result}")
    print(f"输出类型:{type(result)}")  # 应该是 str

print("\n" + "=" * 60)
print("演示 3: 处理不同类型的输出")
print("=" * 60)

# 测试不同类型的输入
test_inputs = [
    "这是一个简单的字符串",
    "这是另一个字符串\n包含换行符",
    "  这是带空格的字符串  ",
]

print("\n测试 StrOutputParser 处理不同类型的输入:")
for i, test_input in enumerate(test_inputs, 1):
    print(f"\n测试 {i}:")
    print(f"输入:{repr(test_input)}")
    print(f"输入类型:{type(test_input)}")
    parsed = parser.parse(test_input)
    print(f"解析后:{repr(parsed)}")
    print(f"解析后类型:{type(parsed)}")
    print(f"是否相同:{test_input == parsed}")

print("\n" + "=" * 60)
print("演示 4: 与流式输出结合使用")
print("=" * 60)

# 流式输出
print("\n使用流式输出(stream):")
formatted_prompt = prompt.format(question="请用三句话介绍人工智能")
stream = llm.stream(formatted_prompt)

print("流式输出内容:")
full_response = ""
for chunk in stream:
    # chunk 是 AIMessage 对象,包含部分内容
    content = chunk.content if hasattr(chunk, 'content') else str(chunk)
    print(content, end="", flush=True)
    full_response += content
print()  # 换行

# 使用解析器解析完整的流式输出
parsed_stream = parser.parse(full_response)
print(f"\n解析后的完整输出:{parsed_stream}")
print(f"解析后类型:{type(parsed_stream)}")

print("\n" + "=" * 60)
print("演示 5: 实际应用场景")
print("=" * 60)

# 实际应用:文本生成和处理
print("\n应用场景:文本摘要生成")

summary_prompt = PromptTemplate.from_template(
    "请为以下文本生成一个简洁的摘要(不超过50字):\n\n{text}\n\n摘要:"
)

text_to_summarize = """
人工智能(AI)是计算机科学的一个分支,致力于创建能够执行通常需要人类智能的任务的系统。
这些任务包括学习、推理、问题解决、感知和语言理解。AI 系统可以通过机器学习、深度学习
和神经网络等技术来实现。近年来,AI 在图像识别、自然语言处理、自动驾驶等领域取得了
显著进展。
"""

print(f"\n原文:\n{text_to_summarize}")
print("\n生成摘要:")
# 手动构建链:Prompt -> LLM -> Parser
formatted_summary_prompt = summary_prompt.format(text=text_to_summarize)
summary_response = llm.invoke(formatted_summary_prompt)
summary = parser.parse(summary_response.content)
print(f"摘要:{summary}")
print(f"摘要类型:{type(summary)}")
print(f"摘要长度:{len(summary)} 字符")

print("\n" + "=" * 60)
print("演示 6: 错误处理")
print("=" * 60)

# 测试错误处理
print("\nStrOutputParser 的特点:")
print("1. 它不会修改输入内容,只是确保输出是字符串类型")
print("2. 如果输入已经是字符串,它会原样返回")
print("3. 如果输入不是字符串,它会尝试转换为字符串")

# 测试非字符串输入(如果可能)
try:
    # 注意:在实际使用中,LLM 的输出通常是字符串
    # 这里只是演示解析器的行为
    test_cases = [
        ("字符串输入", "字符串输入"),
        ("数字字符串", "123"),
        ("空字符串", ""),
    ]

    print("\n测试各种输入:")
    for description, test_input in test_cases:
        result = parser.parse(test_input)
        print(f"{description}: {repr(test_input)} -> {repr(result)}")
        assert result == test_input, "解析器应该原样返回字符串输入"

    print("\n✓ 所有测试通过:StrOutputParser 正确工作")
except Exception as e:
    print(f"\n✗ 测试失败:{e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nStrOutputParser 的主要用途:")
print("1. 确保 LLM 输出是字符串类型")
print("2. 在链式调用中统一输出格式")
print("3. 简化输出处理流程")
print("4. 与流式输出兼容")
print("\n使用场景:")
print("- 文本生成任务")
print("- 问答系统")
print("- 内容摘要")
print("- 任何需要字符串输出的场景")
print("\n注意:StrOutputParser 不会修改内容,只是确保类型为字符串。")
print("如果需要格式化或解析结构化数据,应该使用其他专门的解析器。")

16.3. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
from langchain.messages import AIMessage, HumanMessage, SystemMessage
# 从 langchain.embeddings 模块导入 OpenAIEmbeddings 类
from langchain.embeddings import OpenAIEmbeddings

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

+   # 流式调用模型生成回复的方法
+   def stream(self, input, **kwargs):
+       """
+       流式调用模型生成回复
+       
+       Args:
+           input: 输入内容,可以是字符串或消息列表
+           **kwargs: 额外的 API 参数
+       
+       Yields:
+           AIMessage: AI 的回复消息块(每次产生部分内容)
+       """
+       # 将输入数据转换为消息格式
+       messages = self._convert_input(input)
+       
+       # 构建 API 请求参数字典,启用流式输出
+       params = {
+           "model": self.model,
+           "messages": messages,
+           "stream": True,  # 启用流式输出
+           **self.model_kwargs,
+           **kwargs
+       }
+       
+       # 使用 OpenAI 客户端发起流式调用
+       stream = self.client.chat.completions.create(**params)
+       
+       # 迭代流式响应
+       for chunk in stream:
+           # 检查是否有内容增量
+           if chunk.choices and len(chunk.choices) > 0:
+               delta = chunk.choices[0].delta
+               if hasattr(delta, 'content') and delta.content:
+                   # 产生包含部分内容的 AIMessage
+                   yield AIMessage(content=delta.content)
+   
    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage/SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant, system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:

    # 初始化方法
    def __init__(self, model: str = "qwen-max", **kwargs):
        """
        初始化 ChatTongyi

        Args:
            model: 模型名称,如 "qwen-max"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # 通义千问的 API base URL(使用 OpenAI 兼容模式)
        base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
        # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

17. JsonOutputParser #

JsonOutputParser 是 LangChain 中用于将 LLM 输出解析为 JSON 结构的强大工具类。它不仅可以解析格式规范的 JSON 字符串,还具有容错能力,能应对 LLM 输出中常见的包裹 Markdown 代码块、说明性文本混杂、JSON 数组等场景。其主要应用包括提取结构化数据、自动化信息抽取、自然语言转 API 请求等。

典型特性和用法:

  1. 基础原理

    • 继承自 BaseOutputParser,parse 方法实现智能提取与转换逻辑。
    • 支持解析:
      • 纯 JSON 文本(如 {"a": 1, "b": 2})
      • Markdown 格式的 JSON 代码块(如 json ...)
      • 携带前后缀文本的 JSON 段落
      • JSON 数组、对象
  2. 常见用法

    • 结合 PromptTemplate 引导 LLM 以 JSON 输出回答,使后续处理结构化、自动化。
    • 与 Chain、Agent 等配合,将 AI 结果可靠转换为 Python dict/list,便于二次加工。
    • 可扩展,支持自定义解析策略,增强兼容能力。
  3. 注意事项

    • 强烈建议在提示词中使用 parser.get_format_instructions() 明确要求 LLM 输出规范 JSON,提升解析成功率。
    • 对于非规范输出,parse 方法会尽可能容错处理,但极度混乱的文本仍有解析失败的风险,应做好异常捕获。

17.1. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
+import json
+import re

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"

+
+# 辅助函数:从文本中提取 JSON(支持 markdown 代码块)
+def parse_json_markdown(text: str) -> Any:
+   """
+   从文本中解析 JSON,支持 markdown 代码块格式
+   
+   Args:
+       text: 可能包含 JSON 的文本
+   
+   Returns:
+       解析后的 JSON 对象
+   
+   Raises:
+       json.JSONDecodeError: 如果无法解析 JSON
+   """
+   # 去除首尾空白
+   text = text.strip()
+   
+   # 尝试匹配 markdown 代码块中的 JSON
+   # 匹配 ```json ... ``` 或 ``` ...
  • json_match = re.search(r'(?:json)?\s*\n?(.*?)\n?', text, re.DOTALL)
  • if json_match:
  • text = json_match.group(1).strip()
  • 尝试匹配 { ... } 或 [ ... ] #

  • json_match = re.search(r'({.}|$$.$$)', text, re.DOTALL)
  • if json_match:
  • text = json_match.group(1)
  • 解析 JSON #

  • return json.loads(text) + + +# 定义 JSON 输出解析器类 +class JsonOutputParser(BaseOutputParser[Any]):
  • """
  • JSON 输出解析器
  • 将 LLM 的输出解析为 JSON 对象。支持:
    • 纯 JSON 字符串
    • Markdown 代码块中的 JSON(json ...)
    • 包含 JSON 的文本(自动提取)
  • 主要用于:
    • 结构化数据提取
    • API 响应解析
    • 数据格式化
  • """
  • def init(self, pydantic_object: type = None):
  • """
  • 初始化 JsonOutputParser
  • Args:
  • pydantic_object: 可选的 Pydantic 模型类,用于验证 JSON 结构
  • """
  • self.pydantic_object = pydantic_object
  • def parse(self, text: str) -> Any:
  • """
  • 解析 JSON 输出文本
  • Args:
  • text: 包含 JSON 的文本
  • Returns:
  • Any: 解析后的 JSON 对象(字典、列表等)
  • Raises:
  • ValueError: 如果无法解析 JSON
  • """
  • try:
  • 使用辅助函数解析 JSON #

  • parsed = parse_json_markdown(text)
  • 如果指定了 Pydantic 模型,进行验证 #

  • if self.pydantic_object is not None:
  • try:
  • 尝试使用 Pydantic 验证(如果可用) #

  • if hasattr(self.pydantic_object, 'model_validate'):
  • return self.pydantic_object.model_validate(parsed)
  • elif hasattr(self.pydantic_object, 'parse_obj'):
  • return self.pydantic_object.parse_obj(parsed)
  • except Exception:
  • 如果验证失败,返回原始解析结果 #

  • pass
  • return parsed
  • except json.JSONDecodeError as e:
  • raise ValueError(f"无法解析 JSON 输出: {text[:100]}... 错误: {e}")
  • except Exception as e:
  • raise ValueError(f"解析 JSON 时出错: {e}")
  • def get_format_instructions(self) -> str:
  • """
  • 获取格式说明,用于在提示词中指导 LLM 输出 JSON 格式
  • Returns:
  • str: 格式说明文本
  • """
  • if self.pydantic_object is not None:
  • 如果有 Pydantic 模型,返回其 schema #

  • try:
  • if hasattr(self.pydantic_object, 'model_json_schema'):
  • schema = self.pydantic_object.model_json_schema()
  • elif hasattr(self.pydantic_object, 'schema'):
  • schema = self.pydantic_object.schema()
  • else:
  • schema = {}
  • return f"""请以 JSON 格式输出,格式如下: +json +{schema} + + +确保输出是有效的 JSON 格式。"""
  • except Exception:
  • pass
  • return """请以 JSON 格式输出你的回答。 + +输出格式要求: +1. 使用有效的 JSON 格式 +2. 可以使用 markdown 代码块包裹:json ... +3. 确保所有字符串都用双引号 +4. 确保 JSON 格式正确且完整 + +示例格式: +`json +{
  • "key": "value",
  • "number": 123 +} +`"""
  • def repr(self) -> str:
  • """返回解析器的字符串表示"""
  • if self.pydantic_object:
  • return f"JsonOutputParser(pydantic_object={self.pydantic_object.name})"
  • return "JsonOutputParser()" +

### 17.2. JsonOutputParser.py
17.JsonOutputParser.py
```js
#from langchain_core.prompts import PromptTemplate
#from langchain_core.output_parsers import JsonOutputParser
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import JsonOutputParser
import json

print("=" * 60)
print("JsonOutputParser 演示")
print("=" * 60)
print("\nJsonOutputParser 用于将 LLM 的输出解析为 JSON 格式。")
print("它支持多种 JSON 格式,包括纯 JSON、Markdown 代码块中的 JSON 等。\n")

# 创建模型和输出解析器
llm = ChatOpenAI(model="gpt-4o")
parser = JsonOutputParser()

print("=" * 60)
print("演示 1: 基本用法 - 解析 JSON 输出")
print("=" * 60)

# 测试不同的 JSON 格式
test_cases = [
    # 纯 JSON
    '{"name": "张三", "age": 25, "city": "北京"}',
    # Markdown 代码块中的 JSON
    '```json\n{"product": "手机", "price": 3999, "in_stock": true}\n```',
    # 包含其他文本的 JSON
    '这是产品信息:{"name": "笔记本电脑", "brand": "联想", "price": 5999}',
    # JSON 数组
    '["苹果", "香蕉", "橙子"]',
]

print("\n测试不同的 JSON 格式:")
for i, test_input in enumerate(test_cases, 1):
    print(f"\n测试 {i}:")
    print(f"输入:{test_input[:50]}...")
    try:
        result = parser.parse(test_input)
        print(f"解析成功:{json.dumps(result, ensure_ascii=False, indent=2)}")
        print(f"类型:{type(result).__name__}")
    except Exception as e:
        print(f"解析失败:{e}")

print("\n" + "=" * 60)
print("演示 2: 与 LLM 结合使用 - 提取结构化数据")
print("=" * 60)

# 创建提示模板,要求 LLM 输出 JSON
prompt = PromptTemplate.from_template(
    """你是一个数据提取助手。请从以下文本中提取信息,并以 JSON 格式输出。

文本:{text}

{format_instructions}

请提取以下信息:
- name: 姓名
- age: 年龄
- location: 地点
- interests: 兴趣列表(数组)

JSON 输出:"""
)

# 获取格式说明
format_instructions = parser.get_format_instructions()

print("\n使用 JsonOutputParser 提取结构化数据:")
test_texts = [
    "我叫李四,今年30岁,住在上海。我喜欢编程、阅读和旅行。",
    "王五,28岁,来自深圳。爱好包括音乐、电影和摄影。",
]

for i, text in enumerate(test_texts, 1):
    print(f"\n文本 {i}: {text}")
    print("-" * 60)

    # 格式化提示词
    formatted_prompt = prompt.format(
        text=text,
        format_instructions=format_instructions
    )

    # 调用 LLM
    response = llm.invoke(formatted_prompt)

    # 解析 JSON 输出
    try:
        result = parser.parse(response.content)
        print(f"提取的数据:")
        print(json.dumps(result, ensure_ascii=False, indent=2))
        print(f"数据类型:{type(result).__name__}")

        # 访问解析后的数据
        if isinstance(result, dict):
            print(f"\n访问数据:")
            print(f"  姓名:{result.get('name', 'N/A')}")
            print(f"  年龄:{result.get('age', 'N/A')}")
            print(f"  地点:{result.get('location', 'N/A')}")
            print(f"  兴趣:{result.get('interests', [])}")
    except Exception as e:
        print(f"解析失败:{e}")
        print(f"原始输出:{response.content}")

print("\n" + "=" * 60)
print("演示 3: 复杂数据结构提取")
print("=" * 60)

# 创建更复杂的提示
complex_prompt = PromptTemplate.from_template(
    """分析以下产品评论,提取关键信息并以 JSON 格式输出。

评论:{review}

{format_instructions}

请提取:
- sentiment: 情感倾向(positive/negative/neutral)
- rating: 评分(1-5)
- keywords: 关键词列表
- summary: 评论摘要

JSON 输出:"""
)

print("\n提取复杂数据结构:")
reviews = [
    "这个手机非常好用,屏幕清晰,拍照效果很棒,电池续航也不错。强烈推荐!",
    "产品质量一般,价格偏贵,客服态度不好。不太满意。",
]

for i, review in enumerate(reviews, 1):
    print(f"\n评论 {i}: {review}")
    print("-" * 60)

    formatted = complex_prompt.format(
        review=review,
        format_instructions=parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        result = parser.parse(response.content)
        print(f"提取的数据:")
        print(json.dumps(result, ensure_ascii=False, indent=2))
    except Exception as e:
        print(f"解析失败:{e}")
        print(f"原始输出:{response.content[:200]}...")

print("\n" + "=" * 60)
print("演示 4: 批量数据处理")
print("=" * 60)

# 批量处理
batch_prompt = PromptTemplate.from_template(
    """请将以下信息转换为 JSON 数组格式。

信息列表:
{items}

{format_instructions}

每个项目应包含:
- id: 编号
- name: 名称
- category: 类别

输出 JSON 数组:"""
)

print("\n批量数据处理:")
items_text = """
1. 苹果 - 水果
2. 牛奶 - 乳制品
3. 面包 - 主食
4. 鸡蛋 - 蛋白质
"""

formatted = batch_prompt.format(
    items=items_text,
    format_instructions=parser.get_format_instructions()
)

response = llm.invoke(formatted)

try:
    result = parser.parse(response.content)
    print(f"解析结果:")
    print(json.dumps(result, ensure_ascii=False, indent=2))
    print(f"\n数据类型:{type(result).__name__}")
    if isinstance(result, list):
        print(f"数组长度:{len(result)}")
        print(f"第一个元素:{json.dumps(result[0], ensure_ascii=False) if result else 'N/A'}")
except Exception as e:
    print(f"解析失败:{e}")
    print(f"原始输出:{response.content[:300]}...")

print("\n" + "=" * 60)
print("演示 5: 错误处理")
print("=" * 60)

print("\n测试错误处理:")
invalid_inputs = [
    "这不是 JSON",
    '{"incomplete": json',
    '{"key": value}',  # 值没有引号
]

for i, invalid_input in enumerate(invalid_inputs, 1):
    print(f"\n测试 {i}: {invalid_input}")
    try:
        result = parser.parse(invalid_input)
        print(f"解析成功:{result}")
    except Exception as e:
        print(f"✓ 正确捕获错误:{type(e).__name__}: {e}")

print("\n" + "=" * 60)
print("演示 6: 实际应用场景 - API 数据格式化")
print("=" * 60)

# 实际应用:将自然语言转换为 API 请求格式
api_prompt = PromptTemplate.from_template(
    """将以下用户请求转换为 API 请求的 JSON 格式。

用户请求:{request}

{format_instructions}

API 请求应包含:
- method: HTTP 方法(GET/POST/PUT/DELETE)
- endpoint: API 端点
- params: 参数对象(如果有)
- headers: 请求头对象(如果有)

JSON 输出:"""
)

print("\n将自然语言转换为 API 请求格式:")
requests = [
    "获取用户ID为123的信息",
    "创建一个新订单,商品ID是456,数量是2",
]

for i, request in enumerate(requests, 1):
    print(f"\n请求 {i}: {request}")
    print("-" * 60)

    formatted = api_prompt.format(
        request=request,
        format_instructions=parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        result = parser.parse(response.content)
        print(f"API 请求格式:")
        print(json.dumps(result, ensure_ascii=False, indent=2))
    except Exception as e:
        print(f"解析失败:{e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nJsonOutputParser 的主要用途:")
print("1. 从 LLM 输出中提取结构化数据")
print("2. 将自然语言转换为 JSON 格式")
print("3. 解析各种 JSON 格式(纯 JSON、Markdown 代码块等)")
print("4. 数据格式化和验证")
print("\n使用场景:")
print("- 数据提取和结构化")
print("- API 请求生成")
print("- 配置文件生成")
print("- 数据转换和格式化")
print("\n注意事项:")
print("- 确保提示词中包含格式说明(get_format_instructions())")
print("- 处理可能的解析错误")
print("- 验证解析后的数据结构")

18. PydanticOutputParser #

PydanticOutputParser 是 LangChain 中用于将 LLM 输出结构化为强类型对象(Pydantic 模型)的核心工具。它不仅支持自动的数据类型转换和验证,还可以为 LLM 提供自动生成的输出格式说明,确保大模型的输出方便解析且符合需求。

主要功能

  1. 结构化输出
    利用 Pydantic 的强类型特性和字段约束,自动校验并解析 LLM 输出为 Pydantic 模型的实例。如有字段类型不符、值不合法,能及时抛出异常,提升数据安全性和调试效率。

  2. 自动格式说明
    通过 get_format_instructions() 方法生成人类可读的 JSON Schema 格式要求,作为 prompt 的一部分,大幅提升 LLM 输出格式的正确率。

  3. 灵活性
    支持 Pydantic v1 和 v2,也允许在缺少 pydantic 依赖时降级为普通的数据类,便于快速入门和跨环境演示。

典型用法场景

  • 提取结构化信息(如人物、产品、表单等)
  • 将自然语言数据自动转换为严格格式的数据结构
  • 面向 API 的接口返回、数据清洗/转换、配置解析等

简单工作原理

  • 定义 Pydantic 模型,描述需要提取的数据结构及校验规则
  • 创建 PydanticOutputParser(parser=你的模型类)
  • 在 prompt 中插入 parser.get_format_instructions() 生成的 Schema 说明
  • 用 LLM 进行推理推断后,调用 parser.parse() 自动解析为模型实例

18.1. PydanticOutputParser.py #

18.PydanticOutputParser.py

#from langchain_core.prompts import PromptTemplate
#from langchain_core.output_parsers import PydanticOutputParser
#from langchain_openai import ChatOpenAI

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser

# 尝试导入 Pydantic,如果没有安装则使用简单的数据类
try:
    from pydantic import BaseModel, Field
    PYDANTIC_AVAILABLE = True
except ImportError:
    PYDANTIC_AVAILABLE = False
    print("警告:pydantic 未安装,将使用简单的数据类进行演示。")
    print("建议安装:pip install pydantic")

    # 创建一个简单的 BaseModel 替代
    class BaseModel:
        def __init__(self, **kwargs):
            for key, value in kwargs.items():
                setattr(self, key, value)

        @classmethod
        def model_validate(cls, obj):
            return cls(**obj)

        @classmethod
        def model_json_schema(cls):
            # 简单的 schema 生成
            schema = {"properties": {}, "required": []}
            if hasattr(cls, '__annotations__'):
                for field_name, field_type in cls.__annotations__.items():
                    schema["properties"][field_name] = {"type": "string"}  # 简化处理
            return schema

    # 简单的 Field 函数,返回默认值(用于类型注解)
    def Field(description: str = "", default=None, **kwargs):
        # 如果没有提供 default,返回一个占位符对象
        if default is not None:
            return default
        # 返回一个简单的占位符,用于类型注解
        return type('Field', (), {'description': description, **kwargs})()

print("=" * 60)
print("PydanticOutputParser 演示")
print("=" * 60)
print("\nPydanticOutputParser 用于将 LLM 的输出解析为 Pydantic 模型实例。")
print("它提供了类型验证、数据验证和结构化数据提取功能。\n")

# 创建模型
llm = ChatOpenAI(model="gpt-4o")

print("=" * 60)
print("演示 1: 定义 Pydantic 模型并解析输出")
print("=" * 60)

# 定义 Pydantic 模型
class Person(BaseModel):
    """人员信息模型"""
    name: str = Field(description="姓名")
    age: int = Field(description="年龄", ge=0, le=150)
    email: str = Field(description="邮箱地址")
    city: str = Field(description="所在城市", default="未知")

class Product(BaseModel):
    """产品信息模型"""
    name: str = Field(description="产品名称")
    price: float = Field(description="价格", ge=0)
    category: str = Field(description="产品类别")
    in_stock: bool = Field(description="是否有库存", default=True)

print("\n已定义 Pydantic 模型:")
print("- Person: 人员信息(name, age, email, city)")
print("- Product: 产品信息(name, price, category, in_stock)")

# 创建解析器
person_parser = PydanticOutputParser(pydantic_object=Person)
product_parser = PydanticOutputParser(pydantic_object=Product)

print("\n" + "=" * 60)
print("演示 2: 基本用法 - 解析结构化数据")
print("=" * 60)

# 测试解析
test_json = '{"name": "张三", "age": 30, "email": "zhangsan@example.com", "city": "北京"}'
print(f"\n测试输入:{test_json}")

try:
    person = person_parser.parse(test_json)
    print(f"解析成功!")
    print(f"类型:{type(person).__name__}")
    print(f"姓名:{person.name}")
    print(f"年龄:{person.age}")
    print(f"邮箱:{person.email}")
    print(f"城市:{person.city}")
except Exception as e:
    print(f"解析失败:{e}")

print("\n" + "=" * 60)
print("演示 3: 与 LLM 结合使用 - 提取结构化数据")
print("=" * 60)

# 创建提示模板
person_prompt = PromptTemplate.from_template(
    """从以下文本中提取人员信息。

文本:{text}

{format_instructions}

请提取人员信息并以 JSON 格式输出:"""
)

print("\n从文本中提取人员信息:")
texts = [
    "我叫李四,今年28岁,邮箱是 lisi@example.com,我住在上海。",
    "王五,35岁,邮箱地址是 wangwu@test.com,来自深圳。",
]

for i, text in enumerate(texts, 1):
    print(f"\n文本 {i}: {text}")
    print("-" * 60)

    # 格式化提示词
    formatted = person_prompt.format(
        text=text,
        format_instructions=person_parser.get_format_instructions()
    )

    # 调用 LLM
    response = llm.invoke(formatted)

    # 解析为 Pydantic 模型
    try:
        person = person_parser.parse(response.content)
        print(f"提取成功!")
        print(f"  姓名:{person.name}")
        print(f"  年龄:{person.age}")
        print(f"  邮箱:{person.email}")
        print(f"  城市:{person.city}")
        print(f"  类型:{type(person).__name__}")
    except Exception as e:
        print(f"解析失败:{e}")
        print(f"原始输出:{response.content[:200]}...")

print("\n" + "=" * 60)
print("演示 4: 复杂数据结构 - 产品信息提取")
print("=" * 60)

# 产品信息提取
product_prompt = PromptTemplate.from_template(
    """从以下文本中提取产品信息。

文本:{text}

{format_instructions}

请提取产品信息并以 JSON 格式输出:"""
)

print("\n从文本中提取产品信息:")
product_texts = [
    "iPhone 15 Pro,价格是8999元,属于手机类别,目前有库存。",
    "MacBook Pro,售价12999,笔记本电脑类别,暂时缺货。",
]

for i, text in enumerate(product_texts, 1):
    print(f"\n文本 {i}: {text}")
    print("-" * 60)

    formatted = product_prompt.format(
        text=text,
        format_instructions=product_parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        product = product_parser.parse(response.content)
        print(f"提取成功!")
        print(f"  产品名称:{product.name}")
        print(f"  价格:{product.price} 元")
        print(f"  类别:{product.category}")
        print(f"  有库存:{product.in_stock}")
        print(f"  类型:{type(product).__name__}")
    except Exception as e:
        print(f"解析失败:{e}")

print("\n" + "=" * 60)
print("演示 5: 嵌套数据结构")
print("=" * 60)

# 定义嵌套模型
class Address(BaseModel):
    """地址模型"""
    street: str = Field(description="街道")
    city: str = Field(description="城市")
    zip_code: str = Field(description="邮编")

class Company(BaseModel):
    """公司模型"""
    name: str = Field(description="公司名称")
    address: Address = Field(description="公司地址")
    employee_count: int = Field(description="员工数量", ge=0)

print("\n定义嵌套模型:")
print("- Address: 地址信息")
print("- Company: 公司信息(包含地址)")

company_parser = PydanticOutputParser(pydantic_object=Company)

company_prompt = PromptTemplate.from_template(
    """从以下文本中提取公司信息。

文本:{text}

{format_instructions}

请提取公司信息并以 JSON 格式输出:"""
)

print("\n提取嵌套数据结构:")
company_text = "我们公司叫科技公司,地址在北京市中关村大街1号,邮编100000,有500名员工。"

formatted = company_prompt.format(
    text=company_text,
    format_instructions=company_parser.get_format_instructions()
)

response = llm.invoke(formatted)

try:
    company = company_parser.parse(response.content)
    print(f"提取成功!")
    print(f"  公司名称:{company.name}")
    print(f"  街道:{company.address.street}")
    print(f"  城市:{company.address.city}")
    print(f"  邮编:{company.address.zip_code}")
    print(f"  员工数:{company.employee_count}")
except Exception as e:
    print(f"解析失败:{e}")
    print(f"原始输出:{response.content[:300]}...")

print("\n" + "=" * 60)
print("演示 6: 数据验证和错误处理")
print("=" * 60)

print("\n测试数据验证:")

# 测试无效数据
invalid_cases = [
    ('{"name": "测试", "age": -5, "email": "test@test.com"}', "年龄为负数"),
    ('{"name": "测试", "age": 200, "email": "test@test.com"}', "年龄超出范围"),
    ('{"name": "测试", "age": 25}', "缺少必需字段"),
]

for invalid_json, description in invalid_cases:
    print(f"\n测试:{description}")
    print(f"输入:{invalid_json}")
    try:
        person = person_parser.parse(invalid_json)
        print(f"解析成功:{person}")
    except Exception as e:
        print(f"✓ 正确捕获验证错误:{type(e).__name__}: {e}")

print("\n" + "=" * 60)
print("演示 7: 实际应用场景 - 表单数据提取")
print("=" * 60)

# 定义表单模型
class ContactForm(BaseModel):
    """联系表单模型"""
    name: str = Field(description="姓名")
    phone: str = Field(description="电话号码")
    message: str = Field(description="留言内容")
    urgency: str = Field(description="紧急程度", default="normal")  # normal, urgent, low

print("\n联系表单数据提取:")

contact_parser = PydanticOutputParser(pydantic_object=ContactForm)

contact_prompt = PromptTemplate.from_template(
    """从以下用户留言中提取联系表单信息。

留言:{message}

{format_instructions}

请提取信息并以 JSON 格式输出:"""
)

messages = [
    "我是张三,电话13800138000,想咨询产品价格,比较紧急。",
    "王女士,手机号13900139000,询问售后服务,不着急。",
]

for i, message in enumerate(messages, 1):
    print(f"\n留言 {i}: {message}")
    print("-" * 60)

    formatted = contact_prompt.format(
        message=message,
        format_instructions=contact_parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        contact = contact_parser.parse(response.content)
        print(f"提取成功!")
        print(f"  姓名:{contact.name}")
        print(f"  电话:{contact.phone}")
        print(f"  留言:{contact.message}")
        print(f"  紧急程度:{contact.urgency}")
    except Exception as e:
        print(f"解析失败:{e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nPydanticOutputParser 的主要优势:")
print("1. 类型安全:自动类型验证和转换")
print("2. 数据验证:字段验证(范围、格式等)")
print("3. 结构化:返回强类型的模型实例")
print("4. 自动文档:从模型生成格式说明")
print("5. IDE 支持:更好的代码补全和类型提示")
print("\n使用场景:")
print("- 结构化数据提取和验证")
print("- API 请求/响应处理")
print("- 表单数据验证")
print("- 配置文件解析")
print("- 数据转换和规范化")
print("\n注意事项:")
print("- 需要安装 pydantic: pip install pydantic")
print("- 确保提示词中包含格式说明(get_format_instructions())")
print("- 处理可能的验证错误")
print("- 合理使用 Field 定义字段约束")

18.2. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
import json
import re

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"


# 辅助函数:从文本中提取 JSON(支持 markdown 代码块)
def parse_json_markdown(text: str) -> Any:
    """
    从文本中解析 JSON,支持 markdown 代码块格式

    Args:
        text: 可能包含 JSON 的文本

    Returns:
        解析后的 JSON 对象

    Raises:
        json.JSONDecodeError: 如果无法解析 JSON
    """
    # 去除首尾空白
    text = text.strip()

    # 尝试匹配 markdown 代码块中的 JSON
    # 匹配 ```json ... ``` 或 ``` ...
json_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if json_match:
    text = json_match.group(1).strip()

# 尝试匹配 { ... } 或 [ ... ]
json_match = re.search(r'(\{.*\}|$$.*$$)', text, re.DOTALL)
if json_match:
    text = json_match.group(1)

# 解析 JSON
return json.loads(text)

定义 JSON 输出解析器类 #

class JsonOutputParser(BaseOutputParser[Any]): """ JSON 输出解析器

将 LLM 的输出解析为 JSON 对象。支持:
- 纯 JSON 字符串
- Markdown 代码块中的 JSON(```json ... ```)
- 包含 JSON 的文本(自动提取)

主要用于:
- 结构化数据提取
- API 响应解析
- 数据格式化
"""

def __init__(self, pydantic_object: type = None):
    """
    初始化 JsonOutputParser

    Args:
        pydantic_object: 可选的 Pydantic 模型类,用于验证 JSON 结构
    """
    self.pydantic_object = pydantic_object

def parse(self, text: str) -> Any:
    """
    解析 JSON 输出文本

    Args:
        text: 包含 JSON 的文本

    Returns:
        Any: 解析后的 JSON 对象(字典、列表等)

    Raises:
        ValueError: 如果无法解析 JSON
    """
    try:
        # 使用辅助函数解析 JSON
        parsed = parse_json_markdown(text)

        # 如果指定了 Pydantic 模型,进行验证
        if self.pydantic_object is not None:
            try:
                # 尝试使用 Pydantic 验证(如果可用)
                if hasattr(self.pydantic_object, 'model_validate'):
                    return self.pydantic_object.model_validate(parsed)
                elif hasattr(self.pydantic_object, 'parse_obj'):
                    return self.pydantic_object.parse_obj(parsed)
            except Exception:
                # 如果验证失败,返回原始解析结果
                pass

        return parsed
    except json.JSONDecodeError as e:
        raise ValueError(f"无法解析 JSON 输出: {text[:100]}... 错误: {e}")
    except Exception as e:
        raise ValueError(f"解析 JSON 时出错: {e}")

def get_format_instructions(self) -> str:
    """
    获取格式说明,用于在提示词中指导 LLM 输出 JSON 格式

    Returns:
        str: 格式说明文本
    """
    if self.pydantic_object is not None:
        # 如果有 Pydantic 模型,返回其 schema
        try:
            if hasattr(self.pydantic_object, 'model_json_schema'):
                schema = self.pydantic_object.model_json_schema()
            elif hasattr(self.pydantic_object, 'schema'):
                schema = self.pydantic_object.schema()
            else:
                schema = {}

            return f"""请以 JSON 格式输出,格式如下:
{schema}

确保输出是有效的 JSON 格式。""" except Exception: pass

    return """请以 JSON 格式输出你的回答。

输出格式要求:

  1. 使用有效的 JSON 格式
  2. 可以使用 markdown 代码块包裹:json ...
  3. 确保所有字符串都用双引号
  4. 确保 JSON 格式正确且完整

示例格式:

{
  "key": "value",
  "number": 123
}
```"""

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        if self.pydantic_object:
            return f"JsonOutputParser(pydantic_object={self.pydantic_object.__name__})"
        return "JsonOutputParser()"

+
+# 定义 Pydantic 输出解析器类
+class PydanticOutputParser(JsonOutputParser):
+   """
+   Pydantic 输出解析器
+   
+   将 LLM 的输出解析为 Pydantic 模型实例。继承自 JsonOutputParser,
+   先解析 JSON,然后验证并转换为 Pydantic 模型。
+   
+   主要用于:
+   - 结构化数据验证
+   - 类型安全的数据提取
+   - 自动数据验证和转换
+   """
+   
+   def __init__(self, pydantic_object: type):
+       """
+       初始化 PydanticOutputParser
+       
+       Args:
+           pydantic_object: Pydantic 模型类(必需)
+       
+       Raises:
+           ValueError: 如果 pydantic_object 不是有效的 Pydantic 模型
+       """
+       if pydantic_object is None:
+           raise ValueError("PydanticOutputParser 需要一个 Pydantic 模型类")
+       
+       # 检查是否是 Pydantic 模型
+       try:
+           import pydantic
+           if not (issubclass(pydantic_object, pydantic.BaseModel) or 
+                   (hasattr(pydantic, 'v1') and issubclass(pydantic_object, pydantic.v1.BaseModel))):
+               raise ValueError(f"{pydantic_object} 不是有效的 Pydantic 模型类")
+       except ImportError:
+           # 如果没有安装 pydantic,使用更宽松的检查
+           if not hasattr(pydantic_object, '__fields__') and not hasattr(pydantic_object, 'model_fields'):
+               raise ValueError(f"{pydantic_object} 可能不是有效的 Pydantic 模型类(请确保已安装 pydantic)")
+       
+       self.pydantic_object = pydantic_object
+       # 调用父类初始化,传入 pydantic_object
+       super().__init__(pydantic_object=pydantic_object)
+   
+   def parse(self, text: str):
+       """
+       解析输出文本为 Pydantic 模型实例
+       
+       Args:
+           text: 包含 JSON 的文本
+       
+       Returns:
+           Pydantic 模型实例
+       
+       Raises:
+           ValueError: 如果无法解析 JSON 或验证失败
+       """
+       try:
+           # 先使用父类方法解析 JSON
+           json_obj = super().parse(text)
+           
+           # 转换为 Pydantic 模型实例
+           return self._parse_obj(json_obj)
+       except Exception as e:
+           raise ValueError(f"无法解析为 Pydantic 模型: {e}")
+   
+   def _parse_obj(self, obj: dict):
+       """
+       将字典对象转换为 Pydantic 模型实例
+       
+       Args:
+           obj: 字典对象
+       
+       Returns:
+           Pydantic 模型实例
+       
+       Raises:
+           ValueError: 如果验证失败
+       """
+       try:
+           import pydantic
+           
+           # 尝试使用 Pydantic v2 的 model_validate
+           if hasattr(self.pydantic_object, 'model_validate'):
+               return self.pydantic_object.model_validate(obj)
+           # 尝试使用 Pydantic v1 的 parse_obj
+           elif hasattr(self.pydantic_object, 'parse_obj'):
+               return self.pydantic_object.parse_obj(obj)
+           # 尝试直接实例化
+           else:
+               return self.pydantic_object(**obj)
+       except ImportError:
+           # 如果没有 pydantic,尝试直接实例化
+           return self.pydantic_object(**obj)
+       except Exception as e:
+           raise ValueError(f"Pydantic 验证失败: {e}")
+   
+   def _get_schema(self) -> dict:
+       """
+       获取 Pydantic 模型的 JSON Schema
+       
+       Returns:
+           dict: JSON Schema 字典
+       """
+       try:
+           # 尝试使用 Pydantic v2 的 model_json_schema
+           if hasattr(self.pydantic_object, 'model_json_schema'):
+               return self.pydantic_object.model_json_schema()
+           # 尝试使用 Pydantic v1 的 schema
+           elif hasattr(self.pydantic_object, 'schema'):
+               return self.pydantic_object.schema()
+           else:
+               return {}
+       except Exception:
+           return {}
+   
+   def get_format_instructions(self) -> str:
+       """
+       获取格式说明,包含 Pydantic 模型的 Schema
+       
+       Returns:
+           str: 格式说明文本
+       """
+       schema = self._get_schema()
+       
+       # 清理 schema,移除不必要的字段
+       reduced_schema = dict(schema)
+       if "title" in reduced_schema:
+           del reduced_schema["title"]
+       if "type" in reduced_schema:
+           del reduced_schema["type"]
+       
+       schema_str = json.dumps(reduced_schema, ensure_ascii=False, indent=2)
+       
+       return f"""请以 JSON 格式输出,必须严格遵循以下 Schema:
+
+```json
+{schema_str}
+

+ +输出要求: +1. 必须完全符合上述 Schema 结构 +2. 所有必需字段都必须提供 +3. 字段类型必须匹配(字符串、数字、布尔值等) +4. 使用有效的 JSON 格式 +5. 可以使用 markdown 代码块包裹:json ... + +确保输出是有效的 JSON,并且符合 Schema 定义。"""

  • def repr(self) -> str:
  • """返回解析器的字符串表示"""
  • return f"PydanticOutputParser(pydantic_object={self.pydantic_object.name})" +

## 19. OutputFixingParser
OutputFixingParser(输出修复解析器)是一种智能的解析工具,主要用于提升从 LLM(如 ChatGPT、GPT-4 等)返回结果时的结构化数据解析成功率。

**主要功能简介**

- **增强容错性**:当原始的输出解析器因格式错误等问题解析失败时,OutputFixingParser 会自动调用 LLM 生成修正后的输出,然后再次尝试解析。这通常能显著提高结构化解析的成功率。
- **自动修复**:通过内置或自定义的修复提示,指引 LLM 自动将输出转换为合法且可解析的目标格式(如 JSON、Pydantic、其他结构化数据)。
- **支持任意基础解析器**:OutputFixingParser 可以包装几乎所有 LangChain 原生解析器(JsonOutputParser、PydanticOutputParser 等)。
- **可控的重试策略**:可以设置最大重试次数,防止陷入死循环或产生过多 API 费用。

**典型使用场景**

- LLM 偶尔生成格式不严格、缺少字段或存在语法错误的输出,直接用 JSON/Pydantic 解析常常报错。
- 需要自动批量从 LLM 输出提取结构化信息,且对健壮性有较高要求。
- 希望最小化人工介入与后处理修复数据的工作量,提高自动化流程可靠性。

**工作流程示意**

1. 基础解析器尝试解析 LLM 的原始输出;
2. 如果解析失败,OutputFixingParser 自动将原始输出、Schema 约束等发送回 LLM,请求其修复格式;
3. 再次用基础解析器尝试解析修复后的输出,若仍失败,可多次重试(次数可设定);
4. 若所有尝试后仍失败,则抛出异常。


### 19.1. OutputFixingParser.py
19.OutputFixingParser.py
```python
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
import json

#from langchain.chat_models import ChatOpenAI
#from langchain.prompts import PromptTemplate
#from langchain.output_parsers import OutputFixingParser, JsonOutputParser, PydanticOutputParser
#from pydantic import BaseModel, Field
#import json

print("=" * 60)
print("OutputFixingParser 演示")
print("=" * 60)
print("\nOutputFixingParser 是一个智能的输出修复解析器。")
print("当基础解析器解析失败时,它会使用 LLM 自动修复输出格式,")
print("然后再次尝试解析。这大大提高了解析的成功率。\n")

# 创建模型
llm = ChatOpenAI(model="gpt-4o")

print("=" * 60)
print("演示 1: 基本用法 - 修复 JSON 解析错误")
print("=" * 60)

# 创建基础解析器
json_parser = JsonOutputParser()

# 创建修复解析器
fixing_parser = OutputFixingParser.from_llm(
    llm=llm,
    parser=json_parser,
    max_retries=2,
)

print("\n测试修复无效的 JSON 输出:")

# 测试无效的 JSON
invalid_outputs = [
    # 缺少引号
    '{name: "张三", age: 30, city: "北京"}',
    # 单引号(JSON 需要双引号)
    "{'product': '手机', 'price': 3999}",
    # 缺少逗号
    '{"name": "李四" "age": 25}',
    # 包含注释(JSON 不支持注释)
    '{"name": "王五", /* 这是注释 */ "age": 28}',
]

for i, invalid_output in enumerate(invalid_outputs, 1):
    print(f"\n测试 {i}:")
    print(f"原始输出(无效):{invalid_output}")
    print("-" * 60)

    try:
        # 直接使用基础解析器(应该失败)
        try:
            result = json_parser.parse(invalid_output)
            print(f"基础解析器:解析成功(意外)")
        except Exception as e:
            print(f"基础解析器:解析失败 ✓")
            print(f"  错误:{type(e).__name__}: {e}")

        # 使用修复解析器
        print("\n使用 OutputFixingParser:")
        print("  正在使用 LLM 修复输出...")
        try:
            result = fixing_parser.parse(invalid_output)
            print(f"  修复后解析成功!")
            print(f"  结果:{json.dumps(result, ensure_ascii=False, indent=2)}")
        except Exception as e:
            print(f"  修复失败:{type(e).__name__}: {e}")
    except Exception as e:
        print(f"修复失败:{e}")

print("\n" + "=" * 60)
print("演示 2: 修复 Pydantic 模型解析错误")
print("=" * 60)

# 定义 Pydantic 模型
class Person(BaseModel):
    """人员信息模型"""
    name: str = Field(description="姓名")
    age: int = Field(description="年龄", ge=0, le=150)
    email: str = Field(description="邮箱地址")
    city: str = Field(description="所在城市", default="未知")

# 创建 Pydantic 解析器
pydantic_parser = PydanticOutputParser(pydantic_object=Person)

# 创建修复解析器
pydantic_fixing_parser = OutputFixingParser.from_llm(
    llm=llm,
    parser=pydantic_parser,
    max_retries=2,
)

print("\n测试修复 Pydantic 模型解析错误:")

invalid_pydantic_outputs = [
    # 类型错误:年龄是字符串
    '{"name": "张三", "age": "30", "email": "zhangsan@example.com", "city": "北京"}',
    # 缺少必需字段
    '{"name": "李四", "age": 25}',
    # 字段名错误
    '{"name": "王五", "age": 28, "email_address": "wangwu@example.com", "city": "上海"}',
    # 值超出范围
    '{"name": "赵六", "age": 200, "email": "zhaoliu@example.com", "city": "深圳"}',
]

for i, invalid_output in enumerate(invalid_pydantic_outputs, 1):
    print(f"\n测试 {i}:")
    print(f"原始输出(无效):{invalid_output}")
    print("-" * 60)

    try:
        # 直接使用基础解析器(应该失败)
        try:
            result = pydantic_parser.parse(invalid_output)
            print(f"基础解析器:解析成功(意外)")
        except Exception as e:
            print(f"基础解析器:解析失败 ✓")
            print(f"  错误:{type(e).__name__}")

        # 使用修复解析器
        print("\n使用 OutputFixingParser:")
        print("  正在使用 LLM 修复输出...")
        try:
            result = pydantic_fixing_parser.parse(invalid_output)
            print(f"  修复后解析成功!")
            print(f"  姓名:{result.name}")
            print(f"  年龄:{result.age}")
            print(f"  邮箱:{result.email}")
            print(f"  城市:{result.city}")
            print(f"  类型:{type(result).__name__}")
        except Exception as e:
            print(f"  修复失败:{type(e).__name__}: {e}")
    except Exception as e:
        print(f"修复失败:{e}")

print("\n" + "=" * 60)
print("演示 3: 与 LLM 输出结合使用")
print("=" * 60)

# 创建提示模板
prompt = PromptTemplate.from_template(
    """从以下文本中提取人员信息,并以 JSON 格式输出。

文本:{text}

{format_instructions}

请提取信息并以 JSON 格式输出:"""
)

print("\n从文本中提取信息(可能输出格式不正确):")

texts = [
    "我叫张三,今年30岁,邮箱是zhangsan@example.com,住在北京。",
    "李四,28岁,邮箱lisi@test.com,上海人。",
]

for i, text in enumerate(texts, 1):
    print(f"\n文本 {i}: {text}")
    print("-" * 60)

    # 格式化提示词
    formatted = prompt.format(
        text=text,
        format_instructions=pydantic_parser.get_format_instructions()
    )

    # 调用 LLM(可能输出格式不正确)
    response = llm.invoke(formatted)
    print(f"LLM 原始输出:{response.content[:100]}...")

    # 使用修复解析器解析(自动修复格式错误)
    try:
        person = pydantic_fixing_parser.parse(response.content)
        print(f"\n✓ 解析成功(可能经过修复):")
        print(f"  姓名:{person.name}")
        print(f"  年龄:{person.age}")
        print(f"  邮箱:{person.email}")
        print(f"  城市:{person.city}")
    except Exception as e:
        print(f"\n✗ 解析失败:{e}")

print("\n" + "=" * 60)
print("演示 4: 自定义修复提示")
print("=" * 60)

# 创建自定义修复提示
custom_fix_prompt = PromptTemplate.from_template(
    """你是一个 JSON 格式修复助手。

原始格式要求:
{instructions}

原始输出(有错误):
{completion}

错误信息:
{error}

请修复输出,确保:
1. 符合原始格式要求
2. 是有效的 JSON
3. 包含所有必需字段
4. 字段类型正确

只返回修复后的 JSON,不要包含其他说明:"""
)

# 创建自定义修复解析器
custom_fixing_parser = OutputFixingParser.from_llm(
    llm=llm,
    parser=json_parser,
    prompt=custom_fix_prompt,
    max_retries=1,
)

print("\n使用自定义修复提示:")
test_output = '{name: "测试", age: "25"}'  # 无效 JSON

try:
    result = custom_fixing_parser.parse(test_output)
    print(f"修复成功:{json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
    print(f"修复失败:{e}")

print("\n" + "=" * 60)
print("演示 5: 多次重试")
print("=" * 60)

# 创建允许多次重试的解析器
retry_parser = OutputFixingParser.from_llm(
    llm=llm,
    parser=json_parser,
    max_retries=3,  # 最多重试 3 次
)

print("\n测试多次重试(max_retries=3):")
# 非常糟糕的 JSON
very_bad_json = 'name: 测试, age: 25, city: 北京'  # 完全不是 JSON

try:
    result = retry_parser.parse(very_bad_json)
    print(f"经过多次修复后成功:{json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
    print(f"即使多次重试也失败:{e}")

print("\n" + "=" * 60)
print("演示 6: 实际应用场景")
print("=" * 60)

print("\n实际应用:处理可能格式不正确的 LLM 输出")

# 定义一个更复杂的模型
class ProductReview(BaseModel):
    """产品评论模型"""
    product_name: str = Field(description="产品名称")
    rating: int = Field(description="评分", ge=1, le=5)
    pros: list[str] = Field(description="优点列表")
    cons: list[str] = Field(description="缺点列表")
    summary: str = Field(description="总结")

review_parser = PydanticOutputParser(pydantic_object=ProductReview)
review_fixing_parser = OutputFixingParser.from_llm(
    llm=llm,
    parser=review_parser,
    max_retries=2,
)

review_prompt = PromptTemplate.from_template(
    """分析以下产品评论,提取关键信息。

评论:{review}

{format_instructions}

请提取信息并以 JSON 格式输出:"""
)

reviews = [
    "这个手机非常好用,屏幕清晰,拍照效果很棒,电池续航也不错。但是价格有点贵,而且有点重。总体来说很不错。",
    "笔记本电脑性能强大,运行速度快,键盘手感好。缺点是风扇声音有点大,续航一般。",
]

for i, review in enumerate(reviews, 1):
    print(f"\n评论 {i}: {review}")
    print("-" * 60)

    formatted = review_prompt.format(
        review=review,
        format_instructions=review_parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        result = review_fixing_parser.parse(response.content)
        print(f"✓ 提取成功:")
        print(f"  产品:{result.product_name}")
        print(f"  评分:{result.rating}/5")
        print(f"  优点:{', '.join(result.pros)}")
        print(f"  缺点:{', '.join(result.cons)}")
        print(f"  总结:{result.summary}")
    except Exception as e:
        print(f"✗ 提取失败:{e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nOutputFixingParser 的主要优势:")
print("1. 自动修复:当解析失败时,自动使用 LLM 修复输出")
print("2. 提高成功率:大大提高了解析的成功率")
print("3. 可配置:可以设置最大重试次数")
print("4. 灵活:可以自定义修复提示模板")
print("5. 通用:可以包装任何输出解析器")
print("\n使用场景:")
print("- 处理可能格式不正确的 LLM 输出")
print("- 提高解析器的健壮性")
print("- 减少手动修复的工作量")
print("- 在生产环境中提高系统可靠性")
print("\n注意事项:")
print("- 会增加 API 调用次数(每次修复都需要调用 LLM)")
print("- 需要额外的成本(修复调用会消耗 tokens)")
print("- 设置合理的 max_retries 以避免无限循环")
print("- 修复提示的质量会影响修复效果")

19.2. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
import json
import re

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"


# 辅助函数:从文本中提取 JSON(支持 markdown 代码块)
def parse_json_markdown(text: str) -> Any:
    """
    从文本中解析 JSON,支持 markdown 代码块格式

    Args:
        text: 可能包含 JSON 的文本

    Returns:
        解析后的 JSON 对象

    Raises:
        json.JSONDecodeError: 如果无法解析 JSON
    """
    # 去除首尾空白
    text = text.strip()

    # 尝试匹配 markdown 代码块中的 JSON
    # 匹配 ```json ... ``` 或 ``` ...
json_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if json_match:
    text = json_match.group(1).strip()

# 尝试匹配 { ... } 或 [ ... ]
json_match = re.search(r'(\{.*\}|$$.*$$)', text, re.DOTALL)
if json_match:
    text = json_match.group(1)

# 解析 JSON
return json.loads(text)

定义 JSON 输出解析器类 #

class JsonOutputParser(BaseOutputParser[Any]): """ JSON 输出解析器

将 LLM 的输出解析为 JSON 对象。支持:
- 纯 JSON 字符串
- Markdown 代码块中的 JSON(```json ... ```)
- 包含 JSON 的文本(自动提取)

主要用于:
- 结构化数据提取
- API 响应解析
- 数据格式化
"""

def __init__(self, pydantic_object: type = None):
    """
    初始化 JsonOutputParser

    Args:
        pydantic_object: 可选的 Pydantic 模型类,用于验证 JSON 结构
    """
    self.pydantic_object = pydantic_object

def parse(self, text: str) -> Any:
    """
    解析 JSON 输出文本

    Args:
        text: 包含 JSON 的文本

    Returns:
        Any: 解析后的 JSON 对象(字典、列表等)

    Raises:
        ValueError: 如果无法解析 JSON
    """
    try:
        # 使用辅助函数解析 JSON
        parsed = parse_json_markdown(text)

        # 如果指定了 Pydantic 模型,进行验证
        if self.pydantic_object is not None:
            try:
                # 尝试使用 Pydantic 验证(如果可用)
                if hasattr(self.pydantic_object, 'model_validate'):
                    return self.pydantic_object.model_validate(parsed)
                elif hasattr(self.pydantic_object, 'parse_obj'):
                    return self.pydantic_object.parse_obj(parsed)
            except Exception:
                # 如果验证失败,返回原始解析结果
                pass

        return parsed
    except json.JSONDecodeError as e:
        raise ValueError(f"无法解析 JSON 输出: {text[:100]}... 错误: {e}")
    except Exception as e:
        raise ValueError(f"解析 JSON 时出错: {e}")

def get_format_instructions(self) -> str:
    """
    获取格式说明,用于在提示词中指导 LLM 输出 JSON 格式

    Returns:
        str: 格式说明文本
    """
    if self.pydantic_object is not None:
        # 如果有 Pydantic 模型,返回其 schema
        try:
            if hasattr(self.pydantic_object, 'model_json_schema'):
                schema = self.pydantic_object.model_json_schema()
            elif hasattr(self.pydantic_object, 'schema'):
                schema = self.pydantic_object.schema()
            else:
                schema = {}

            return f"""请以 JSON 格式输出,格式如下:

```json {schema} ```

确保输出是有效的 JSON 格式。""" except Exception: pass

    return """请以 JSON 格式输出你的回答。

输出格式要求:

  1. 使用有效的 JSON 格式
  2. 可以使用 markdown 代码块包裹:json ...
  3. 确保所有字符串都用双引号
  4. 确保 JSON 格式正确且完整

示例格式: ```json { "key": "value", "number": 123 } ```"""

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    if self.pydantic_object:
        return f"JsonOutputParser(pydantic_object={self.pydantic_object.__name__})"
    return "JsonOutputParser()"

定义 Pydantic 输出解析器类 #

class PydanticOutputParser(JsonOutputParser): """ Pydantic 输出解析器

将 LLM 的输出解析为 Pydantic 模型实例。继承自 JsonOutputParser,
先解析 JSON,然后验证并转换为 Pydantic 模型。

主要用于:
- 结构化数据验证
- 类型安全的数据提取
- 自动数据验证和转换
"""

def __init__(self, pydantic_object: type):
    """
    初始化 PydanticOutputParser

    Args:
        pydantic_object: Pydantic 模型类(必需)

    Raises:
        ValueError: 如果 pydantic_object 不是有效的 Pydantic 模型
    """
    if pydantic_object is None:
        raise ValueError("PydanticOutputParser 需要一个 Pydantic 模型类")

    # 检查是否是 Pydantic 模型
    try:
        import pydantic
        if not (issubclass(pydantic_object, pydantic.BaseModel) or 
                (hasattr(pydantic, 'v1') and issubclass(pydantic_object, pydantic.v1.BaseModel))):
            raise ValueError(f"{pydantic_object} 不是有效的 Pydantic 模型类")
    except ImportError:
        # 如果没有安装 pydantic,使用更宽松的检查
        if not hasattr(pydantic_object, '__fields__') and not hasattr(pydantic_object, 'model_fields'):
            raise ValueError(f"{pydantic_object} 可能不是有效的 Pydantic 模型类(请确保已安装 pydantic)")

    self.pydantic_object = pydantic_object
    # 调用父类初始化,传入 pydantic_object
    super().__init__(pydantic_object=pydantic_object)

def parse(self, text: str):
    """
    解析输出文本为 Pydantic 模型实例

    Args:
        text: 包含 JSON 的文本

    Returns:
        Pydantic 模型实例

    Raises:
        ValueError: 如果无法解析 JSON 或验证失败
    """
    try:
        # 先使用父类方法解析 JSON
        json_obj = super().parse(text)

        # 转换为 Pydantic 模型实例
        return self._parse_obj(json_obj)
    except Exception as e:
        raise ValueError(f"无法解析为 Pydantic 模型: {e}")

def _parse_obj(self, obj: dict):
    """
    将字典对象转换为 Pydantic 模型实例

    Args:
        obj: 字典对象

    Returns:
        Pydantic 模型实例

    Raises:
        ValueError: 如果验证失败
    """
    try:
        import pydantic

        # 尝试使用 Pydantic v2 的 model_validate
        if hasattr(self.pydantic_object, 'model_validate'):
            return self.pydantic_object.model_validate(obj)
        # 尝试使用 Pydantic v1 的 parse_obj
        elif hasattr(self.pydantic_object, 'parse_obj'):
            return self.pydantic_object.parse_obj(obj)
        # 尝试直接实例化
        else:
            return self.pydantic_object(**obj)
    except ImportError:
        # 如果没有 pydantic,尝试直接实例化
        return self.pydantic_object(**obj)
    except Exception as e:
        raise ValueError(f"Pydantic 验证失败: {e}")

def _get_schema(self) -> dict:
    """
    获取 Pydantic 模型的 JSON Schema

    Returns:
        dict: JSON Schema 字典
    """
    try:
        # 尝试使用 Pydantic v2 的 model_json_schema
        if hasattr(self.pydantic_object, 'model_json_schema'):
            return self.pydantic_object.model_json_schema()
        # 尝试使用 Pydantic v1 的 schema
        elif hasattr(self.pydantic_object, 'schema'):
            return self.pydantic_object.schema()
        else:
            return {}
    except Exception:
        return {}

def get_format_instructions(self) -> str:
    """
    获取格式说明,包含 Pydantic 模型的 Schema

    Returns:
        str: 格式说明文本
    """
    schema = self._get_schema()

    # 清理 schema,移除不必要的字段
    reduced_schema = dict(schema)
    if "title" in reduced_schema:
        del reduced_schema["title"]
    if "type" in reduced_schema:
        del reduced_schema["type"]

    schema_str = json.dumps(reduced_schema, ensure_ascii=False, indent=2)

    return f"""请以 JSON 格式输出,必须严格遵循以下 Schema:

```json {schema_str} ```

输出要求:

  1. 必须完全符合上述 Schema 结构
  2. 所有必需字段都必须提供
  3. 字段类型必须匹配(字符串、数字、布尔值等)
  4. 使用有效的 JSON 格式
  5. 可以使用 markdown 代码块包裹:json ...

确保输出是有效的 JSON,并且符合 Schema 定义。"""

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"PydanticOutputParser(pydantic_object={self.pydantic_object.__name__})"

+ +# 定义输出解析异常类 +class OutputParserException(ValueError):

  • """输出解析异常"""
  • def init(self, message: str, llm_output: str = ""):
  • super().init(message)
  • self.llm_output = llm_output + + +# 定义输出修复解析器类 +class OutputFixingParser(BaseOutputParser[T]):
  • """
  • 输出修复解析器
  • 包装一个基础解析器,当解析失败时,使用 LLM 自动修复输出。
  • 这是一个非常有用的功能,可以处理 LLM 输出格式不正确的情况。
  • 工作原理:
    1. 首先尝试使用基础解析器解析输出
    1. 如果解析失败,将错误信息和原始输出发送给 LLM
    1. LLM 根据格式说明修复输出
    1. 再次尝试解析修复后的输出
    1. 可以设置最大重试次数
  • """
  • def init(
  • self,
  • parser: BaseOutputParser[T],
  • retry_chain,
  • max_retries: int = 1,
  • ):
  • """
  • 初始化 OutputFixingParser
  • Args:
  • parser: 基础解析器
  • retry_chain: 用于修复输出的链(通常是 Prompt -> LLM -> StrOutputParser)
  • max_retries: 最大重试次数
  • """
  • self.parser = parser
  • self.retry_chain = retry_chain
  • self.max_retries = max_retries
  • @classmethod
  • def from_llm(
  • cls,
  • llm,
  • parser: BaseOutputParser[T],
  • prompt=None,
  • max_retries: int = 1,
  • ) -> "OutputFixingParser[T]":
  • """
  • 从 LLM 创建 OutputFixingParser
  • Args:
  • llm: 用于修复输出的语言模型
  • parser: 基础解析器
  • prompt: 修复提示模板(可选,有默认模板)
  • max_retries: 最大重试次数
  • Returns:
  • OutputFixingParser 实例
  • """
  • from langchain.prompts import PromptTemplate
  • from langchain.output_parsers import StrOutputParser
  • 默认修复提示模板 #

  • if prompt is None:
  • fix_template = """Instructions: +-------------- +{instructions} +-------------- +Completion: +-------------- +{completion} +-------------- + +上面的 Completion 没有满足 Instructions 中的约束要求。 +错误信息: +-------------- +{error} +-------------- + +请修复输出,确保它满足 Instructions 中的所有约束要求。只返回修复后的输出,不要包含其他内容:"""
  • prompt = PromptTemplate.from_template(fix_template)
  • 创建修复链:Prompt -> LLM -> StrOutputParser #

  • retry_chain = _SimpleChain(prompt, llm, StrOutputParser())
  • return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)
  • def parse(self, completion: str) -> T:
  • """
  • 解析输出,如果失败则尝试修复
  • Args:
  • completion: LLM 的输出文本
  • Returns:
  • T: 解析后的结果
  • Raises:
  • OutputParserException: 如果修复后仍然无法解析
  • """
  • retries = 0
  • original_completion = completion
  • while retries <= self.max_retries:
  • try:
  • 尝试使用基础解析器解析 #

  • return self.parser.parse(completion)
  • except (ValueError, OutputParserException, Exception) as e:
  • 如果已达到最大重试次数,抛出异常 #

  • if retries >= self.max_retries:
  • raise OutputParserException(
  • f"解析失败,已重试 {retries} 次: {e}",
  • llm_output=completion
  • )
  • retries += 1
  • print(f" 第 {retries} 次尝试修复...")
  • 获取格式说明(如果解析器支持) #

  • try:
  • instructions = self.parser.get_format_instructions()
  • except (AttributeError, NotImplementedError):
  • instructions = "请确保输出格式正确。"
  • 使用 LLM 修复输出 #

  • try:
  • if hasattr(self.retry_chain, 'invoke'):
  • 新式链式调用 #

  • completion = self.retry_chain.invoke({
  • "instructions": instructions,
  • "completion": completion,
  • "error": str(e),
  • })
  • elif hasattr(self.retry_chain, 'run'):
  • 旧式链式调用 #

  • completion = self.retry_chain.run(
  • instructions=instructions,
  • completion=completion,
  • error=str(e),
  • )
  • else:
  • 直接调用(作为函数) #

  • completion = self.retry_chain({
  • "instructions": instructions,
  • "completion": completion,
  • "error": str(e),
  • })
  • 确保返回的是字符串 #

  • if not isinstance(completion, str):
  • completion = str(completion)
  • except Exception as fix_error:
  • raise OutputParserException(
  • f"修复输出时出错: {fix_error}",
  • llm_output=completion
  • )
  • raise OutputParserException(
  • f"解析失败,已重试 {self.max_retries} 次",
  • llm_output=completion
  • )
  • def get_format_instructions(self) -> str:
  • """
  • 获取格式说明(委托给基础解析器)
  • Returns:
  • str: 格式说明文本
  • """
  • try:
  • return self.parser.get_format_instructions()
  • except (AttributeError, NotImplementedError):
  • return "请确保输出格式正确。"
  • def repr(self) -> str:
  • """返回解析器的字符串表示"""
  • return f"OutputFixingParser(parser={self.parser}, max_retries={self.max_retries})" + + +# 简单的链式调用包装类 +class _SimpleChain:
  • """简单的链式调用包装类,用于连接 Prompt -> LLM -> Parser"""
  • def init(self, prompt, llm, parser):
  • self.prompt = prompt
  • self.llm = llm
  • self.parser = parser
  • def invoke(self, input_dict: dict) -> str:
  • """调用链"""
  • 格式化提示词 #

  • formatted = self.prompt.format(**input_dict)
  • 调用 LLM #

  • response = self.llm.invoke(formatted)
  • 解析输出 #

  • if hasattr(response, 'content'):
  • content = response.content
  • else:
  • content = str(response)
  • 使用解析器解析 #

  • return self.parser.parse(content)
  • def run(self, **kwargs) -> str:
  • """运行链(兼容旧接口)"""
  • return self.invoke(kwargs) + `

20. RetryOutputParser #

在实际使用 LLM 时,模型可能生成不符合期望格式(如 JSON/Python 对象等)的输出,导致解析失败。如果解析器(如 JsonOutputParser/PydanticOutputParser 等)无法解析 LLM 的回复,RetryOutputParser 能自动利用 LLM 资源进行重试,直到输出可被正确解析或达到最大重试次数。

两个主要类:

  • RetryOutputParser:在底层解析失败时,用 LLM 基于同样的 prompt(以及原始 completion)让模型“修正”输出格式。
  • RetryWithErrorOutputParser:除了原始 prompt 与 completion,还会将解析错误信息一并传递给 LLM,提升修正能力。

典型用法流程:

  1. 定义一个“基础解析器”,如 JsonOutputParser。
  2. 用 LLM 及基础 parser 构造 RetryOutputParser,获得有自动容错能力的 parser。
  3. 用 .parse_with_prompt 方法解析 LLM 输出(注意不能用 .parse,后者跳过重试功能)。
  4. 如果输出不正确格式,自动尝试让 LLM 修正并重试。

适用场景:

  • 需要格式化输出,并希望在 LLM 偶发“格式崩坏”时自动补救。
  • LLM 输出复杂结构体、JSON 等,确保结构可靠。

20.1. RetryOutputParser.py #

20.RetryOutputParser.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
#from langchain_classic.output_parsers import RetryOutputParser, RetryWithErrorOutputParser
#from pydantic import BaseModel, Field

from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.output_parsers import RetryOutputParser, RetryWithErrorOutputParser, JsonOutputParser, PydanticOutputParser
from pydantic import BaseModel, Field
import json

# 简单的 PromptValue 包装类(用于演示)
class _SimplePromptValue:
    """简单的 PromptValue 包装类,用于 RetryOutputParser"""
    def __init__(self, text: str):
        self.text = text

    def to_string(self) -> str:
        """返回提示词字符串"""
        return self.text

print("=" * 60)
print("RetryOutputParser 和 RetryWithErrorOutputParser 演示")
print("=" * 60)
print("\nRetryOutputParser 和 RetryWithErrorOutputParser 是智能的重试解析器。")
print("当基础解析器解析失败时,它们会使用 LLM 重新生成输出。")
print("\n主要区别:")
print("- RetryOutputParser: 将原始 prompt 和 completion 传递给 LLM")
print("- RetryWithErrorOutputParser: 还将错误信息传递给 LLM(提供更多上下文)")
print("\n注意:这两个解析器需要使用 parse_with_prompt 方法,而不是 parse 方法。\n")

# 创建模型
llm = ChatOpenAI(model="gpt-4o")

print("=" * 60)
print("演示 1: RetryOutputParser 基本用法")
print("=" * 60)

# 创建基础解析器
json_parser = JsonOutputParser()

# 创建重试解析器
retry_parser = RetryOutputParser.from_llm(
    llm=llm,
    parser=json_parser,
    max_retries=2,
)

print("\n测试 RetryOutputParser:")

# 创建提示模板
prompt = PromptTemplate.from_template(
    """请以 JSON 格式输出以下信息:
- 姓名:{name}
- 年龄:{age}
- 城市:{city}

{format_instructions}

请输出 JSON:"""
)

# 格式化提示词
formatted_prompt = prompt.format(
    name="张三",
    age=30,
    city="北京",
    format_instructions=json_parser.get_format_instructions()
)

# 模拟一个格式不正确的 LLM 输出
invalid_completion = 'name: "张三", age: 30, city: "北京"'  # 不是有效的 JSON

print(f"\n原始 prompt: {formatted_prompt[:100]}...")
print(f"无效的 completion: {invalid_completion}")
print("-" * 60)

try:
    # 直接使用基础解析器(应该失败)
    try:
        result = json_parser.parse(invalid_completion)
        print("基础解析器:解析成功(意外)")
    except Exception as e:
        print(f"基础解析器:解析失败 ✓")
        print(f"  错误:{type(e).__name__}")

    # 使用 RetryOutputParser
    print("\n使用 RetryOutputParser:")
    print("  正在使用 LLM 重新生成输出...")

    # 创建简单的 PromptValue
    prompt_value = _SimplePromptValue(formatted_prompt)

    result = retry_parser.parse_with_prompt(invalid_completion, prompt_value)
    print(f"  重试后解析成功!")
    print(f"  结果:{json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
    print(f"  重试失败:{type(e).__name__}: {e}")

print("\n" + "=" * 60)
print("演示 2: RetryWithErrorOutputParser 基本用法")
print("=" * 60)

# 创建带错误信息的重试解析器
retry_with_error_parser = RetryWithErrorOutputParser.from_llm(
    llm=llm,
    parser=json_parser,
    max_retries=2,
)

print("\n测试 RetryWithErrorOutputParser(包含错误信息):")

invalid_completion2 = "{'name': '李四', 'age': 25, 'city': '上海'}"  # 单引号,不是有效 JSON

print(f"\n无效的 completion: {invalid_completion2}")
print("-" * 60)

try:
    # 使用 RetryWithErrorOutputParser
    print("\n使用 RetryWithErrorOutputParser:")
    print("  正在使用 LLM 重新生成输出(包含错误信息)...")

    result = retry_with_error_parser.parse_with_prompt(invalid_completion2, prompt_value)
    print(f"  重试后解析成功!")
    print(f"  结果:{json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
    print(f"  重试失败:{type(e).__name__}: {e}")

print("\n" + "=" * 60)
print("演示 3: 与 LLM 输出结合使用")
print("=" * 60)

# 定义 Pydantic 模型
class Person(BaseModel):
    """人员信息模型"""
    name: str = Field(description="姓名")
    age: int = Field(description="年龄", ge=0, le=150)
    email: str = Field(description="邮箱地址")
    city: str = Field(description="所在城市", default="未知")

# 创建 Pydantic 解析器
pydantic_parser = PydanticOutputParser(pydantic_object=Person)

# 创建重试解析器
pydantic_retry_parser = RetryWithErrorOutputParser.from_llm(
    llm=llm,
    parser=pydantic_parser,
    max_retries=2,
)

# 创建提示模板
person_prompt = PromptTemplate.from_template(
    """从以下文本中提取人员信息。

文本:{text}

{format_instructions}

请提取信息并以 JSON 格式输出:"""
)

print("\n从文本中提取信息(可能输出格式不正确):")

texts = [
    "我叫张三,今年30岁,邮箱是zhangsan@example.com,住在北京。",
    "李四,28岁,邮箱lisi@test.com,上海人。",
]

for i, text in enumerate(texts, 1):
    print(f"\n文本 {i}: {text}")
    print("-" * 60)

    # 格式化提示词
    formatted = person_prompt.format(
        text=text,
        format_instructions=pydantic_parser.get_format_instructions()
    )

    # 调用 LLM(可能输出格式不正确)
    response = llm.invoke(formatted)
    print(f"LLM 原始输出:{response.content[:150]}...")

    # 使用重试解析器解析(自动重试)
    try:
        prompt_val = _SimplePromptValue(formatted)
        person = pydantic_retry_parser.parse_with_prompt(response.content, prompt_val)
        print(f"\n✓ 解析成功(可能经过重试):")
        print(f"  姓名:{person.name}")
        print(f"  年龄:{person.age}")
        print(f"  邮箱:{person.email}")
        print(f"  城市:{person.city}")
    except Exception as e:
        print(f"\n✗ 解析失败:{e}")

print("\n" + "=" * 60)
print("演示 4: RetryOutputParser vs RetryWithErrorOutputParser")
print("=" * 60)

print("\n比较两种解析器的效果:")

# 测试用例:缺少必需字段
test_completion = '{"name": "王五", "age": 28}'  # 缺少 email 字段

print(f"\n测试 completion: {test_completion}")
print("(缺少 email 字段,Pydantic 验证会失败)")
print("-" * 60)

# 使用 RetryOutputParser
print("\n1. 使用 RetryOutputParser(不包含错误信息):")
retry_parser_pydantic = RetryOutputParser.from_llm(
    llm=llm,
    parser=pydantic_parser,
    max_retries=1,
)

try:
    prompt_val = _SimplePromptValue(formatted)
    result1 = retry_parser_pydantic.parse_with_prompt(test_completion, prompt_val)
    print(f"  ✓ 成功:{result1.name}, {result1.age}, {result1.email}")
except Exception as e:
    print(f"  ✗ 失败:{type(e).__name__}")

# 使用 RetryWithErrorOutputParser
print("\n2. 使用 RetryWithErrorOutputParser(包含错误信息):")
try:
    result2 = pydantic_retry_parser.parse_with_prompt(test_completion, prompt_val)
    print(f"  ✓ 成功:{result2.name}, {result2.age}, {result2.email}")
except Exception as e:
    print(f"  ✗ 失败:{type(e).__name__}")

print("\n" + "=" * 60)
print("演示 5: 自定义重试提示")
print("=" * 60)

# 创建自定义重试提示
custom_retry_prompt = PromptTemplate.from_template(
    """原始提示:
{prompt}

原始输出(有错误):
{completion}

错误信息:
{error}

请根据原始提示的要求和错误信息,重新生成一个正确的输出。
确保:
1. 完全符合原始提示的要求
2. 修复所有错误
3. 格式正确

只返回修复后的输出,不要包含其他说明:"""
)

# 创建自定义重试解析器
custom_retry_parser = RetryWithErrorOutputParser.from_llm(
    llm=llm,
    parser=json_parser,
    prompt=custom_retry_prompt,
    max_retries=1,
)

print("\n使用自定义重试提示:")
test_output = '{name: "测试", age: "25"}'  # 无效 JSON

try:
    prompt_val = _SimplePromptValue(formatted_prompt)
    result = custom_retry_parser.parse_with_prompt(test_output, prompt_val)
    print(f"修复成功:{json.dumps(result, ensure_ascii=False, indent=2)}")
except Exception as e:
    print(f"修复失败:{e}")

print("\n" + "=" * 60)
print("演示 6: 实际应用场景")
print("=" * 60)

print("\n实际应用:处理可能格式不正确的 LLM 输出")

# 定义一个更复杂的模型
class ProductReview(BaseModel):
    """产品评论模型"""
    product_name: str = Field(description="产品名称")
    rating: int = Field(description="评分", ge=1, le=5)
    pros: list[str] = Field(description="优点列表")
    cons: list[str] = Field(description="缺点列表")
    summary: str = Field(description="总结")

review_parser = PydanticOutputParser(pydantic_object=ProductReview)
review_retry_parser = RetryWithErrorOutputParser.from_llm(
    llm=llm,
    parser=review_parser,
    max_retries=2,
)

review_prompt = PromptTemplate.from_template(
    """分析以下产品评论,提取关键信息。

评论:{review}

{format_instructions}

请提取信息并以 JSON 格式输出:"""
)

reviews = [
    "这个手机非常好用,屏幕清晰,拍照效果很棒,电池续航也不错。但是价格有点贵,而且有点重。总体来说很不错。",
]

for i, review in enumerate(reviews, 1):
    print(f"\n评论 {i}: {review}")
    print("-" * 60)

    formatted = review_prompt.format(
        review=review,
        format_instructions=review_parser.get_format_instructions()
    )

    response = llm.invoke(formatted)

    try:
        prompt_val = _SimplePromptValue(formatted)
        result = review_retry_parser.parse_with_prompt(response.content, prompt_val)
        print(f"✓ 提取成功:")
        print(f"  产品:{result.product_name}")
        print(f"  评分:{result.rating}/5")
        print(f"  优点:{', '.join(result.pros)}")
        print(f"  缺点:{', '.join(result.cons)}")
        print(f"  总结:{result.summary}")
    except Exception as e:
        print(f"✗ 提取失败:{e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRetryOutputParser 和 RetryWithErrorOutputParser 的主要特点:")
print("1. 需要原始 prompt:使用 parse_with_prompt 方法而不是 parse 方法")
print("2. 重新生成输出:让 LLM 根据 prompt 重新生成,而不是修复现有输出")
print("3. 提供上下文:RetryWithErrorOutputParser 还提供错误信息,效果更好")
print("4. 可配置:可以设置最大重试次数和自定义提示模板")
print("\n与 OutputFixingParser 的区别:")
print("- OutputFixingParser: 修复现有输出,只需要 completion 和格式说明")
print("- RetryOutputParser: 重新生成输出,需要原始 prompt 和 completion")
print("- RetryWithErrorOutputParser: 重新生成输出,还包含错误信息")
print("\n使用场景:")
print("- 处理可能格式不正确的 LLM 输出")
print("- 需要根据原始 prompt 重新生成输出的场景")
print("- 提高解析器的健壮性")
print("\n注意事项:")
print("- 必须使用 parse_with_prompt 方法")
print("- 需要提供原始 prompt(字符串或 PromptValue 对象)")
print("- 会增加 API 调用次数和成本")
print("- 设置合理的 max_retries 以避免无限循环")

20.2. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
import json
import re

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"


# 辅助函数:从文本中提取 JSON(支持 markdown 代码块)
def parse_json_markdown(text: str) -> Any:
    """
    从文本中解析 JSON,支持 markdown 代码块格式

    Args:
        text: 可能包含 JSON 的文本

    Returns:
        解析后的 JSON 对象

    Raises:
        json.JSONDecodeError: 如果无法解析 JSON
    """
    # 去除首尾空白
    text = text.strip()

    # 尝试匹配 markdown 代码块中的 JSON
    # 匹配 ```json ... ``` 或 ``` ...
json_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if json_match:
    text = json_match.group(1).strip()

# 尝试匹配 { ... } 或 [ ... ]
json_match = re.search(r'(\{.*\}|$$.*$$)', text, re.DOTALL)
if json_match:
    text = json_match.group(1)

# 解析 JSON
return json.loads(text)

定义 JSON 输出解析器类 #

class JsonOutputParser(BaseOutputParser[Any]): """ JSON 输出解析器

将 LLM 的输出解析为 JSON 对象。支持:
- 纯 JSON 字符串
- Markdown 代码块中的 JSON(```json ... ```)
- 包含 JSON 的文本(自动提取)

主要用于:
- 结构化数据提取
- API 响应解析
- 数据格式化
"""

def __init__(self, pydantic_object: type = None):
    """
    初始化 JsonOutputParser

    Args:
        pydantic_object: 可选的 Pydantic 模型类,用于验证 JSON 结构
    """
    self.pydantic_object = pydantic_object

def parse(self, text: str) -> Any:
    """
    解析 JSON 输出文本

    Args:
        text: 包含 JSON 的文本

    Returns:
        Any: 解析后的 JSON 对象(字典、列表等)

    Raises:
        ValueError: 如果无法解析 JSON
    """
    try:
        # 使用辅助函数解析 JSON
        parsed = parse_json_markdown(text)

        # 如果指定了 Pydantic 模型,进行验证
        if self.pydantic_object is not None:
            try:
                # 尝试使用 Pydantic 验证(如果可用)
                if hasattr(self.pydantic_object, 'model_validate'):
                    return self.pydantic_object.model_validate(parsed)
                elif hasattr(self.pydantic_object, 'parse_obj'):
                    return self.pydantic_object.parse_obj(parsed)
            except Exception:
                # 如果验证失败,返回原始解析结果
                pass

        return parsed
    except json.JSONDecodeError as e:
        raise ValueError(f"无法解析 JSON 输出: {text[:100]}... 错误: {e}")
    except Exception as e:
        raise ValueError(f"解析 JSON 时出错: {e}")

def get_format_instructions(self) -> str:
    """
    获取格式说明,用于在提示词中指导 LLM 输出 JSON 格式

    Returns:
        str: 格式说明文本
    """
    if self.pydantic_object is not None:
        # 如果有 Pydantic 模型,返回其 schema
        try:
            if hasattr(self.pydantic_object, 'model_json_schema'):
                schema = self.pydantic_object.model_json_schema()
            elif hasattr(self.pydantic_object, 'schema'):
                schema = self.pydantic_object.schema()
            else:
                schema = {}

            return f"""请以 JSON 格式输出,格式如下:

```json {schema} ```

确保输出是有效的 JSON 格式。""" except Exception: pass

    return """请以 JSON 格式输出你的回答。

输出格式要求:

  1. 使用有效的 JSON 格式
  2. 可以使用 markdown 代码块包裹:json ...
  3. 确保所有字符串都用双引号
  4. 确保 JSON 格式正确且完整

示例格式: ```json { "key": "value", "number": 123 } ```"""

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    if self.pydantic_object:
        return f"JsonOutputParser(pydantic_object={self.pydantic_object.__name__})"
    return "JsonOutputParser()"

定义 Pydantic 输出解析器类 #

class PydanticOutputParser(JsonOutputParser): """ Pydantic 输出解析器

将 LLM 的输出解析为 Pydantic 模型实例。继承自 JsonOutputParser,
先解析 JSON,然后验证并转换为 Pydantic 模型。

主要用于:
- 结构化数据验证
- 类型安全的数据提取
- 自动数据验证和转换
"""

def __init__(self, pydantic_object: type):
    """
    初始化 PydanticOutputParser

    Args:
        pydantic_object: Pydantic 模型类(必需)

    Raises:
        ValueError: 如果 pydantic_object 不是有效的 Pydantic 模型
    """
    if pydantic_object is None:
        raise ValueError("PydanticOutputParser 需要一个 Pydantic 模型类")

    # 检查是否是 Pydantic 模型
    try:
        import pydantic
        if not (issubclass(pydantic_object, pydantic.BaseModel) or 
                (hasattr(pydantic, 'v1') and issubclass(pydantic_object, pydantic.v1.BaseModel))):
            raise ValueError(f"{pydantic_object} 不是有效的 Pydantic 模型类")
    except ImportError:
        # 如果没有安装 pydantic,使用更宽松的检查
        if not hasattr(pydantic_object, '__fields__') and not hasattr(pydantic_object, 'model_fields'):
            raise ValueError(f"{pydantic_object} 可能不是有效的 Pydantic 模型类(请确保已安装 pydantic)")

    self.pydantic_object = pydantic_object
    # 调用父类初始化,传入 pydantic_object
    super().__init__(pydantic_object=pydantic_object)

def parse(self, text: str):
    """
    解析输出文本为 Pydantic 模型实例

    Args:
        text: 包含 JSON 的文本

    Returns:
        Pydantic 模型实例

    Raises:
        ValueError: 如果无法解析 JSON 或验证失败
    """
    try:
        # 先使用父类方法解析 JSON
        json_obj = super().parse(text)

        # 转换为 Pydantic 模型实例
        return self._parse_obj(json_obj)
    except Exception as e:
        raise ValueError(f"无法解析为 Pydantic 模型: {e}")

def _parse_obj(self, obj: dict):
    """
    将字典对象转换为 Pydantic 模型实例

    Args:
        obj: 字典对象

    Returns:
        Pydantic 模型实例

    Raises:
        ValueError: 如果验证失败
    """
    try:
        import pydantic

        # 尝试使用 Pydantic v2 的 model_validate
        if hasattr(self.pydantic_object, 'model_validate'):
            return self.pydantic_object.model_validate(obj)
        # 尝试使用 Pydantic v1 的 parse_obj
        elif hasattr(self.pydantic_object, 'parse_obj'):
            return self.pydantic_object.parse_obj(obj)
        # 尝试直接实例化
        else:
            return self.pydantic_object(**obj)
    except ImportError:
        # 如果没有 pydantic,尝试直接实例化
        return self.pydantic_object(**obj)
    except Exception as e:
        raise ValueError(f"Pydantic 验证失败: {e}")

def _get_schema(self) -> dict:
    """
    获取 Pydantic 模型的 JSON Schema

    Returns:
        dict: JSON Schema 字典
    """
    try:
        # 尝试使用 Pydantic v2 的 model_json_schema
        if hasattr(self.pydantic_object, 'model_json_schema'):
            return self.pydantic_object.model_json_schema()
        # 尝试使用 Pydantic v1 的 schema
        elif hasattr(self.pydantic_object, 'schema'):
            return self.pydantic_object.schema()
        else:
            return {}
    except Exception:
        return {}

def get_format_instructions(self) -> str:
    """
    获取格式说明,包含 Pydantic 模型的 Schema

    Returns:
        str: 格式说明文本
    """
    schema = self._get_schema()

    # 清理 schema,移除不必要的字段
    reduced_schema = dict(schema)
    if "title" in reduced_schema:
        del reduced_schema["title"]
    if "type" in reduced_schema:
        del reduced_schema["type"]

    schema_str = json.dumps(reduced_schema, ensure_ascii=False, indent=2)

    return f"""请以 JSON 格式输出,必须严格遵循以下 Schema:

```json {schema_str} ```

输出要求:

  1. 必须完全符合上述 Schema 结构
  2. 所有必需字段都必须提供
  3. 字段类型必须匹配(字符串、数字、布尔值等)
  4. 使用有效的 JSON 格式
  5. 可以使用 markdown 代码块包裹:json ...

确保输出是有效的 JSON,并且符合 Schema 定义。"""

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"PydanticOutputParser(pydantic_object={self.pydantic_object.__name__})"

定义输出解析异常类 #

class OutputParserException(ValueError): """输出解析异常""" def init(self, message: str, llm_output: str = ""): super().init(message) self.llm_output = llm_output

定义输出修复解析器类 #

class OutputFixingParser(BaseOutputParser[T]): """ 输出修复解析器

包装一个基础解析器,当解析失败时,使用 LLM 自动修复输出。
这是一个非常有用的功能,可以处理 LLM 输出格式不正确的情况。

工作原理:
1. 首先尝试使用基础解析器解析输出
2. 如果解析失败,将错误信息和原始输出发送给 LLM
3. LLM 根据格式说明修复输出
4. 再次尝试解析修复后的输出
5. 可以设置最大重试次数
"""

def __init__(
    self,
    parser: BaseOutputParser[T],
    retry_chain,
    max_retries: int = 1,
):
    """
    初始化 OutputFixingParser

    Args:
        parser: 基础解析器
        retry_chain: 用于修复输出的链(通常是 Prompt -> LLM -> StrOutputParser)
        max_retries: 最大重试次数
    """
    self.parser = parser
    self.retry_chain = retry_chain
    self.max_retries = max_retries

@classmethod
def from_llm(
    cls,
    llm,
    parser: BaseOutputParser[T],
    prompt=None,
    max_retries: int = 1,
) -> "OutputFixingParser[T]":
    """
    从 LLM 创建 OutputFixingParser

    Args:
        llm: 用于修复输出的语言模型
        parser: 基础解析器
        prompt: 修复提示模板(可选,有默认模板)
        max_retries: 最大重试次数

    Returns:
        OutputFixingParser 实例
    """
    from langchain.prompts import PromptTemplate
    from langchain.output_parsers import StrOutputParser

    # 默认修复提示模板
    if prompt is None:
        fix_template = """Instructions:

{instructions} #

Completion: #

{completion} #

上面的 Completion 没有满足 Instructions 中的约束要求。

错误信息: #

{error} #

请修复输出,确保它满足 Instructions 中的所有约束要求。只返回修复后的输出,不要包含其他内容:""" prompt = PromptTemplate.from_template(fix_template)

    # 创建修复链:Prompt -> LLM -> StrOutputParser
    retry_chain = _SimpleChain(prompt, llm, StrOutputParser())

    return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)

def parse(self, completion: str) -> T:
    """
    解析输出,如果失败则尝试修复

    Args:
        completion: LLM 的输出文本

    Returns:
        T: 解析后的结果

    Raises:
        OutputParserException: 如果修复后仍然无法解析
    """
    retries = 0
    original_completion = completion

    while retries <= self.max_retries:
        try:
            # 尝试使用基础解析器解析
            return self.parser.parse(completion)
        except (ValueError, OutputParserException, Exception) as e:
            # 如果已达到最大重试次数,抛出异常
            if retries >= self.max_retries:
                raise OutputParserException(
                    f"解析失败,已重试 {retries} 次: {e}",
                    llm_output=completion
                )

            retries += 1
            print(f"  第 {retries} 次尝试修复...")

            # 获取格式说明(如果解析器支持)
            try:
                instructions = self.parser.get_format_instructions()
            except (AttributeError, NotImplementedError):
                instructions = "请确保输出格式正确。"

            # 使用 LLM 修复输出
            try:
                if hasattr(self.retry_chain, 'invoke'):
                    # 新式链式调用
                    completion = self.retry_chain.invoke({
                        "instructions": instructions,
                        "completion": completion,
                        "error": str(e),
                    })
                elif hasattr(self.retry_chain, 'run'):
                    # 旧式链式调用
                    completion = self.retry_chain.run(
                        instructions=instructions,
                        completion=completion,
                        error=str(e),
                    )
                else:
                    # 直接调用(作为函数)
                    completion = self.retry_chain({
                        "instructions": instructions,
                        "completion": completion,
                        "error": str(e),
                    })

                # 确保返回的是字符串
                if not isinstance(completion, str):
                    completion = str(completion)

            except Exception as fix_error:
                raise OutputParserException(
                    f"修复输出时出错: {fix_error}",
                    llm_output=completion
                )

    raise OutputParserException(
        f"解析失败,已重试 {self.max_retries} 次",
        llm_output=completion
    )

def get_format_instructions(self) -> str:
    """
    获取格式说明(委托给基础解析器)

    Returns:
        str: 格式说明文本
    """
    try:
        return self.parser.get_format_instructions()
    except (AttributeError, NotImplementedError):
        return "请确保输出格式正确。"

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"OutputFixingParser(parser={self.parser}, max_retries={self.max_retries})"

简单的链式调用包装类 #

class _SimpleChain: """简单的链式调用包装类,用于连接 Prompt -> LLM -> Parser"""

def __init__(self, prompt, llm, parser):
    self.prompt = prompt
    self.llm = llm
    self.parser = parser

def invoke(self, input_dict: dict) -> str:
    """调用链"""
    # 格式化提示词
    formatted = self.prompt.format(**input_dict)
    # 调用 LLM
    response = self.llm.invoke(formatted)
    # 解析输出
    if hasattr(response, 'content'):
        content = response.content
    else:
        content = str(response)
    # 使用解析器解析
    return self.parser.parse(content)

def run(self, **kwargs) -> str:
    """运行链(兼容旧接口)"""
    return self.invoke(kwargs)

+# 简单的 PromptValue 包装类 +class _SimplePromptValue:

  • """简单的 PromptValue 包装类,用于 RetryOutputParser"""
  • def init(self, text: str):
  • self.text = text
  • def to_string(self) -> str:
  • """返回提示词字符串"""
  • return self.text + + +# 定义重试输出解析器类 +class RetryOutputParser(BaseOutputParser[T]):
  • """
  • 重试输出解析器
  • 包装一个基础解析器,当解析失败时,使用 LLM 重新生成输出。
  • 与 OutputFixingParser 的区别:
    • RetryOutputParser 需要原始 prompt 和 completion
    • 它使用 parse_with_prompt 方法而不是 parse 方法
    • 它将原始 prompt 和 completion 都传递给 LLM,让 LLM 重新生成
  • 工作原理:
    1. 首先尝试使用基础解析器解析 completion
    1. 如果解析失败,将原始 prompt 和 completion 发送给 LLM
    1. LLM 根据 prompt 的要求重新生成输出
    1. 再次尝试解析新生成的输出
    1. 可以设置最大重试次数
  • """
  • def init(
  • self,
  • parser: BaseOutputParser[T],
  • retry_chain,
  • max_retries: int = 1,
  • ):
  • """
  • 初始化 RetryOutputParser
  • Args:
  • parser: 基础解析器
  • retry_chain: 用于重试的链(通常是 Prompt -> LLM -> StrOutputParser)
  • max_retries: 最大重试次数
  • """
  • self.parser = parser
  • self.retry_chain = retry_chain
  • self.max_retries = max_retries
  • @classmethod
  • def from_llm(
  • cls,
  • llm,
  • parser: BaseOutputParser[T],
  • prompt=None,
  • max_retries: int = 1,
  • ) -> "RetryOutputParser[T]":
  • """
  • 从 LLM 创建 RetryOutputParser
  • Args:
  • llm: 用于重试的语言模型
  • parser: 基础解析器
  • prompt: 重试提示模板(可选,有默认模板)
  • max_retries: 最大重试次数
  • Returns:
  • RetryOutputParser 实例
  • """
  • from langchain.prompts import PromptTemplate
  • from langchain.output_parsers import StrOutputParser
  • 默认重试提示模板 #

  • if prompt is None:
  • retry_template = """Prompt: +{prompt} +Completion: +{completion} + +上面的 Completion 没有满足 Prompt 中的约束要求。 +请重新生成一个满足要求的输出:"""
  • prompt = PromptTemplate.from_template(retry_template)
  • 创建重试链:Prompt -> LLM -> StrOutputParser #

  • retry_chain = _SimpleChain(prompt, llm, StrOutputParser())
  • return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)
  • def parse_with_prompt(self, completion: str, prompt_value) -> T:
  • """
  • 使用 prompt 解析输出,如果失败则尝试重试
  • Args:
  • completion: LLM 的输出文本
  • prompt_value: 原始提示词(可以是字符串或 PromptValue 对象)
  • Returns:
  • T: 解析后的结果
  • Raises:
  • OutputParserException: 如果重试后仍然无法解析
  • """
  • 将 prompt_value 转换为字符串 #

  • if hasattr(prompt_value, 'to_string'):
  • prompt_str = prompt_value.to_string()
  • else:
  • prompt_str = str(prompt_value)
  • retries = 0
  • while retries <= self.max_retries:
  • try:
  • 尝试使用基础解析器解析 #

  • return self.parser.parse(completion)
  • except (ValueError, OutputParserException, Exception) as e:
  • 如果已达到最大重试次数,抛出异常 #

  • if retries >= self.max_retries:
  • raise OutputParserException(
  • f"解析失败,已重试 {retries} 次: {e}",
  • llm_output=completion
  • )
  • retries += 1
  • print(f" 第 {retries} 次尝试重试...")
  • 使用 LLM 重新生成输出 #

  • try:
  • if hasattr(self.retry_chain, 'invoke'):
  • 新式链式调用 #

  • completion = self.retry_chain.invoke({
  • "prompt": prompt_str,
  • "completion": completion,
  • })
  • elif hasattr(self.retry_chain, 'run'):
  • 旧式链式调用 #

  • completion = self.retry_chain.run(
  • prompt=prompt_str,
  • completion=completion,
  • )
  • else:
  • 直接调用 #

  • completion = self.retry_chain({
  • "prompt": prompt_str,
  • "completion": completion,
  • })
  • 确保返回的是字符串 #

  • if not isinstance(completion, str):
  • completion = str(completion)
  • except Exception as retry_error:
  • raise OutputParserException(
  • f"重试输出时出错: {retry_error}",
  • llm_output=completion
  • )
  • raise OutputParserException(
  • f"解析失败,已重试 {self.max_retries} 次",
  • llm_output=completion
  • )
  • def parse(self, completion: str) -> T:
  • """
  • 此解析器只能通过 parse_with_prompt 方法调用
  • Raises:
  • NotImplementedError: 总是抛出此异常
  • """
  • raise NotImplementedError(
  • "RetryOutputParser 只能通过 parse_with_prompt 方法调用,"
  • "需要提供原始 prompt。"
  • )
  • def get_format_instructions(self) -> str:
  • """
  • 获取格式说明(委托给基础解析器)
  • Returns:
  • str: 格式说明文本
  • """
  • try:
  • return self.parser.get_format_instructions()
  • except (AttributeError, NotImplementedError):
  • return "请确保输出格式正确。"
  • def repr(self) -> str:
  • """返回解析器的字符串表示"""
  • return f"RetryOutputParser(parser={self.parser}, max_retries={self.max_retries})" + + +# 定义带错误信息的重试输出解析器类 +class RetryWithErrorOutputParser(BaseOutputParser[T]):
  • """
  • 带错误信息的重试输出解析器
  • 与 RetryOutputParser 类似,但会将错误信息也传递给 LLM。
  • 这为 LLM 提供了更多上下文,理论上能更好地修复输出。
  • 工作原理:
    1. 首先尝试使用基础解析器解析 completion
    1. 如果解析失败,将原始 prompt、completion 和错误信息发送给 LLM
    1. LLM 根据 prompt 的要求和错误信息重新生成输出
    1. 再次尝试解析新生成的输出
    1. 可以设置最大重试次数
  • """
  • def init(
  • self,
  • parser: BaseOutputParser[T],
  • retry_chain,
  • max_retries: int = 1,
  • ):
  • """
  • 初始化 RetryWithErrorOutputParser
  • Args:
  • parser: 基础解析器
  • retry_chain: 用于重试的链(通常是 Prompt -> LLM -> StrOutputParser)
  • max_retries: 最大重试次数
  • """
  • self.parser = parser
  • self.retry_chain = retry_chain
  • self.max_retries = max_retries
  • @classmethod
  • def from_llm(
  • cls,
  • llm,
  • parser: BaseOutputParser[T],
  • prompt=None,
  • max_retries: int = 1,
  • ) -> "RetryWithErrorOutputParser[T]":
  • """
  • 从 LLM 创建 RetryWithErrorOutputParser
  • Args:
  • llm: 用于重试的语言模型
  • parser: 基础解析器
  • prompt: 重试提示模板(可选,有默认模板)
  • max_retries: 最大重试次数
  • Returns:
  • RetryWithErrorOutputParser 实例
  • """
  • from langchain.prompts import PromptTemplate
  • from langchain.output_parsers import StrOutputParser
  • 默认重试提示模板(包含错误信息) #

  • if prompt is None:
  • retry_template = """Prompt: +{prompt} +Completion: +{completion} + +上面的 Completion 没有满足 Prompt 中的约束要求。 +错误详情: {error} +请根据错误信息重新生成一个满足要求的输出:"""
  • prompt = PromptTemplate.from_template(retry_template)
  • 创建重试链:Prompt -> LLM -> StrOutputParser #

  • retry_chain = _SimpleChain(prompt, llm, StrOutputParser())
  • return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)
  • def parse_with_prompt(self, completion: str, prompt_value) -> T:
  • """
  • 使用 prompt 解析输出,如果失败则尝试重试(包含错误信息)
  • Args:
  • completion: LLM 的输出文本
  • prompt_value: 原始提示词(可以是字符串或 PromptValue 对象)
  • Returns:
  • T: 解析后的结果
  • Raises:
  • OutputParserException: 如果重试后仍然无法解析
  • """
  • 将 prompt_value 转换为字符串 #

  • if hasattr(prompt_value, 'to_string'):
  • prompt_str = prompt_value.to_string()
  • else:
  • prompt_str = str(prompt_value)
  • retries = 0
  • while retries <= self.max_retries:
  • try:
  • 尝试使用基础解析器解析 #

  • return self.parser.parse(completion)
  • except (ValueError, OutputParserException, Exception) as e:
  • 如果已达到最大重试次数,抛出异常 #

  • if retries >= self.max_retries:
  • raise OutputParserException(
  • f"解析失败,已重试 {retries} 次: {e}",
  • llm_output=completion
  • )
  • retries += 1
  • print(f" 第 {retries} 次尝试重试(带错误信息)...")
  • 使用 LLM 重新生成输出(包含错误信息) #

  • try:
  • if hasattr(self.retry_chain, 'invoke'):
  • 新式链式调用 #

  • completion = self.retry_chain.invoke({
  • "prompt": prompt_str,
  • "completion": completion,
  • "error": str(e),
  • })
  • elif hasattr(self.retry_chain, 'run'):
  • 旧式链式调用 #

  • completion = self.retry_chain.run(
  • prompt=prompt_str,
  • completion=completion,
  • error=str(e),
  • )
  • else:
  • 直接调用 #

  • completion = self.retry_chain({
  • "prompt": prompt_str,
  • "completion": completion,
  • "error": str(e),
  • })
  • 确保返回的是字符串 #

  • if not isinstance(completion, str):
  • completion = str(completion)
  • except Exception as retry_error:
  • raise OutputParserException(
  • f"重试输出时出错: {retry_error}",
  • llm_output=completion
  • )
  • raise OutputParserException(
  • f"解析失败,已重试 {self.max_retries} 次",
  • llm_output=completion
  • )
  • def parse(self, completion: str) -> T:
  • """
  • 此解析器只能通过 parse_with_prompt 方法调用
  • Raises:
  • NotImplementedError: 总是抛出此异常
  • """
  • raise NotImplementedError(
  • "RetryWithErrorOutputParser 只能通过 parse_with_prompt 方法调用,"
  • "需要提供原始 prompt。"
  • )
  • def get_format_instructions(self) -> str:
  • """
  • 获取格式说明(委托给基础解析器)
  • Returns:
  • str: 格式说明文本
  • """
  • try:
  • return self.parser.get_format_instructions()
  • except (AttributeError, NotImplementedError):
  • return "请确保输出格式正确。"
  • def repr(self) -> str:
  • """返回解析器的字符串表示"""
  • return f"RetryWithErrorOutputParser(parser={self.parser}, max_retries={self.max_retries})" + `

21. BaseOutputParser #

21.1. BaseOutputParser.py #

21.BaseOutputParser.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser
#from langchain_core.exceptions import OutputParserException
#from pydantic import BaseModel, Field

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import BaseOutputParser, OutputParserException, JsonOutputParser
from pydantic import BaseModel, Field
import json
import re
from typing import List, Dict, Any

print("=" * 60)
print("BaseOutputParser 自定义解析器演示")
print("=" * 60)
print("\nBaseOutputParser 是创建自定义输出解析器的抽象基类。")
print("通过继承 BaseOutputParser 并实现 parse 方法,")
print("可以创建各种自定义的解析器来满足特定需求。\n")

# 创建模型
llm = ChatOpenAI(model="gpt-4o")

print("=" * 60)
print("演示 1: 布尔值解析器")
print("=" * 60)

class BooleanOutputParser(BaseOutputParser[bool]):
    """
    布尔值输出解析器

    将文本解析为布尔值。支持多种表示方式:
    - YES/NO
    - TRUE/FALSE
    - 是/否
    - 1/0
    """

    def __init__(self, true_val: str = "YES", false_val: str = "NO"):
        """
        初始化布尔值解析器

        Args:
            true_val: 表示 True 的值(默认 "YES")
            false_val: 表示 False 的值(默认 "NO")
        """
        # 使用 object.__setattr__ 来绕过可能的 Pydantic 验证
        try:
            super().__init__()
        except TypeError:
            pass  # 如果父类不需要参数

        object.__setattr__(self, 'true_val', true_val.upper())
        object.__setattr__(self, 'false_val', false_val.upper())

    def parse(self, text: str) -> bool:
        """
        解析文本为布尔值

        Args:
            text: 要解析的文本

        Returns:
            bool: 解析后的布尔值

        Raises:
            OutputParserException: 如果无法解析为布尔值
        """
        cleaned_text = text.strip().upper()

        # 支持多种表示方式
        true_values = [self.true_val, "TRUE", "YES", "是", "1", "Y"]
        false_values = [self.false_val, "FALSE", "NO", "否", "0", "N"]

        if cleaned_text in true_values:
            return True
        elif cleaned_text in false_values:
            return False
        else:
            raise OutputParserException(
                f"BooleanOutputParser 无法解析 '{text}'。"
                f"期望的值:{true_values} 或 {false_values}"
            )

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return f"请输出 {self.true_val} 或 {self.false_val}(不区分大小写)"

    def __repr__(self) -> str:
        return f"BooleanOutputParser(true_val={self.true_val}, false_val={self.false_val})"

# 测试布尔值解析器
print("\n测试布尔值解析器:")
bool_parser = BooleanOutputParser()

test_cases = [
    "YES",
    "no",
    "true",
    "FALSE",
    "是",
    "否",
    "1",
    "0",
]

for test in test_cases:
    try:
        result = bool_parser.parse(test)
        print(f"  '{test}' -> {result}")
    except Exception as e:
        print(f"  '{test}' -> 错误: {e}")

print("\n" + "=" * 60)
print("演示 2: 列表解析器")
print("=" * 60)

class ListOutputParser(BaseOutputParser[List[str]]):
    """
    列表输出解析器

    从文本中提取列表。支持多种格式:
    - 逗号分隔:apple, banana, orange
    - 换行分隔:apple\nbanana\norange
    - 编号列表:1. apple\n2. banana\n3. orange
    - JSON 数组:["apple", "banana", "orange"]
    """

    def __init__(self, separator: str = ",", remove_numbers: bool = True):
        """
        初始化列表解析器

        Args:
            separator: 分隔符(默认 ",")
            remove_numbers: 是否移除编号(默认 True)
        """
        try:
            super().__init__()
        except TypeError:
            pass

        object.__setattr__(self, 'separator', separator)
        object.__setattr__(self, 'remove_numbers', remove_numbers)

    def parse(self, text: str) -> List[str]:
        """
        解析文本为字符串列表

        Args:
            text: 要解析的文本

        Returns:
            List[str]: 解析后的列表
        """
        # 尝试解析 JSON 数组
        try:
            parsed = json.loads(text)
            if isinstance(parsed, list):
                return [str(item).strip() for item in parsed]
        except (json.JSONDecodeError, TypeError):
            pass

        # 按换行符分割
        if "\n" in text:
            items = text.split("\n")
        else:
            # 按分隔符分割
            items = text.split(self.separator)

        result = []
        for item in items:
            item = item.strip()
            if not item:
                continue

            # 移除编号(如 "1. apple" -> "apple")
            if self.remove_numbers:
                item = re.sub(r'^\d+[\.\)]\s*', '', item)

            # 移除引号
            item = item.strip('"\'')

            if item:
                result.append(item)

        return result

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return f"请以列表格式输出,使用 {self.separator} 分隔,或使用换行符分隔"

# 测试列表解析器
print("\n测试列表解析器:")
list_parser = ListOutputParser()

test_cases = [
    "apple, banana, orange",
    "apple\nbanana\norange",
    "1. apple\n2. banana\n3. orange",
    '["apple", "banana", "orange"]',
    "苹果, 香蕉, 橙子",
]

for test in test_cases:
    try:
        result = list_parser.parse(test)
        print(f"  输入: {test[:50]}...")
        print(f"  输出: {result}")
    except Exception as e:
        print(f"  错误: {e}")

print("\n" + "=" * 60)
print("演示 3: 键值对解析器")
print("=" * 60)

class KeyValueOutputParser(BaseOutputParser[Dict[str, str]]):
    """
    键值对输出解析器

    从文本中提取键值对。支持多种格式:
    - 冒号分隔:key1: value1\nkey2: value2
    - 等号分隔:key1=value1\nkey2=value2
    - JSON 对象:{"key1": "value1", "key2": "value2"}
    """

    def __init__(self, separator: str = ":"):
        """
        初始化键值对解析器

        Args:
            separator: 键值分隔符(默认 ":")
        """
        try:
            super().__init__()
        except TypeError:
            pass

        object.__setattr__(self, 'separator', separator)

    def parse(self, text: str) -> Dict[str, str]:
        """
        解析文本为键值对字典

        Args:
            text: 要解析的文本

        Returns:
            Dict[str, str]: 解析后的字典
        """
        # 尝试解析 JSON 对象
        try:
            parsed = json.loads(text)
            if isinstance(parsed, dict):
                return {str(k): str(v) for k, v in parsed.items()}
        except (json.JSONDecodeError, TypeError):
            pass

        result = {}
        lines = text.strip().split("\n")

        for line in lines:
            line = line.strip()
            if not line:
                continue

            # 尝试不同的分隔符
            for sep in [self.separator, "=", ":"]:
                if sep in line:
                    parts = line.split(sep, 1)
                    if len(parts) == 2:
                        key = parts[0].strip().strip('"\'')
                        value = parts[1].strip().strip('"\'')
                        result[key] = value
                        break

        return result

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return f"请以键值对格式输出,每行一个,使用 {self.separator} 分隔键和值"

# 测试键值对解析器
print("\n测试键值对解析器:")
kv_parser = KeyValueOutputParser()

test_cases = [
    "name: 张三\nage: 30\ncity: 北京",
    "name=李四\nage=25\ncity=上海",
    '{"name": "王五", "age": "28", "city": "深圳"}',
]

for test in test_cases:
    try:
        result = kv_parser.parse(test)
        print(f"  输入: {test[:50]}...")
        print(f"  输出: {result}")
    except Exception as e:
        print(f"  错误: {e}")

print("\n" + "=" * 60)
print("演示 4: 正则表达式解析器")
print("=" * 60)

class RegexOutputParser(BaseOutputParser[Dict[str, str]]):
    """
    正则表达式输出解析器

    使用正则表达式从文本中提取信息。
    可以定义多个模式来匹配不同的字段。
    """

    def __init__(self, patterns: Dict[str, str]):
        """
        初始化正则表达式解析器

        Args:
            patterns: 字段名到正则表达式模式的字典
        """
        try:
            super().__init__()
        except TypeError:
            pass

        compiled_patterns = {k: re.compile(v, re.IGNORECASE | re.MULTILINE) 
                            for k, v in patterns.items()}
        object.__setattr__(self, 'patterns', compiled_patterns)

    def parse(self, text: str) -> Dict[str, str]:
        """
        使用正则表达式解析文本

        Args:
            text: 要解析的文本

        Returns:
            Dict[str, str]: 匹配到的字段字典
        """
        result = {}
        for field, pattern in self.patterns.items():
            match = pattern.search(text)
            if match:
                # 如果有命名组,使用命名组的值
                if match.groupdict():
                    result.update(match.groupdict())
                else:
                    # 否则使用第一个捕获组或整个匹配
                    result[field] = match.group(1) if match.groups() else match.group(0)

        return result

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return f"请确保输出包含以下字段:{', '.join(self.patterns.keys())}"

# 测试正则表达式解析器
print("\n测试正则表达式解析器:")

# 定义模式:提取邮箱、电话、日期
patterns = {
    "email": r"([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})",
    "phone": r"(\d{3,4}[-.\s]?\d{7,8})",
    "date": r"(\d{4}[-/]\d{1,2}[-/]\d{1,2})",
}

regex_parser = RegexOutputParser(patterns)

test_text = """
联系方式:
邮箱:zhangsan@example.com
电话:138-0013-8000
日期:2024-01-15
"""

try:
    result = regex_parser.parse(test_text)
    print(f"  输入文本: {test_text.strip()}")
    print(f"  提取结果: {result}")
except Exception as e:
    print(f"  错误: {e}")

print("\n" + "=" * 60)
print("演示 5: 自定义格式解析器 - 提取评分和评论")
print("=" * 60)

class RatingReviewParser(BaseOutputParser[Dict[str, Any]]):
    """
    评分和评论解析器

    从文本中提取评分(1-5星)和评论内容。
    """

    def parse(self, text: str) -> Dict[str, Any]:
        """
        解析评分和评论

        Args:
            text: 要解析的文本

        Returns:
            Dict[str, Any]: 包含 rating 和 review 的字典
        """
        result = {
            "rating": None,
            "review": None
        }

        # 提取评分(1-5星)
        rating_patterns = [
            r"(\d+)\s*星",
            r"(\d+)/5",
            r"评分[::]\s*(\d+)",
            r"rating[::]\s*(\d+)",
        ]

        for pattern in rating_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                rating = int(match.group(1))
                if 1 <= rating <= 5:
                    result["rating"] = rating
                    break

        # 提取评论内容(移除评分部分)
        review_text = text
        for pattern in rating_patterns:
            review_text = re.sub(pattern, "", review_text, flags=re.IGNORECASE)

        result["review"] = review_text.strip()

        return result

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return "请以以下格式输出:评分:X星(1-5),评论:..."

# 测试评分和评论解析器
print("\n测试评分和评论解析器:")
rating_parser = RatingReviewParser()

test_cases = [
    "评分:5星,这个产品非常好用,强烈推荐!",
    "4/5 这个产品还不错,但有一些小问题。",
    "Rating: 3 一般般,没什么特别的地方。",
]

for test in test_cases:
    try:
        result = rating_parser.parse(test)
        print(f"  输入: {test}")
        print(f"  输出: 评分={result['rating']}, 评论={result['review'][:30]}...")
    except Exception as e:
        print(f"  错误: {e}")

print("\n" + "=" * 60)
print("演示 6: 与 LLM 结合使用")
print("=" * 60)

# 使用布尔值解析器
print("\n1. 使用布尔值解析器:")

bool_prompt = PromptTemplate.from_template(
    """判断以下陈述是否正确。

陈述:{statement}

{format_instructions}

请回答:"""
)

statement = "Python 是一种编程语言"

formatted = bool_prompt.format(
    statement=statement,
    format_instructions=bool_parser.get_format_instructions()
)

response = llm.invoke(formatted)
print(f"  陈述: {statement}")
print(f"  LLM 输出: {response.content}")

try:
    result = bool_parser.parse(response.content)
    print(f"  解析结果: {result}")
except Exception as e:
    print(f"  解析错误: {e}")

# 使用列表解析器
print("\n2. 使用列表解析器:")

list_prompt = PromptTemplate.from_template(
    """列出以下类别的前3个例子。

类别:{category}

{format_instructions}

请列出:"""
)

category = "水果"

formatted = list_prompt.format(
    category=category,
    format_instructions=list_parser.get_format_instructions()
)

response = llm.invoke(formatted)
print(f"  类别: {category}")
print(f"  LLM 输出: {response.content}")

try:
    result = list_parser.parse(response.content)
    print(f"  解析结果: {result}")
except Exception as e:
    print(f"  解析错误: {e}")

# 使用键值对解析器
print("\n3. 使用键值对解析器:")

kv_prompt = PromptTemplate.from_template(
    """从以下文本中提取关键信息。

文本:{text}

{format_instructions}

请提取:"""
)

text = "我叫张三,今年30岁,住在北京,职业是软件工程师"

formatted = kv_prompt.format(
    text=text,
    format_instructions=kv_parser.get_format_instructions()
)

response = llm.invoke(formatted)
print(f"  文本: {text}")
print(f"  LLM 输出: {response.content}")

try:
    result = kv_parser.parse(response.content)
    print(f"  解析结果: {result}")
except Exception as e:
    print(f"  解析错误: {e}")

print("\n" + "=" * 60)
print("演示 7: 组合使用多个解析器")
print("=" * 60)

class MultiFormatParser(BaseOutputParser[Dict[str, Any]]):
    """
    多格式解析器

    尝试多种解析方式,返回第一个成功的结果。
    """

    def __init__(self, parsers: List[BaseOutputParser]):
        """
        初始化多格式解析器

        Args:
            parsers: 解析器列表,按顺序尝试
        """
        try:
            super().__init__()
        except TypeError:
            pass

        object.__setattr__(self, 'parsers', parsers)

    def parse(self, text: str) -> Dict[str, Any]:
        """
        尝试使用多个解析器解析

        Args:
            text: 要解析的文本

        Returns:
            Dict[str, Any]: 解析结果
        """
        errors = []
        for parser in self.parsers:
            try:
                result = parser.parse(text)
                return {
                    "success": True,
                    "parser": type(parser).__name__,
                    "result": result
                }
            except Exception as e:
                errors.append(f"{type(parser).__name__}: {str(e)}")

        raise OutputParserException(
            f"所有解析器都失败了。错误:{'; '.join(errors)}"
        )

    def get_format_instructions(self) -> str:
        """获取格式说明"""
        return "请以以下任一格式输出:" + " 或 ".join([
            p.get_format_instructions() if hasattr(p, 'get_format_instructions') 
            else "标准格式" 
            for p in self.parsers
        ])

# 测试多格式解析器
print("\n测试多格式解析器:")
multi_parser = MultiFormatParser([
    JsonOutputParser(),
    kv_parser,
    list_parser,
])

test_cases = [
    '{"name": "测试", "value": "123"}',  # JSON
    "key1: value1\nkey2: value2",  # 键值对
    "item1, item2, item3",  # 列表
]

for test in test_cases:
    try:
        result = multi_parser.parse(test)
        print(f"  输入: {test[:50]}...")
        print(f"  成功解析器: {result['parser']}")
        print(f"  结果: {result['result']}")
    except Exception as e:
        print(f"  错误: {e}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\n创建自定义解析器的步骤:")
print("1. 继承 BaseOutputParser[T],其中 T 是返回类型")
print("2. 实现 parse(text: str) -> T 方法")
print("3. (可选)实现 get_format_instructions() 方法")
print("4. (可选)实现 __repr__() 方法")
print("\n演示的自定义解析器:")
print("1. BooleanOutputParser: 解析布尔值")
print("2. ListOutputParser: 解析列表")
print("3. KeyValueOutputParser: 解析键值对")
print("4. RegexOutputParser: 使用正则表达式解析")
print("5. RatingReviewParser: 解析评分和评论")
print("6. MultiFormatParser: 组合多个解析器")
print("\n使用场景:")
print("- 处理特定格式的输出")
print("- 提取结构化信息")
print("- 数据验证和转换")
print("- 多格式兼容")
print("\n注意事项:")
print("- 实现 parse 方法时必须处理异常")
print("- 提供清晰的错误信息")
print("- 考虑边界情况")
print("- 可以组合使用多个解析器")

22. LCEL表达式 #

LCEL 表达式(LangChain Expression Language)是一种可组合的链式表达式写法,可让你用 | (管道符)将多个组件(如解析器、函数、链、工具)串联起来,实现数据流式处理,极大提升可读性和模块化设计能力。

LCEL 的核心是“可运行体”Runnable,它代表了一个能够输入-输出处理的组件。你可以自定义组件或用库内置的解析器、链路,并通过 | 操作符组合它们。

LCEL 主要特性:

  • 可读性强:用 | 表示管道传递,一目了然。
  • 支持序列、分支、混合链路等复杂组合。
  • 自动处理每步的输入输出衔接,便于快速搭建数据流程。
  • 支持 invoke、batch、stream 调用模式,无缝兼容。

在本例中,我们扩展了 langchain,使其支持 LCEL 写法,并通过 RunnableSequence 实现自定义表达式链式调用。 通过 LCEL,你可以快速拼装、复用和调试复杂的数据处理链,极大提升开发效率。

22.1. init.py #

langchain/init.py

+# 导入 runnables 模块以自动启用 LCEL 支持
+import langchain.runnables

__all__ = []

22.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 RunnableSequence 类和为组件添加 | 操作符支持

class RunnableSequence:
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                result = step.invoke(result, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

22.3. LCEL.py #

22.LCEL.py

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
import json

print("=" * 60)
print("Runnable 接口演示")
print("=" * 60)
print("\nRunnable 是 LangChain 的核心接口,表示可以运行的对象。")
print("它支持同步/异步调用、批处理、流式处理等功能。\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: 基本 invoke 方法")
print("=" * 60)

# 创建提示模板
prompt = PromptTemplate.from_template(
    "你是一个专业的编程助手。请用一句话回答:{question}"
)

# 创建链
chain = prompt | llm | parser

print("\n使用 chain.invoke() 进行同步调用:")
print("-" * 60)

questions = [
    "什么是 Python?",
    "Python 的主要特点是什么?",
]

for question in questions:
    result = chain.invoke({"question": question})
    print(f"\n问题: {question}")
    print(f"回答: {result}")

print("\n" + "=" * 60)
print("演示 2: batch 批处理")
print("=" * 60)

print("\n使用 chain.batch() 进行批处理:")
print("-" * 60)

# 准备多个输入
batch_inputs = [
    {"question": "什么是 Python?"},
    {"question": "什么是机器学习?"},
    {"question": "什么是人工智能?"},
]

# 批处理调用
batch_results = chain.batch(batch_inputs)

print("\n批处理结果:")
for i, (input_data, result) in enumerate(zip(batch_inputs, batch_results), 1):
    print(f"\n{i}. 问题: {input_data['question']}")
    print(f"   回答: {result[:50]}...")

print("\n" + "=" * 60)
print("演示 3: stream 流式处理")
print("=" * 60)

print("\n使用 chain.stream() 进行流式处理:")
print("-" * 60)

# 流式调用
print("\n流式输出:")
for chunk in chain.stream({"question": "请详细介绍一下 Python 编程语言"}):
    print(chunk, end="", flush=True)
print("\n")

print("\n" + "=" * 60)
print("演示 4: 链式组合(使用 | 操作符)")
print("=" * 60)

# 创建多个步骤
step1 = PromptTemplate.from_template("将以下文本翻译成英文:{text}")
step2 = PromptTemplate.from_template("将以下英文总结成一句话:{text}")

# 组合多个链
multi_chain = step1 | llm | parser | step2 | llm | parser

print("\n使用 | 操作符组合多个步骤:")
print("multi_chain = step1 | llm | parser | step2 | llm | parser")
print("-" * 60)

test_text = "Python 是一种高级编程语言,具有简洁的语法和强大的功能。"

result = multi_chain.invoke({"text": test_text})
print(f"\n原始文本: {test_text}")
print(f"最终结果: {result}")

print("\n" + "=" * 60)
print("演示 5: 自定义函数(RunnableLambda 概念)")
print("=" * 60)

# 定义自定义函数
def to_uppercase(text: str) -> str:
    """转换为大写"""
    return text.upper()

def add_prefix(text: str) -> str:
    """添加前缀"""
    return f"回答: {text}"

# 创建包含自定义函数的链
custom_prompt = PromptTemplate.from_template("请用一句话介绍:{topic}")

# 手动组合链(演示如何将函数集成到链中)
print("\n创建包含自定义函数的链:")
print("-" * 60)
def process_chain(input_data):
    """手动组合的链,演示函数如何集成"""
    # Step 1: Prompt
    formatted = custom_prompt.format(**input_data)
    # Step 2: LLM
    response = llm.invoke(formatted)
    # Step 3: Parser
    parsed = parser.parse(response.content)
    # Step 4: Custom function
    uppercased = to_uppercase(parsed)
    # Step 5: Custom function
    prefixed = add_prefix(uppercased)
    return prefixed

topics = ["Python", "机器学习"]

for topic in topics:
    result = process_chain({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"结果: {result}")

print("\n" + "=" * 60)
print("演示 6: JSON 解析链")
print("=" * 60)

# 创建 JSON 解析链
json_prompt = PromptTemplate.from_template(
    """请以 JSON 格式输出以下信息:
- 姓名:{name}
- 年龄:{age}
- 城市:{city}

{format_instructions}

请输出 JSON:"""
)

json_parser = JsonOutputParser()
json_chain = json_prompt | llm | json_parser

print("\n使用 JSON 解析链:")
print("-" * 60)

test_cases = [
    {"name": "张三", "age": 30, "city": "北京"},
    {"name": "李四", "age": 25, "city": "上海"},
]

for case in test_cases:
    result = json_chain.invoke({
        **case,
        "format_instructions": json_parser.get_format_instructions()
    })
    print(f"\n输入: {case}")
    print(f"输出: {json.dumps(result, ensure_ascii=False, indent=2)}")

print("\n" + "=" * 60)
print("演示 7: 链式组合和重用")
print("=" * 60)

# 创建可重用的子链
translation_prompt = PromptTemplate.from_template("将以下中文翻译成英文:{text}")
translation_chain = translation_prompt | llm | parser

summary_prompt = PromptTemplate.from_template("将以下英文总结成一句话:{text}")
summary_chain = summary_prompt | llm | parser

print("\n创建可重用的子链:")
print("translation_chain = translation_prompt | llm | parser")
print("summary_chain = summary_prompt | llm | parser")
print("-" * 60)

# 手动组合链
chinese_text = "Python 是一种高级编程语言,具有简洁的语法和强大的功能。"

# 第一步:翻译
translated = translation_chain.invoke({"text": chinese_text})
print(f"\n中文原文: {chinese_text}")
print(f"翻译结果: {translated}")

# 第二步:总结
summarized = summary_chain.invoke({"text": translated})
print(f"总结结果: {summarized}")

print("\n" + "=" * 60)
print("演示 8: 错误处理")
print("=" * 60)

print("\n使用错误处理:")
print("-" * 60)

try:
    result = chain.invoke({"question": "测试问题"})
    print(f"成功: {result[:50]}...")
except Exception as e:
    print(f"错误: {e}")

print("\n" + "=" * 60)
print("演示 9: 链的属性检查")
print("=" * 60)

print("\n检查链的属性和方法:")
print("-" * 60)

print(f"chain 类型: {type(chain)}")
print(f"是否有 invoke 方法: {hasattr(chain, 'invoke')}")
print(f"是否有 batch 方法: {hasattr(chain, 'batch')}")
print(f"是否有 stream 方法: {hasattr(chain, 'stream')}")

# 测试链的组件
print(f"\n链的组件:")
if hasattr(chain, 'steps'):
    for i, step in enumerate(chain.steps, 1):
        print(f"  步骤 {i}: {type(step).__name__}")
else:
    print("  (无法直接访问步骤)")

print("\n" + "=" * 60)
print("演示 10: 实际应用场景")
print("=" * 60)

# 创建问答链
qa_prompt = PromptTemplate.from_template(
    """基于以下上下文回答问题。

上下文:{context}

问题:{question}

请根据上下文回答问题,如果上下文中没有相关信息,请说"我不知道"。"""
)

qa_chain = qa_prompt | llm | parser

print("\n使用问答链:")
print("-" * 60)

qa_cases = [
    {
        "context": "Python 是一种高级编程语言,由 Guido van Rossum 在 1991 年发布。",
        "question": "Python 是什么时候发布的?"
    },
    {
        "context": "机器学习是人工智能的一个分支,通过算法让计算机从数据中学习。",
        "question": "什么是机器学习?"
    },
]

for case in qa_cases:
    result = qa_chain.invoke(case)
    print(f"\n上下文: {case['context']}")
    print(f"问题: {case['question']}")
    print(f"回答: {result}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnable 的主要方法:")
print("1. invoke(input): 同步调用,处理单个输入")
print("2. batch(inputs): 批处理,处理多个输入")
print("3. stream(input): 流式处理,逐步产生输出")
print("4. ainvoke(input): 异步调用(如果支持)")
print("5. abatch(inputs): 异步批处理(如果支持)")
print("6. astream(input): 异步流式处理(如果支持)")
print("\nRunnable 的主要特性:")
print("1. 可组合:使用 | 操作符连接多个 Runnable")
print("2. 类型安全:支持类型提示")
print("3. 可重用:可以创建可重用的子链")
print("4. 灵活:支持自定义函数和逻辑")
print("\n常见使用模式:")
print("- prompt | llm | parser: 基本三步链")
print("- chain1 | chain2: 组合多个链")
print("- prompt | llm | function: 添加后处理")
print("\n使用场景:")
print("- 文本处理和转换")
print("- 数据提取和验证")
print("- 多步骤推理")
print("- 问答系统")
print("- 内容生成")
print("\n注意事项:")
print("- 确保每个组件都能处理前一个组件的输出")
print("- 使用类型提示提高代码可读性")
print("- 考虑错误处理和边界情况")
print("- 批处理可以提高效率")
print("- 流式处理可以提供更好的用户体验")

22.4. chat_models.py #

langchain/chat_models.py

# 导入操作系统相关模块
import os

# 导入 openai 模块
import openai
# 从 langchain.messages 模块导入 AIMessage、HumanMessage 和 SystemMessage 类
from langchain.messages import AIMessage, HumanMessage, SystemMessage
# 从 langchain.embeddings 模块导入 OpenAIEmbeddings 类
from langchain.embeddings import OpenAIEmbeddings

# 定义与 OpenAI 聊天模型交互的类
class ChatOpenAI:

    # 初始化方法
    def __init__(self, model: str = "gpt-4o", **kwargs):
        """
        初始化 ChatOpenAI

        Args:
            model: 模型名称,如 "gpt-4o"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 OPENAI_API_KEY 环境变量")

        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}

        # 创建 OpenAI 客户端实例
        self.client = openai.OpenAI(api_key=self.api_key)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)

        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""

        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 流式调用模型生成回复的方法
    def stream(self, input, **kwargs):
        """
        流式调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Yields:
            AIMessage: AI 的回复消息块(每次产生部分内容)
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)

        # 构建 API 请求参数字典,启用流式输出
        params = {
            "model": self.model,
            "messages": messages,
            "stream": True,  # 启用流式输出
            **self.model_kwargs,
            **kwargs
        }

        # 使用 OpenAI 客户端发起流式调用
        stream = self.client.chat.completions.create(**params)

        # 迭代流式响应
        for chunk in stream:
            # 检查是否有内容增量
            if chunk.choices and len(chunk.choices) > 0:
                delta = chunk.choices[0].delta
                if hasattr(delta, 'content') and delta.content:
                    # 产生包含部分内容的 AIMessage
                    yield AIMessage(content=delta.content)

    # 内部方法,将输入转换为 OpenAI API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 OpenAI API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: OpenAI API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage/AIMessage/SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user, ai 为 assistant, system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组,且长度为 2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与 DeepSeek 聊天模型交互的类
class ChatDeepSeek:

    # 初始化方法
    def __init__(self, model: str = "deepseek-chat", **kwargs):
        """
        初始化 ChatDeepSeek

        Args:
            model: 模型名称,如 "deepseek-chat"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DEEPSEEK_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DEEPSEEK_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # DeepSeek 的 API base URL
        base_url = kwargs.get("base_url", "https://api.deepseek.com/v1")
        # 创建 OpenAI 兼容的客户端实例(DeepSeek 使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]


# 定义与通义千问(Tongyi)聊天模型交互的类
class ChatTongyi:

    # 初始化方法
    def __init__(self, model: str = "qwen-max", **kwargs):
        """
        初始化 ChatTongyi

        Args:
            model: 模型名称,如 "qwen-max"
            **kwargs: 其他参数(如 temperature, max_tokens 等)
        """
        # 设置模型名称
        self.model = model
        # 获取 api_key,优先从参数获取,否则从环境变量获取
        self.api_key = kwargs.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
        # 如果没有提供 api_key,则抛出异常
        if not self.api_key:
            raise ValueError("需要提供 api_key 或设置 DASHSCOPE_API_KEY 环境变量")
        # 保存除 api_key 之外的其他参数,用于 API 调用
        self.model_kwargs = {k: v for k, v in kwargs.items() if k != "api_key"}
        # 通义千问的 API base URL(使用 OpenAI 兼容模式)
        base_url = kwargs.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
        # 创建 OpenAI 兼容的客户端实例(通义千问使用 OpenAI 兼容的 API)
        self.client = openai.OpenAI(api_key=self.api_key, base_url=base_url)

    # 调用模型生成回复的方法
    def invoke(self, input, **kwargs):
        """
        调用模型生成回复

        Args:
            input: 输入内容,可以是字符串或消息列表
            **kwargs: 额外的 API 参数

        Returns:
            AIMessage: AI 的回复消息
        """
        # 将输入数据转换为消息格式
        messages = self._convert_input(input)
        # 构建 API 请求参数字典
        params = {
            "model": self.model,
            "messages": messages,
            **self.model_kwargs,
            **kwargs
        }
        # 使用 OpenAI 兼容的客户端发起 chat.completions.create 调用获取回复
        response = self.client.chat.completions.create(**params)
        # 取出返回结果中的第一个选项
        choice = response.choices[0]
        # 获取消息内容
        content = choice.message.content or ""
        # 返回一个 AIMessage 对象
        return AIMessage(content=content)

    # 内部方法,将输入转换为 API 需要的消息格式
    def _convert_input(self, input):
        """
        将输入转换为 API 需要的消息格式

        Args:
            input: 字符串、消息列表或 ChatPromptValue

        Returns:
            list[dict]: API 格式的消息列表
        """
        # 如果是 ChatPromptValue,转换为消息列表
        from langchain.prompts import ChatPromptValue
        if isinstance(input, ChatPromptValue):
            input = input.to_messages()

        # 输入为字符串时,直接封装为用户角色消息
        if isinstance(input, str):
            return [{"role": "user", "content": input}]
        # 输入为列表时,需逐个元素处理
        elif isinstance(input, list):
            messages = []
            # 遍历输入列表中的每个元素
            for msg in input:
                # 如果该元素为字符串,作为用户消息处理
                if isinstance(msg, str):
                    messages.append({"role": "user", "content": msg})
                # 如果为 HumanMessage、AIMessage 或 SystemMessage 类型,判断角色并获取内容
                elif isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
                    # 判断消息类型,human 为 user,ai 为 assistant,system 为 system
                    if isinstance(msg, HumanMessage):
                        role = "user"
                    elif isinstance(msg, AIMessage):
                        role = "assistant"
                    elif isinstance(msg, SystemMessage):
                        role = "system"
                    # 获取消息内容属性
                    content = msg.content if hasattr(msg, "content") else str(msg)
                    messages.append({"role": role, "content": content})
                # 如果该元素本身为字典,直接添加
                elif isinstance(msg, dict):
                    messages.append(msg)
                # 如果是元组且长度为2,解包为 role 与 content
                elif isinstance(msg, tuple) and len(msg) == 2:
                    role, content = msg
                    messages.append({"role": role, "content": content})
            # 返回处理后的消息列表
            return messages
        else:
            # 其他输入类型,转为字符串作为 user 消息
            return [{"role": "user", "content": str(input)}]

+# 延迟导入 runnables 模块以启用 LCEL 支持
+# 使用 try-except 避免循环导入问题
+try:
+   import langchain.runnables  # noqa: F401
+except ImportError:
+   pass

22.5. prompts.py #

langchain/prompts.py

# 导入正则表达式库
import re
# 导入 JSON 和路径处理
import json
from pathlib import Path
# 从本地模块导入三类消息类型
from .messages import SystemMessage, HumanMessage, AIMessage

# 定义提示词模板类
class PromptTemplate:
    # 提示词模板类,用于格式化字符串模板
    """提示词模板类,用于格式化字符串模板"""

    # 构造方法
    def __init__(self, template: str, partial_variables: dict = None):
        # 保存模板字符串
        self.template = template
        # 保存部分变量(已预填充的变量)
        self.partial_variables = partial_variables or {}
        # 提取模板中的变量名
        all_variables = self._extract_variables(template)
        # 从所有变量中排除已部分填充的变量
        self.input_variables = [v for v in all_variables if v not in self.partial_variables]

    # 从模板字符串创建PromptTemplate实例的类方法
    @classmethod
    def from_template(cls, template: str):
        # 返回PromptTemplate实例
        return cls(template=template)

    # 格式化模板字符串
    def format(self, **kwargs):
        # 合并部分变量和用户提供的变量
        all_vars = {**self.partial_variables, **kwargs}
        # 获取缺失的变量名集合
        missing_vars = set(self.input_variables) - set(kwargs.keys())
        # 如果有变量未传入,则抛出错误
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 使用format方法将变量填充到模板字符串
        return self.template.format(**all_vars)

    # 定义部分填充模板变量的方法,返回新的模板实例
    def partial(self, **kwargs):
        """
        部分填充模板变量,返回一个新的 PromptTemplate 实例

        Args:
            **kwargs: 要部分填充的变量及其值

        Returns:
            新的 PromptTemplate 实例,其中指定的变量已被填充

        示例:
            template = PromptTemplate.from_template("你好,我叫{name},我来自{city}")
            partial_template = template.partial(name="张三")
            # 现在只需要提供 city 参数
            result = partial_template.format(city="北京")
        """
        # 合并现有对象的部分变量(partial_variables)和本次要填充的新变量
        new_partial_variables = {**self.partial_variables, **kwargs}
        # 使用原模板字符串和更新后的部分变量,创建新的 PromptTemplate 实例
        new_template = PromptTemplate(
            template=self.template,
            partial_variables=new_partial_variables
        )
        # 返回新的 PromptTemplate 实例
        return new_template

    # 提取模板变量的内部方法
    def _extract_variables(self, template: str):
        # 正则表达式匹配花括号内的变量,兼容:格式
        pattern = r'\{([^}:]+)(?::[^}]+)?\}'
        # 匹配所有变量名
        matches = re.findall(pattern, template)
        # 去重但保持顺序返回列表
        return list(dict.fromkeys(matches))

# 定义格式化消息值类
class ChatPromptValue:
    # 聊天提示词值类,包含格式化后的消息列表
    """聊天提示词值类,包含格式化后的消息列表"""

    # 构造方法,接收消息列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages

    # 将消息对象列表转为字符串
    def to_string(self):
        # 新建一个用于存放字符串的列表
        parts = []
        # 遍历每个消息
        for msg in self.messages:
            # 如果有type和content属性
            if hasattr(msg, 'type') and hasattr(msg, 'content'):
                # 定义角色映射字典
                role_map = {
                    "system": "System",
                    "human": "Human",
                    "ai": "AI"
                }
                # 获取大写角色名
                role = role_map.get(msg.type, msg.type.capitalize())
                # 字符串形式添加到parts
                parts.append(f"{role}: {msg.content}")
            else:
                # 其他对象直接str
                parts.append(str(msg))
        # 用换行符拼接所有消息并返回
        return "\n".join(parts)

    # 返回消息对象列表
    def to_messages(self):
        # 直接返回消息列表
        return self.messages

# 定义基础消息提示词模板类
class BaseMessagePromptTemplate:
    # 基础消息提示词模板类
    """基础消息提示词模板类"""

    # 构造方法,必须给定PromptTemplate实例
    def __init__(self, prompt: PromptTemplate):
        # 保存PromptTemplate到实例变量
        self.prompt = prompt

    # 工厂方法:用模板字符串创建实例
    @classmethod
    def from_template(cls, template: str):
        # 先通过模板字符串新建PromptTemplate对象
        prompt = PromptTemplate.from_template(template)
        # 返回当前类的实例
        return cls(prompt=prompt)

    # 格式化当前模板,返回消息对象
    def format(self, **kwargs):
        # 先用PromptTemplate格式化内容
        content = self.prompt.format(**kwargs)
        # 由子类方法创建对应类型消息对象
        return self._create_message(content)

    # 子类必须实现此消息构建方法
    def _create_message(self, content):
        raise NotImplementedError

# 定义系统消息提示词模板类
class SystemMessagePromptTemplate(BaseMessagePromptTemplate):
    # 系统消息提示词模板类
    """系统消息提示词模板"""

    # 创建SystemMessage对象
    def _create_message(self, content):
        # 延迟导入SystemMessage类型
        from langchain.messages import SystemMessage
        # 返回生成的SystemMessage对象
        return SystemMessage(content=content)

# 定义人类消息提示词模板类
class HumanMessagePromptTemplate(BaseMessagePromptTemplate):
    # 人类消息提示词模板类
    """人类消息提示词模板"""

    # 创建HumanMessage对象
    def _create_message(self, content):
        # 延迟导入HumanMessage类型
        from langchain.messages import HumanMessage
        # 返回生成的HumanMessage对象
        return HumanMessage(content=content)

# 定义AI消息提示词模板类
class AIMessagePromptTemplate(BaseMessagePromptTemplate):
    # AI消息提示词模板类
    """AI消息提示词模板"""

    # 创建AIMessage对象
    def _create_message(self, content):
        # 延迟导入AIMessage类型
        from langchain.messages import AIMessage
        # 返回生成的AIMessage对象
        return AIMessage(content=content)

# 定义动态消息列表占位符类
class MessagesPlaceholder:
    # 在聊天模板中插入动态消息列表的占位符
    """在聊天模板中插入动态消息列表的占位符"""

    # 构造方法,存储变量名
    def __init__(self, variable_name: str):
        self.variable_name = variable_name

# 定义聊天提示词模板类
class ChatPromptTemplate:
    # 聊天提示词模板类,用于创建多轮对话的提示词
    """聊天提示词模板类,用于创建多轮对话的提示词"""

    # 构造方法,传入一个消息模板/对象列表
    def __init__(self, messages):
        # 保存消息列表到实例变量
        self.messages = messages
        # 提取所有输入变量到实例变量
        self.input_variables = self._extract_input_variables()

    # 类方法:通过消息列表创建ChatPromptTemplate实例
    @classmethod
    def from_messages(cls, messages):
        # 返回通过messages参数新建的实例
        return cls(messages=messages)

    # 调用格式化所有消息,生成ChatPromptValue对象
    def invoke(self, input_variables):
        # 格式化所有消息对象
        formatted_messages = self._format_all_messages(input_variables)
        # 返回ChatPromptValue对象
        return ChatPromptValue(messages=formatted_messages)

    # 使用提供的变量格式化模板,返回消息列表
    def format_messages(self, **kwargs):
        # 格式化所有消息并返回
        return self._format_all_messages(kwargs)

    # 提取所有输入变量
    def _extract_input_variables(self):
        # 用集合避免变量重复
        variables = set()
        # 遍历所有消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                _, template_str = msg
                # 新建PromptTemplate提取其变量
                prompt = PromptTemplate.from_template(template_str)
                variables.update(prompt.input_variables)
            # 如果是BaseMessagePromptTemplate子类实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                variables.update(msg.prompt.input_variables)
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                variables.add(msg.variable_name)
        # 返回变量名列表
        return list(variables)

    # 给所有消息模板/对象填充变量并变为消息对象列表
    def _format_all_messages(self, variables):
        # 存放格式化后消息
        formatted_messages = []
        # 遍历每个消息模板或对象
        for msg in self.messages:
            # 如果是(role, template_str)元组
            if isinstance(msg, tuple) and len(msg) == 2:
                role, template_str = msg
                prompt = PromptTemplate.from_template(template_str)
                content = prompt.format(**variables)
                formatted_messages.append(self._create_message_from_role(role, content))
            # 如果是BaseMessagePromptTemplate的实例
            elif isinstance(msg, BaseMessagePromptTemplate):
                formatted_messages.append(msg.format(**variables))
            # 如果是占位符对象
            elif isinstance(msg, MessagesPlaceholder):
                placeholder_messages = self._coerce_placeholder_value(
                    msg.variable_name, variables.get(msg.variable_name)
                )
                formatted_messages.extend(placeholder_messages)
            # 其他情况直接追加
            else:
                formatted_messages.append(msg)
        # 返回格式化的消息列表
        return formatted_messages

    # 处理占位符对象的值,返回消息对象列表
    def _coerce_placeholder_value(self, variable_name, value):
        # 如果未传入变量,抛出异常
        if value is None:
            raise ValueError(f"MessagesPlaceholder '{variable_name}' 对应变量缺失")
        # 如果是ChatPromptValue实例,转换为消息列表
        if isinstance(value, ChatPromptValue):
            return value.to_messages()
        # 如果已经是消息对象/结构列表,则依次转换
        if isinstance(value, list):
            return [self._coerce_single_message(item) for item in value]
        # 其他情况尝试单个转换
        return [self._coerce_single_message(value)]

    # 单个原始值转换为消息对象
    def _coerce_single_message(self, value):
        # 已是有效消息类型,直接返回
        if isinstance(value, (SystemMessage, HumanMessage, AIMessage)):
            return value
        # 有type和content属性,也当消息对象直接返回
        if hasattr(value, "type") and hasattr(value, "content"):
            return value
        # 字符串变为人类消息
        if isinstance(value, str):
            return HumanMessage(content=value)
        # (role, content)元组转为指定角色的消息
        if isinstance(value, tuple) and len(value) == 2:
            role, content = value
            return self._create_message_from_role(role, content)
        # 字典,默认user角色
        if isinstance(value, dict):
            role = value.get("role", "user")
            content = value.get("content", "")
            return self._create_message_from_role(role, content)
        # 其他无法识别类型,抛出异常
        raise TypeError("无法将占位符内容转换为消息")

    # 通过角色字符串和内容构建标准消息对象
    def _create_message_from_role(self, role, content):
        # 角色字符串全部转小写
        normalized_role = role.lower()
        # 系统角色
        if normalized_role == "system":
            return SystemMessage(content=content)
        # 人类/用户角色
        if normalized_role in ("human", "user"):
            return HumanMessage(content=content)
        # AI/assistant角色
        if normalized_role in ("ai", "assistant"):
            return AIMessage(content=content)
        # 其它未知角色抛异常
        raise ValueError(f"未知的消息角色: {role}")
# 定义 FewShotPromptTemplate 类,用于构造 few-shot 提示词的模板
class FewShotPromptTemplate:
    # few-shot 提示词模板说明
    """用于构造 few-shot 提示词的模板"""

    # 构造方法,初始化示例、模板、前缀、后缀、分隔符和输入变量
    def __init__(
        self,
        *,
        examples: list[dict] = None,  # 示例列表,元素为字典
        example_prompt: PromptTemplate | str,  # 示例模板,可为 PromptTemplate 对象或字符串
        prefix: str = "",  # 提示词前缀
        suffix: str = "",  # 提示词后缀
        example_separator: str = "\n\n",  # 示例之间的分隔符
        input_variables: list[str] | None = None,  # 输入变量列表
        example_selector=None,  # 示例选择器(可选)
    ):
        # 如果提供了示例选择器,则使用选择器;否则使用提供的示例列表
        self.example_selector = example_selector
        # 如果未传示例,则设为默认空列表
        self.examples = examples or []
        # 判断 example_prompt 参数类型
        if isinstance(example_prompt, PromptTemplate):
            # 如果是 PromptTemplate 实例,直接赋值
            self.example_prompt = example_prompt
        else:
            # 如果是字符串,则通过 from_template 方法实例化为 PromptTemplate
            self.example_prompt = PromptTemplate.from_template(example_prompt)
        # 保存前缀
        self.prefix = prefix
        # 保存后缀
        self.suffix = suffix
        # 保存示例分隔符
        self.example_separator = example_separator
        # 输入变量列表,未提供则自动推断
        self.input_variables = input_variables or self._infer_input_variables()

    # 内部方法:自动推断输入变量列表
    def _infer_input_variables(self) -> list[str]:
        # 创建空集合用于存放变量名
        variables = set()
        # 提取前缀中的变量并加入集合
        variables.update(self._extract_vars(self.prefix))
        # 提取后缀中的变量并加入集合
        variables.update(self._extract_vars(self.suffix))
        # 返回变量集合的列表形式
        return list(variables)

    # 内部方法:从文本中提取所有花括号包围的变量名
    def _extract_vars(self, text: str) -> list[str]:
        # 如果文本为空,返回空列表
        if not text:
            return []
        # 定义正则表达式,匹配 {变量名} 或 {变量名:格式}
        pattern = r"\{([^}:]+)(?::[^}]+)?\}"
        # 使用正则在文本中查找所有变量名
        matches = re.findall(pattern, text)
        # 去重并保持顺序,返回变量名列表
        return list(dict.fromkeys(matches))

    # 向示例列表中添加一个新示例(字典类型)
    def add_example(self, example: dict):
        """动态添加单条示例"""
        # 在示例列表末尾追加新示例
        self.examples.append(example)

    # 格式化所有示例,返回由字符串组成的列表
    def format_examples(self, input_variables: dict = None) -> list[str]:
        """
        返回格式化后的示例字符串列表

        Args:
            input_variables: 输入变量字典,用于示例选择器选择示例
        """
        # 如果提供了示例选择器,使用选择器选择示例
        if self.example_selector and input_variables:
            selected_examples = self.example_selector.select_examples(input_variables)
        else:
            selected_examples = self.examples

        # 创建空列表用于保存格式化结果
        formatted = []
        # 遍历选中的示例
        for example in selected_examples:
            # 用 example_prompt 对每个示例格式化,并添加到 formatted 列表
            formatted.append(self.example_prompt.format(**example))
        # 返回格式化后的字符串列表
        return formatted

    # 格式化 few-shot 提示串,根据输入变量生成完整提示词
    def format(self, **kwargs) -> str:
        """
        根据传入变量生成 few-shot 提示词

        Args:
            **kwargs: 输入变量,如果提供了 example_selector,会用于选择示例
        """
        # 检查传入的变量是否完整,缺失就抛异常
        missing = set(self.input_variables) - set(kwargs.keys())
        if missing:
            raise ValueError(f"缺少必需的变量: {missing}")

        # 创建 parts 列表,用于拼接最终组成部分
        parts: list[str] = []
        # 如果有前缀,则格式化并加入 parts
        if self.prefix:
            parts.append(self._format_text(self.prefix, **kwargs))

        # 格式化所有示例并拼接为块
        # 如果使用示例选择器,传递输入变量;否则不传递
        if self.example_selector:
            example_block = self.example_separator.join(self.format_examples(input_variables=kwargs))
        else:
            example_block = self.example_separator.join(self.format_examples())
        # 如果示例块非空,加到 parts
        if example_block:
            parts.append(example_block)

        # 如果有后缀,则格式化并加入 parts
        if self.suffix:
            parts.append(self._format_text(self.suffix, **kwargs))

        # 用分隔符拼接所有组成部分,返回最终结果
        return self.example_separator.join(part for part in parts if part)

    # 内部方法:用 PromptTemplate 对 text 按变量进行格式化
    def _format_text(self, text: str, **kwargs) -> str:
        # 把 text 构造为 PromptTemplate,然后用 kwargs 格式化
        temp_prompt = PromptTemplate.from_template(text)
        return temp_prompt.format(**kwargs)


# 定义一个从文件加载提示词模板的函数
def load_prompt(path: str | Path,encoding: str | None = None) -> PromptTemplate:
    """
    从 JSON 文件加载提示词模板

    Args:
        path: 提示词配置文件的路径(支持 .json 格式)

    Returns:
        PromptTemplate 实例

    JSON 文件格式示例:
        {
            "_type": "prompt",
            "template": "你好,我叫{name},你是谁?"
        }
    """
    # 将传入路径转换为 Path 对象
    file_path = Path(path)
    # 检查文件是否存在,不存在则抛出异常
    if not file_path.exists():
        raise FileNotFoundError(f"提示词文件不存在: {path}")
    # 检查文件扩展名是否为 .json,不是则抛出异常
    if file_path.suffix != ".json":
        raise ValueError(f"只支持 .json 格式文件,当前文件: {file_path.suffix}")
    # 打开文件并以 utf-8 编码读取 JSON 内容到 config 变量中
    with file_path.open(encoding=encoding) as f:
        config = json.load(f)
    # 从配置中获取 _type 字段,默认值为 "prompt"
    config_type = config.get("_type", "prompt")
    # 如果配置类型不是 "prompt",则抛出异常
    if config_type != "prompt":
        raise ValueError(f"不支持的提示词类型: {config_type},当前只支持 'prompt'")
    # 获取模板字符串,若不存在则抛出异常
    template = config.get("template")
    if template is None:
        raise ValueError("配置文件中缺少 'template' 字段")
    # 使用读取到的模板字符串创建 PromptTemplate 实例并返回
    return PromptTemplate.from_template(template)

# 定义管道式提示词模板类
class PipelinePromptTemplate:
    """
    管道式提示词模板,将多个模板串联
    前一个模板的输出作为后一个模板的输入变量

    使用方式:
        template1 = PromptTemplate.from_template("问题:{question}")
        template2 = PromptTemplate.from_template("上下文:{context}")
        final = PromptTemplate.from_template("{output_0}\n{output_1}\n请回答")
        pipeline = PipelinePromptTemplate([template1, template2], final)
        result = pipeline.format(question="...", context="...")
    """

    # 构造方法,初始化管道式提示词模板
    def __init__(self, prompt_templates: list[PromptTemplate], final_prompt: PromptTemplate):
        """
        初始化 PipelinePromptTemplate

        Args:
            prompt_templates: 中间模板列表,每个模板的输出会作为最终模板的输入
            final_prompt: 最终模板,使用所有中间模板的输出(output_0, output_1, ...)和用户变量
        """
        # 保存中间模板列表
        self.prompt_templates = prompt_templates
        # 保存最终模板
        self.final_prompt = final_prompt
        # 提取管道模板所需的全部输入变量
        self.input_variables = self._extract_input_variables()

    # 内部方法:提取所有需要的输入变量
    def _extract_input_variables(self):
        """
        提取所有需要的输入变量

        Returns:
            输入变量列表
        """
        # 定义变量集合用于存储所有输入变量
        variables = set()
        # 遍历所有中间模板,收集其输入变量
        for template in self.prompt_templates:
            variables.update(template.input_variables)
        # 收集最终模板需要的变量
        final_vars = set(self.final_prompt.input_variables)
        # 构造中间模板输出变量名集合(output_0, output_1, ...)
        intermediate_outputs = {f"output_{i}" for i in range(len(self.prompt_templates))}
        # 最终模板中去掉这些output_x后的用户变量
        user_vars = final_vars - intermediate_outputs
        # 合并所有输入变量
        variables.update(user_vars)
        # 返回变量列表
        return list(variables)

    # 格式化管道提示词模板
    def format(self, **kwargs):
        """
        格式化管道模板

        Args:
            **kwargs: 用户提供的变量

        Returns:
            格式化后的最终字符串
        """
        # 创建字典存储所有中间步骤的输出
        intermediate_outputs = {}
        # 逐个遍历并格式化中间模板
        for i, template in enumerate(self.prompt_templates):
            # 准备传递给当前模板的变量
            template_vars = {}
            # 取用户提供的变量,如果模板需要则填入
            for key, value in kwargs.items():
                if key in template.input_variables:
                    template_vars[key] = value
            # 用当前变量格式化模板,获得输出
            output = template.format(**template_vars)
            # 保存输出到中间输出字典,变量名为output_{i}
            intermediate_outputs[f"output_{i}"] = output
        # 合并用户变量和中间模板所有输出,准备给最终模板用
        final_vars = {**kwargs, **intermediate_outputs}
        # 用最终变量格式化最终模板
        return self.final_prompt.format(**final_vars)

+# 延迟导入 runnables 模块以启用 LCEL 支持
+# 使用 try-except 避免循环导入问题
+try:
+   import langchain.runnables  # noqa: F401
+except ImportError:
+   pass
+

22.6. output_parsers.py #

langchain/output_parsers.py

# 导入类型提示相关模块
from typing import Any, Generic, TypeVar
from abc import ABC, abstractmethod
import json
import re

# 定义类型变量
T = TypeVar('T')


# 定义输出解析器的抽象基类
class BaseOutputParser(ABC, Generic[T]):
    """输出解析器的抽象基类"""

    @abstractmethod
    def parse(self, text: str) -> T:
        """
        解析输出文本

        Args:
            text: 要解析的文本

        Returns:
            解析后的结果
        """
        pass


# 定义字符串输出解析器类
class StrOutputParser(BaseOutputParser[str]):
    """
    字符串输出解析器

    将 LLM 的输出解析为字符串。这是最简单的输出解析器,
    它不会修改输入内容,只是确保输出是字符串类型。

    主要用于:
    - 确保 LLM 输出是字符串类型
    - 在链式调用中统一输出格式
    - 简化输出处理流程
    """

    def parse(self, text: str) -> str:
        """
        解析输出文本(实际上只是返回原文本)

        Args:
            text: 输入文本(应该是字符串)

        Returns:
            str: 原样返回输入文本
        """
        # StrOutputParser 不会修改内容,只是确保类型为字符串
        # 如果输入不是字符串,尝试转换
        if not isinstance(text, str):
            return str(text)
        return text

    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        return "StrOutputParser()"


# 辅助函数:从文本中提取 JSON(支持 markdown 代码块)
def parse_json_markdown(text: str) -> Any:
    """
    从文本中解析 JSON,支持 markdown 代码块格式

    Args:
        text: 可能包含 JSON 的文本

    Returns:
        解析后的 JSON 对象

    Raises:
        json.JSONDecodeError: 如果无法解析 JSON
    """
    # 去除首尾空白
    text = text.strip()

    # 尝试匹配 markdown 代码块中的 JSON
    # 匹配 \`\`\`json ... ``` 或 ``` ...
json_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if json_match:
    text = json_match.group(1).strip()

# 尝试匹配 { ... } 或 [ ... ]
json_match = re.search(r'(\{.*\}|$$.*$$)', text, re.DOTALL)
if json_match:
    text = json_match.group(1)

# 解析 JSON
return json.loads(text)

定义 JSON 输出解析器类 #

class JsonOutputParser(BaseOutputParser[Any]): """ JSON 输出解析器

将 LLM 的输出解析为 JSON 对象。支持:
- 纯 JSON 字符串
- Markdown 代码块中的 JSON(\`\`\`json ... ```)
- 包含 JSON 的文本(自动提取)

主要用于:
- 结构化数据提取
- API 响应解析
- 数据格式化
"""

def __init__(self, pydantic_object: type = None):
    """
    初始化 JsonOutputParser

    Args:
        pydantic_object: 可选的 Pydantic 模型类,用于验证 JSON 结构
    """
    self.pydantic_object = pydantic_object

def parse(self, text: str) -> Any:
    """
    解析 JSON 输出文本

    Args:
        text: 包含 JSON 的文本

    Returns:
        Any: 解析后的 JSON 对象(字典、列表等)

    Raises:
        ValueError: 如果无法解析 JSON
    """
    try:
        # 使用辅助函数解析 JSON
        parsed = parse_json_markdown(text)

        # 如果指定了 Pydantic 模型,进行验证
        if self.pydantic_object is not None:
            try:
                # 尝试使用 Pydantic 验证(如果可用)
                if hasattr(self.pydantic_object, 'model_validate'):
                    return self.pydantic_object.model_validate(parsed)
                elif hasattr(self.pydantic_object, 'parse_obj'):
                    return self.pydantic_object.parse_obj(parsed)
            except Exception:
                # 如果验证失败,返回原始解析结果
                pass

        return parsed
    except json.JSONDecodeError as e:
        raise ValueError(f"无法解析 JSON 输出: {text[:100]}... 错误: {e}")
    except Exception as e:
        raise ValueError(f"解析 JSON 时出错: {e}")

def get_format_instructions(self) -> str:
    """
    获取格式说明,用于在提示词中指导 LLM 输出 JSON 格式

    Returns:
        str: 格式说明文本
    """
    if self.pydantic_object is not None:
        # 如果有 Pydantic 模型,返回其 schema
        try:
            if hasattr(self.pydantic_object, 'model_json_schema'):
                schema = self.pydantic_object.model_json_schema()
            elif hasattr(self.pydantic_object, 'schema'):
                schema = self.pydantic_object.schema()
            else:
                schema = {}

            return f"""请以 JSON 格式输出,格式如下:

```json {schema} ```

确保输出是有效的 JSON 格式。""" except Exception: pass

    return """请以 JSON 格式输出你的回答。

输出格式要求:

  1. 使用有效的 JSON 格式
  2. 可以使用 markdown 代码块包裹:```json ... `
  3. 确保所有字符串都用双引号
  4. 确保 JSON 格式正确且完整

示例格式: ```json { "key": "value", "number": 123 }


    def __repr__(self) -> str:
        """返回解析器的字符串表示"""
        if self.pydantic_object:
            return f"JsonOutputParser(pydantic_object={self.pydantic_object.__name__})"
        return "JsonOutputParser()"


# 定义 Pydantic 输出解析器类
class PydanticOutputParser(JsonOutputParser):
    """
    Pydantic 输出解析器

    将 LLM 的输出解析为 Pydantic 模型实例。继承自 JsonOutputParser,
    先解析 JSON,然后验证并转换为 Pydantic 模型。

    主要用于:
    - 结构化数据验证
    - 类型安全的数据提取
    - 自动数据验证和转换
    """

    def __init__(self, pydantic_object: type):
        """
        初始化 PydanticOutputParser

        Args:
            pydantic_object: Pydantic 模型类(必需)

        Raises:
            ValueError: 如果 pydantic_object 不是有效的 Pydantic 模型
        """
        if pydantic_object is None:
            raise ValueError("PydanticOutputParser 需要一个 Pydantic 模型类")

        # 检查是否是 Pydantic 模型
        try:
            import pydantic
            if not (issubclass(pydantic_object, pydantic.BaseModel) or 
                    (hasattr(pydantic, 'v1') and issubclass(pydantic_object, pydantic.v1.BaseModel))):
                raise ValueError(f"{pydantic_object} 不是有效的 Pydantic 模型类")
        except ImportError:
            # 如果没有安装 pydantic,使用更宽松的检查
            if not hasattr(pydantic_object, '__fields__') and not hasattr(pydantic_object, 'model_fields'):
                raise ValueError(f"{pydantic_object} 可能不是有效的 Pydantic 模型类(请确保已安装 pydantic)")

        self.pydantic_object = pydantic_object
        # 调用父类初始化,传入 pydantic_object
        super().__init__(pydantic_object=pydantic_object)

    def parse(self, text: str):
        """
        解析输出文本为 Pydantic 模型实例

        Args:
            text: 包含 JSON 的文本

        Returns:
            Pydantic 模型实例

        Raises:
            ValueError: 如果无法解析 JSON 或验证失败
        """
        try:
            # 先使用父类方法解析 JSON
            json_obj = super().parse(text)

            # 转换为 Pydantic 模型实例
            return self._parse_obj(json_obj)
        except Exception as e:
            raise ValueError(f"无法解析为 Pydantic 模型: {e}")

    def _parse_obj(self, obj: dict):
        """
        将字典对象转换为 Pydantic 模型实例

        Args:
            obj: 字典对象

        Returns:
            Pydantic 模型实例

        Raises:
            ValueError: 如果验证失败
        """
        try:
            import pydantic

            # 尝试使用 Pydantic v2 的 model_validate
            if hasattr(self.pydantic_object, 'model_validate'):
                return self.pydantic_object.model_validate(obj)
            # 尝试使用 Pydantic v1 的 parse_obj
            elif hasattr(self.pydantic_object, 'parse_obj'):
                return self.pydantic_object.parse_obj(obj)
            # 尝试直接实例化
            else:
                return self.pydantic_object(**obj)
        except ImportError:
            # 如果没有 pydantic,尝试直接实例化
            return self.pydantic_object(**obj)
        except Exception as e:
            raise ValueError(f"Pydantic 验证失败: {e}")

    def _get_schema(self) -> dict:
        """
        获取 Pydantic 模型的 JSON Schema

        Returns:
            dict: JSON Schema 字典
        """
        try:
            # 尝试使用 Pydantic v2 的 model_json_schema
            if hasattr(self.pydantic_object, 'model_json_schema'):
                return self.pydantic_object.model_json_schema()
            # 尝试使用 Pydantic v1 的 schema
            elif hasattr(self.pydantic_object, 'schema'):
                return self.pydantic_object.schema()
            else:
                return {}
        except Exception:
            return {}

    def get_format_instructions(self) -> str:
        """
        获取格式说明,包含 Pydantic 模型的 Schema

        Returns:
            str: 格式说明文本
        """
        schema = self._get_schema()

        # 清理 schema,移除不必要的字段
        reduced_schema = dict(schema)
        if "title" in reduced_schema:
            del reduced_schema["title"]
        if "type" in reduced_schema:
            del reduced_schema["type"]

        schema_str = json.dumps(reduced_schema, ensure_ascii=False, indent=2)

        return f"""请以 JSON 格式输出,必须严格遵循以下 Schema:

\`\`\`json
{schema_str}
\`\`\`

输出要求:
1. 必须完全符合上述 Schema 结构
2. 所有必需字段都必须提供
3. 字段类型必须匹配(字符串、数字、布尔值等)
4. 使用有效的 JSON 格式
5. 可以使用 markdown 代码块包裹:\`\`\`json ...

确保输出是有效的 JSON,并且符合 Schema 定义。"""

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"PydanticOutputParser(pydantic_object={self.pydantic_object.__name__})"

定义输出解析异常类 #

class OutputParserException(ValueError): """输出解析异常""" def init(self, message: str, llm_output: str = ""): super().init(message) self.llm_output = llm_output

定义输出修复解析器类 #

class OutputFixingParser(BaseOutputParser[T]): """ 输出修复解析器

包装一个基础解析器,当解析失败时,使用 LLM 自动修复输出。
这是一个非常有用的功能,可以处理 LLM 输出格式不正确的情况。

工作原理:
1. 首先尝试使用基础解析器解析输出
2. 如果解析失败,将错误信息和原始输出发送给 LLM
3. LLM 根据格式说明修复输出
4. 再次尝试解析修复后的输出
5. 可以设置最大重试次数
"""

def __init__(
    self,
    parser: BaseOutputParser[T],
    retry_chain,
    max_retries: int = 1,
):
    """
    初始化 OutputFixingParser

    Args:
        parser: 基础解析器
        retry_chain: 用于修复输出的链(通常是 Prompt -> LLM -> StrOutputParser)
        max_retries: 最大重试次数
    """
    self.parser = parser
    self.retry_chain = retry_chain
    self.max_retries = max_retries

@classmethod
def from_llm(
    cls,
    llm,
    parser: BaseOutputParser[T],
    prompt=None,
    max_retries: int = 1,
) -> "OutputFixingParser[T]":
    """
    从 LLM 创建 OutputFixingParser

    Args:
        llm: 用于修复输出的语言模型
        parser: 基础解析器
        prompt: 修复提示模板(可选,有默认模板)
        max_retries: 最大重试次数

    Returns:
        OutputFixingParser 实例
    """
    from langchain.prompts import PromptTemplate
    from langchain.output_parsers import StrOutputParser

    # 默认修复提示模板
    if prompt is None:
        fix_template = """Instructions:

{instructions} #

Completion: #

{completion} #

上面的 Completion 没有满足 Instructions 中的约束要求。

错误信息: #

{error} #

请修复输出,确保它满足 Instructions 中的所有约束要求。只返回修复后的输出,不要包含其他内容:""" prompt = PromptTemplate.from_template(fix_template)

    # 创建修复链:Prompt -> LLM -> StrOutputParser
    retry_chain = _SimpleChain(prompt, llm, StrOutputParser())

    return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)

def parse(self, completion: str) -> T:
    """
    解析输出,如果失败则尝试修复

    Args:
        completion: LLM 的输出文本

    Returns:
        T: 解析后的结果

    Raises:
        OutputParserException: 如果修复后仍然无法解析
    """
    retries = 0
    original_completion = completion

    while retries <= self.max_retries:
        try:
            # 尝试使用基础解析器解析
            return self.parser.parse(completion)
        except (ValueError, OutputParserException, Exception) as e:
            # 如果已达到最大重试次数,抛出异常
            if retries >= self.max_retries:
                raise OutputParserException(
                    f"解析失败,已重试 {retries} 次: {e}",
                    llm_output=completion
                )

            retries += 1
            print(f"  第 {retries} 次尝试修复...")

            # 获取格式说明(如果解析器支持)
            try:
                instructions = self.parser.get_format_instructions()
            except (AttributeError, NotImplementedError):
                instructions = "请确保输出格式正确。"

            # 使用 LLM 修复输出
            try:
                if hasattr(self.retry_chain, 'invoke'):
                    # 新式链式调用
                    completion = self.retry_chain.invoke({
                        "instructions": instructions,
                        "completion": completion,
                        "error": str(e),
                    })
                elif hasattr(self.retry_chain, 'run'):
                    # 旧式链式调用
                    completion = self.retry_chain.run(
                        instructions=instructions,
                        completion=completion,
                        error=str(e),
                    )
                else:
                    # 直接调用(作为函数)
                    completion = self.retry_chain({
                        "instructions": instructions,
                        "completion": completion,
                        "error": str(e),
                    })

                # 确保返回的是字符串
                if not isinstance(completion, str):
                    completion = str(completion)

            except Exception as fix_error:
                raise OutputParserException(
                    f"修复输出时出错: {fix_error}",
                    llm_output=completion
                )

    raise OutputParserException(
        f"解析失败,已重试 {self.max_retries} 次",
        llm_output=completion
    )

def get_format_instructions(self) -> str:
    """
    获取格式说明(委托给基础解析器)

    Returns:
        str: 格式说明文本
    """
    try:
        return self.parser.get_format_instructions()
    except (AttributeError, NotImplementedError):
        return "请确保输出格式正确。"

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"OutputFixingParser(parser={self.parser}, max_retries={self.max_retries})"

简单的链式调用包装类 #

class _SimpleChain: """简单的链式调用包装类,用于连接 Prompt -> LLM -> Parser"""

def __init__(self, prompt, llm, parser):
    self.prompt = prompt
    self.llm = llm
    self.parser = parser

def invoke(self, input_dict: dict) -> str:
    """调用链"""
    # 格式化提示词
    formatted = self.prompt.format(**input_dict)
    # 调用 LLM
    response = self.llm.invoke(formatted)
    # 解析输出
    if hasattr(response, 'content'):
        content = response.content
    else:
        content = str(response)
    # 使用解析器解析
    return self.parser.parse(content)

def run(self, **kwargs) -> str:
    """运行链(兼容旧接口)"""
    return self.invoke(kwargs)

简单的 PromptValue 包装类 #

class _SimplePromptValue: """简单的 PromptValue 包装类,用于 RetryOutputParser""" def init(self, text: str): self.text = text

def to_string(self) -> str:
    """返回提示词字符串"""
    return self.text

定义重试输出解析器类 #

class RetryOutputParser(BaseOutputParser[T]): """ 重试输出解析器

包装一个基础解析器,当解析失败时,使用 LLM 重新生成输出。
与 OutputFixingParser 的区别:
- RetryOutputParser 需要原始 prompt 和 completion
- 它使用 parse_with_prompt 方法而不是 parse 方法
- 它将原始 prompt 和 completion 都传递给 LLM,让 LLM 重新生成

工作原理:
1. 首先尝试使用基础解析器解析 completion
2. 如果解析失败,将原始 prompt 和 completion 发送给 LLM
3. LLM 根据 prompt 的要求重新生成输出
4. 再次尝试解析新生成的输出
5. 可以设置最大重试次数
"""

def __init__(
    self,
    parser: BaseOutputParser[T],
    retry_chain,
    max_retries: int = 1,
):
    """
    初始化 RetryOutputParser

    Args:
        parser: 基础解析器
        retry_chain: 用于重试的链(通常是 Prompt -> LLM -> StrOutputParser)
        max_retries: 最大重试次数
    """
    self.parser = parser
    self.retry_chain = retry_chain
    self.max_retries = max_retries

@classmethod
def from_llm(
    cls,
    llm,
    parser: BaseOutputParser[T],
    prompt=None,
    max_retries: int = 1,
) -> "RetryOutputParser[T]":
    """
    从 LLM 创建 RetryOutputParser

    Args:
        llm: 用于重试的语言模型
        parser: 基础解析器
        prompt: 重试提示模板(可选,有默认模板)
        max_retries: 最大重试次数

    Returns:
        RetryOutputParser 实例
    """
    from langchain.prompts import PromptTemplate
    from langchain.output_parsers import StrOutputParser

    # 默认重试提示模板
    if prompt is None:
        retry_template = """Prompt:

{prompt} Completion: {completion}

上面的 Completion 没有满足 Prompt 中的约束要求。 请重新生成一个满足要求的输出:""" prompt = PromptTemplate.from_template(retry_template)

    # 创建重试链:Prompt -> LLM -> StrOutputParser
    retry_chain = _SimpleChain(prompt, llm, StrOutputParser())

    return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)

def parse_with_prompt(self, completion: str, prompt_value) -> T:
    """
    使用 prompt 解析输出,如果失败则尝试重试

    Args:
        completion: LLM 的输出文本
        prompt_value: 原始提示词(可以是字符串或 PromptValue 对象)

    Returns:
        T: 解析后的结果

    Raises:
        OutputParserException: 如果重试后仍然无法解析
    """
    # 将 prompt_value 转换为字符串
    if hasattr(prompt_value, 'to_string'):
        prompt_str = prompt_value.to_string()
    else:
        prompt_str = str(prompt_value)

    retries = 0

    while retries <= self.max_retries:
        try:
            # 尝试使用基础解析器解析
            return self.parser.parse(completion)
        except (ValueError, OutputParserException, Exception) as e:
            # 如果已达到最大重试次数,抛出异常
            if retries >= self.max_retries:
                raise OutputParserException(
                    f"解析失败,已重试 {retries} 次: {e}",
                    llm_output=completion
                )

            retries += 1
            print(f"  第 {retries} 次尝试重试...")

            # 使用 LLM 重新生成输出
            try:
                if hasattr(self.retry_chain, 'invoke'):
                    # 新式链式调用
                    completion = self.retry_chain.invoke({
                        "prompt": prompt_str,
                        "completion": completion,
                    })
                elif hasattr(self.retry_chain, 'run'):
                    # 旧式链式调用
                    completion = self.retry_chain.run(
                        prompt=prompt_str,
                        completion=completion,
                    )
                else:
                    # 直接调用
                    completion = self.retry_chain({
                        "prompt": prompt_str,
                        "completion": completion,
                    })

                # 确保返回的是字符串
                if not isinstance(completion, str):
                    completion = str(completion)

            except Exception as retry_error:
                raise OutputParserException(
                    f"重试输出时出错: {retry_error}",
                    llm_output=completion
                )

    raise OutputParserException(
        f"解析失败,已重试 {self.max_retries} 次",
        llm_output=completion
    )

def parse(self, completion: str) -> T:
    """
    此解析器只能通过 parse_with_prompt 方法调用

    Raises:
        NotImplementedError: 总是抛出此异常
    """
    raise NotImplementedError(
        "RetryOutputParser 只能通过 parse_with_prompt 方法调用,"
        "需要提供原始 prompt。"
    )

def get_format_instructions(self) -> str:
    """
    获取格式说明(委托给基础解析器)

    Returns:
        str: 格式说明文本
    """
    try:
        return self.parser.get_format_instructions()
    except (AttributeError, NotImplementedError):
        return "请确保输出格式正确。"

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"RetryOutputParser(parser={self.parser}, max_retries={self.max_retries})"

定义带错误信息的重试输出解析器类 #

class RetryWithErrorOutputParser(BaseOutputParser[T]): """ 带错误信息的重试输出解析器

与 RetryOutputParser 类似,但会将错误信息也传递给 LLM。
这为 LLM 提供了更多上下文,理论上能更好地修复输出。

工作原理:
1. 首先尝试使用基础解析器解析 completion
2. 如果解析失败,将原始 prompt、completion 和错误信息发送给 LLM
3. LLM 根据 prompt 的要求和错误信息重新生成输出
4. 再次尝试解析新生成的输出
5. 可以设置最大重试次数
"""

def __init__(
    self,
    parser: BaseOutputParser[T],
    retry_chain,
    max_retries: int = 1,
):
    """
    初始化 RetryWithErrorOutputParser

    Args:
        parser: 基础解析器
        retry_chain: 用于重试的链(通常是 Prompt -> LLM -> StrOutputParser)
        max_retries: 最大重试次数
    """
    self.parser = parser
    self.retry_chain = retry_chain
    self.max_retries = max_retries

@classmethod
def from_llm(
    cls,
    llm,
    parser: BaseOutputParser[T],
    prompt=None,
    max_retries: int = 1,
) -> "RetryWithErrorOutputParser[T]":
    """
    从 LLM 创建 RetryWithErrorOutputParser

    Args:
        llm: 用于重试的语言模型
        parser: 基础解析器
        prompt: 重试提示模板(可选,有默认模板)
        max_retries: 最大重试次数

    Returns:
        RetryWithErrorOutputParser 实例
    """
    from langchain.prompts import PromptTemplate
    from langchain.output_parsers import StrOutputParser

    # 默认重试提示模板(包含错误信息)
    if prompt is None:
        retry_template = """Prompt:

{prompt} Completion: {completion}

上面的 Completion 没有满足 Prompt 中的约束要求。 错误详情: {error} 请根据错误信息重新生成一个满足要求的输出:""" prompt = PromptTemplate.from_template(retry_template)

    # 创建重试链:Prompt -> LLM -> StrOutputParser
    retry_chain = _SimpleChain(prompt, llm, StrOutputParser())

    return cls(parser=parser, retry_chain=retry_chain, max_retries=max_retries)

def parse_with_prompt(self, completion: str, prompt_value) -> T:
    """
    使用 prompt 解析输出,如果失败则尝试重试(包含错误信息)

    Args:
        completion: LLM 的输出文本
        prompt_value: 原始提示词(可以是字符串或 PromptValue 对象)

    Returns:
        T: 解析后的结果

    Raises:
        OutputParserException: 如果重试后仍然无法解析
    """
    # 将 prompt_value 转换为字符串
    if hasattr(prompt_value, 'to_string'):
        prompt_str = prompt_value.to_string()
    else:
        prompt_str = str(prompt_value)

    retries = 0

    while retries <= self.max_retries:
        try:
            # 尝试使用基础解析器解析
            return self.parser.parse(completion)
        except (ValueError, OutputParserException, Exception) as e:
            # 如果已达到最大重试次数,抛出异常
            if retries >= self.max_retries:
                raise OutputParserException(
                    f"解析失败,已重试 {retries} 次: {e}",
                    llm_output=completion
                )

            retries += 1
            print(f"  第 {retries} 次尝试重试(带错误信息)...")

            # 使用 LLM 重新生成输出(包含错误信息)
            try:
                if hasattr(self.retry_chain, 'invoke'):
                    # 新式链式调用
                    completion = self.retry_chain.invoke({
                        "prompt": prompt_str,
                        "completion": completion,
                        "error": str(e),
                    })
                elif hasattr(self.retry_chain, 'run'):
                    # 旧式链式调用
                    completion = self.retry_chain.run(
                        prompt=prompt_str,
                        completion=completion,
                        error=str(e),
                    )
                else:
                    # 直接调用
                    completion = self.retry_chain({
                        "prompt": prompt_str,
                        "completion": completion,
                        "error": str(e),
                    })

                # 确保返回的是字符串
                if not isinstance(completion, str):
                    completion = str(completion)

            except Exception as retry_error:
                raise OutputParserException(
                    f"重试输出时出错: {retry_error}",
                    llm_output=completion
                )

    raise OutputParserException(
        f"解析失败,已重试 {self.max_retries} 次",
        llm_output=completion
    )

def parse(self, completion: str) -> T:
    """
    此解析器只能通过 parse_with_prompt 方法调用

    Raises:
        NotImplementedError: 总是抛出此异常
    """
    raise NotImplementedError(
        "RetryWithErrorOutputParser 只能通过 parse_with_prompt 方法调用,"
        "需要提供原始 prompt。"
    )

def get_format_instructions(self) -> str:
    """
    获取格式说明(委托给基础解析器)

    Returns:
        str: 格式说明文本
    """
    try:
        return self.parser.get_format_instructions()
    except (AttributeError, NotImplementedError):
        return "请确保输出格式正确。"

def __repr__(self) -> str:
    """返回解析器的字符串表示"""
    return f"RetryWithErrorOutputParser(parser={self.parser}, max_retries={self.max_retries})"

+# 延迟导入 runnables 模块以启用 LCEL 支持 +# 使用 try-except 避免循环导入问题 +try:

  • import langchain.runnables # noqa: F401 +except ImportError:
  • pass +

## 23. Runnable
Runnable 是 LangChain 的统一运行接口,表示“可以运行的对象”。所有实现了 `invoke`、`batch` 和 `stream` 方法的组件,都可以被当作 Runnable 使用。这种统一接口设计让 Prompt、LLM、输出解析器、流程链等不同类型的对象拥有一致的调用方式,极大简化了多组件协作与组合的复杂度。

**核心方法说明:**
- `invoke(input)`:同步处理单个输入,返回输出结果。适合一次性单步处理。
- `batch(inputs)`:批量处理一组输入,提高效率。常用于一次处理多个请求。
- `stream(input)`:流式逐步输出,支持大文本或交互式场景。

**为什么要有 Runnable?**
LangChain 致力于“可组合性”,一个大的工作流可能需要 Prompt → LLM → 解析器等多个步骤。统一接口后,可以用“管道”式的方式将多个组件串联,使用 `|` 操作符组合构建复杂链路,同时保持类型和调用方式的统一。例如:

```python
prompt | llm | output_parser

这个表达式就是将 3 个可运行对象无缝组合,从提示模板到大模型再到解析,构成一个完整的链。

常见支持 Runnable 的组件:

  • PromptTemplate(通过 LCEL 支持)
  • ChatOpenAI(LLM 模型)
  • BaseOutputParser(输出解析器)
  • RunnableSequence(可组合链)
  • 你自定义的 Runnable 实现类

应用场景举例:

  • 统一处理链条中的各个模块,实现端到端自动化
  • 将通用处理流程抽象为可复用的工作单元
  • 类型安全、结构标准化、接口简单明了
  • 支持复杂多步骤流程或分支操作

注意事项:

  • 保证输入输出类型一致
  • 捕捉异常,设计合理的错误处理逻辑
  • 善用批处理和流式处理提升效率

23.1. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
+# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

+from typing import Any, List
+from abc import ABC, abstractmethod
+
+
+# 定义 Runnable 接口
+class Runnable(ABC):
+   """Runnable 接口的抽象定义
+   
+   Runnable 是 LangChain 的核心接口,表示可以运行的对象。
+   它定义了统一的操作接口:invoke、batch、stream 等。
+   所有支持这些方法的对象都可以被视为 Runnable。
+   """
+   
+   @abstractmethod
+   def invoke(self, input_data: Any, **kwargs) -> Any:
+       """
+       同步调用,处理单个输入
+       
+       Args:
+           input_data: 输入数据
+           **kwargs: 额外的参数
+       
+       Returns:
+           处理后的输出
+       """
+       pass
+   
+   def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
+       """
+       批处理,处理多个输入(默认实现)
+       
+       Args:
+           inputs: 输入数据列表
+           **kwargs: 额外的参数
+       
+       Returns:
+           处理后的输出列表
+       """
+       return [self.invoke(input_data, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data: Any, **kwargs):
+       """
+       流式处理,逐步产生输出(默认实现)
+       
+       Args:
+           input_data: 输入数据
+           **kwargs: 额外的参数
+       
+       Yields:
+           处理后的输出块
+       """
+       result = self.invoke(input_data, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """
+       支持 | 操作符,用于组合 Runnable
+       
+       Args:
+           other: 另一个 Runnable 或可调用对象
+       
+       Returns:
+           RunnableSequence: 组合后的序列
+       """
+       if hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+
+
+class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                result = step.invoke(result, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

23.2. Runnable.py #

23.Runnable.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
#from langchain_core.runnables import Runnable
#import json
#from typing import Any, List

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
from langchain.runnables import Runnable
import json
from typing import Any, List

print("=" * 60)
print("Runnable 接口演示")
print("=" * 60)
print("\nRunnable 是 LangChain 的核心接口,表示可以运行的对象。")
print("它定义了统一的操作接口:invoke、batch、stream 等。")
print("所有支持这些方法的对象都可以被视为 Runnable。\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: Runnable 接口的基本概念")
print("=" * 60)

print("\nRunnable 接口定义的方法:")
print("1. invoke(input): 同步调用,处理单个输入")
print("2. batch(inputs): 批处理,处理多个输入")
print("3. stream(input): 流式处理,逐步产生输出")
print("\n所有支持这些方法的对象都可以被视为 Runnable:")
print("- PromptTemplate(通过 LCEL 支持)")
print("- ChatOpenAI(LLM 模型)")
print("- BaseOutputParser(输出解析器)")
print("- RunnableSequence(链式组合)")

print("\n" + "=" * 60)
print("演示 2: 检查对象是否是 Runnable")
print("=" * 60)

# 创建一些对象
prompt = PromptTemplate.from_template("请回答:{question}")
chain = prompt | llm | parser

print("\n检查对象是否支持 Runnable 方法:")
print("-" * 60)

objects_to_check = [
    ("prompt", prompt),
    ("llm", llm),
    ("parser", parser),
    ("chain", chain),
]

for name, obj in objects_to_check:
    has_invoke = hasattr(obj, 'invoke')
    has_batch = hasattr(obj, 'batch')
    has_stream = hasattr(obj, 'stream')
    is_runnable = has_invoke and has_batch and has_stream

    print(f"\n{name}:")
    print(f"  类型: {type(obj).__name__}")
    print(f"  有 invoke 方法: {has_invoke}")
    print(f"  有 batch 方法: {has_batch}")
    print(f"  有 stream 方法: {has_stream}")
    print(f"  是 Runnable: {is_runnable}")

print("\n" + "=" * 60)
print("演示 3: 自定义 Runnable 实现")
print("=" * 60)

# 创建自定义 Runnable 实现
class AddOneRunnable(Runnable):
    """自定义 Runnable:将输入加 1"""

    def invoke(self, input_data: Any, config=None, **kwargs) -> Any:
        """将输入加 1"""
        if isinstance(input_data, (int, float)):
            return input_data + 1
        elif isinstance(input_data, str):
            try:
                return str(int(input_data) + 1)
            except ValueError:
                return input_data
        return input_data

    def stream(self, input_data: Any, config=None, **kwargs):
        """流式处理:逐步输出"""
        result = self.invoke(input_data, config=config, **kwargs)
        # 模拟逐步输出
        if isinstance(result, str):
            for char in result:
                yield char
        else:
            yield result

class MultiplyRunnable(Runnable):
    """自定义 Runnable:将输入乘以指定倍数"""

    def __init__(self, multiplier: int = 2):
        self.multiplier = multiplier

    def invoke(self, input_data: Any, config=None, **kwargs) -> Any:
        """将输入乘以倍数"""
        if isinstance(input_data, (int, float)):
            return input_data * self.multiplier
        return input_data

print("\n创建自定义 Runnable:")
print("-" * 60)

# 创建自定义 Runnable 实例
add_one = AddOneRunnable()
multiply_by_2 = MultiplyRunnable(multiplier=2)

# 测试 invoke
print(f"\nadd_one.invoke(5) = {add_one.invoke(5)}")
print(f"multiply_by_2.invoke(5) = {multiply_by_2.invoke(5)}")

# 测试 batch
print(f"\nadd_one.batch([1, 2, 3]) = {add_one.batch([1, 2, 3])}")
print(f"multiply_by_2.batch([1, 2, 3]) = {multiply_by_2.batch([1, 2, 3])}")

# 测试 stream
print(f"\nadd_one.stream('123'):")
for chunk in add_one.stream('123'):
    print(f"  chunk: {chunk}")

print("\n" + "=" * 60)
print("演示 4: RunnableSequence 作为 Runnable")
print("=" * 60)

# 创建链(RunnableSequence)
chain = prompt | llm | parser

print("\nRunnableSequence 实现了 Runnable 接口:")
print("-" * 60)

print(f"chain 类型: {type(chain).__name__}")
print(f"是否有 invoke 方法: {hasattr(chain, 'invoke')}")
print(f"是否有 batch 方法: {hasattr(chain, 'batch')}")
print(f"是否有 stream 方法: {hasattr(chain, 'stream')}")

# 测试 invoke
print("\n使用 invoke 方法:")
result = chain.invoke({"question": "什么是 Python?"})
print(f"结果: {result[:50]}...")

# 测试 batch
print("\n使用 batch 方法:")
batch_inputs = [
    {"question": "什么是 Python?"},
    {"question": "什么是机器学习?"},
]
batch_results = chain.batch(batch_inputs)
for i, (input_data, result) in enumerate(zip(batch_inputs, batch_results), 1):
    print(f"{i}. {input_data['question']}: {result[:30]}...")

# 测试 stream
print("\n使用 stream 方法:")
print("流式输出:", end="")
for chunk in chain.stream({"question": "用一句话介绍 Python"}):
    print(chunk, end="", flush=True)
print()

print("\n" + "=" * 60)
print("演示 5: 组合多个 Runnable")
print("=" * 60)

# 组合自定义 Runnable
custom_chain = add_one | multiply_by_2

print("\n组合自定义 Runnable:")
print("custom_chain = add_one | multiply_by_2")
print("-" * 60)

test_inputs = [1, 2, 3, 4, 5]

print("\n测试组合链:")
for input_val in test_inputs:
    result = custom_chain.invoke(input_val)
    print(f"  {input_val} -> add_one -> multiply_by_2 = {result}")

print("\n批处理:")
results = custom_chain.batch(test_inputs)
print(f"  输入: {test_inputs}")
print(f"  输出: {results}")

print("\n" + "=" * 60)
print("演示 6: Runnable 的类型检查")
print("=" * 60)

# 类型检查函数
def is_runnable(obj: Any) -> bool:
    """检查对象是否是 Runnable"""
    return (
        hasattr(obj, 'invoke') and 
        callable(getattr(obj, 'invoke', None)) and
        hasattr(obj, 'batch') and 
        callable(getattr(obj, 'batch', None)) and
        hasattr(obj, 'stream') and 
        callable(getattr(obj, 'stream', None))
    )

print("\n检查各种对象是否是 Runnable:")
print("-" * 60)

test_objects = [
    ("prompt", prompt),
    ("llm", llm),
    ("parser", parser),
    ("chain", chain),
    ("add_one", add_one),
    ("multiply_by_2", multiply_by_2),
    ("custom_chain", custom_chain),
]

for name, obj in test_objects:
    result = is_runnable(obj)
    print(f"  {name:15} -> {result}")

print("\n" + "=" * 60)
print("演示 7: Runnable 的实际应用")
print("=" * 60)

# 创建文本处理链
class TextProcessorRunnable(Runnable):
    """文本处理 Runnable"""

    def __init__(self, operation: str = "uppercase"):
        self.operation = operation

    def invoke(self, input_data: Any, config=None, **kwargs) -> Any:
        """处理文本"""
        if not isinstance(input_data, str):
            input_data = str(input_data)

        if self.operation == "uppercase":
            return input_data.upper()
        elif self.operation == "lowercase":
            return input_data.lower()
        elif self.operation == "reverse":
            return input_data[::-1]
        elif self.operation == "capitalize":
            return input_data.capitalize()
        return input_data

# 创建文本处理链
text_prompt = PromptTemplate.from_template("请介绍:{topic}")
text_chain = text_prompt | llm | parser | TextProcessorRunnable("uppercase")

print("\n创建文本处理链:")
print("text_chain = text_prompt | llm | parser | TextProcessorRunnable('uppercase')")
print("-" * 60)

topics = ["Python", "机器学习"]

for topic in topics:
    result = text_chain.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"处理结果: {result[:50]}...")

print("\n" + "=" * 60)
print("演示 8: Runnable 的链式组合")
print("=" * 60)

# 创建多个处理步骤
step1 = AddOneRunnable()
step2 = MultiplyRunnable(multiplier=3)
step3 = AddOneRunnable()

# 组合成链
processing_chain = step1 | step2 | step3

print("\n创建处理链:")
print("processing_chain = step1 (add_one) | step2 (multiply_by_3) | step3 (add_one)")
print("-" * 60)

test_value = 5
result = processing_chain.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"步骤 1 (add_one): {test_value + 1} = {test_value + 1}")
print(f"步骤 2 (multiply_by_3): {test_value + 1} * 3 = {(test_value + 1) * 3}")
print(f"步骤 3 (add_one): {(test_value + 1) * 3} + 1 = {result}")
print(f"最终结果: {result}")

print("\n批处理:")
batch_results = processing_chain.batch([1, 2, 3, 4, 5])
print(f"  输入: [1, 2, 3, 4, 5]")
print(f"  输出: {batch_results}")

print("\n" + "=" * 60)
print("演示 9: Runnable 与 LLM 链结合")
print("=" * 60)

# 创建包含自定义处理的链
class SummaryRunnable(Runnable):
    """摘要 Runnable:提取前 N 个字符"""

    def __init__(self, max_length: int = 50):
        self.max_length = max_length

    def invoke(self, input_data: Any, config=None, **kwargs) -> Any:
        """提取摘要"""
        text = str(input_data)
        if len(text) > self.max_length:
            return text[:self.max_length] + "..."
        return text

# 创建摘要链
summary_chain = prompt | llm | parser | SummaryRunnable(max_length=30)

print("\n创建摘要链:")
print("summary_chain = prompt | llm | parser | SummaryRunnable(max_length=30)")
print("-" * 60)

questions = [
    "什么是 Python?",
    "Python 的主要特点是什么?",
]

for question in questions:
    result = summary_chain.invoke({"question": question})
    print(f"\n问题: {question}")
    print(f"摘要: {result}")

print("\n" + "=" * 60)
print("演示 10: Runnable 的通用处理模式")
print("=" * 60)

# 通用处理函数
def process_with_runnable(runnable: Any, input_data: Any, method: str = "invoke"):
    """使用 Runnable 处理数据"""
    if method == "invoke":
        return runnable.invoke(input_data)
    elif method == "batch":
        return runnable.batch(input_data if isinstance(input_data, list) else [input_data])
    elif method == "stream":
        return list(runnable.stream(input_data))
    else:
        raise ValueError(f"不支持的方法: {method}")

print("\n通用处理函数演示:")
print("-" * 60)

# 测试不同的 Runnable
runnables_to_test = [
    ("add_one", add_one),
    ("multiply_by_2", multiply_by_2),
    ("chain", chain),
]

for name, runnable in runnables_to_test:
    print(f"\n{runnable}:")
    # invoke
    if name == "chain":
        result = process_with_runnable(runnable, {"question": "测试"}, "invoke")
    else:
        result = process_with_runnable(runnable, 5, "invoke")
    print(f"  invoke: {result[:30] if isinstance(result, str) else result}...")

    # batch
    if name == "chain":
        batch_input = [{"question": "测试1"}, {"question": "测试2"}]
    else:
        batch_input = [1, 2, 3]
    batch_result = process_with_runnable(runnable, batch_input, "batch")
    print(f"  batch: {batch_result[:2] if isinstance(batch_result, list) else batch_result}...")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnable 接口的核心概念:")
print("1. 统一接口:所有 Runnable 都支持 invoke、batch、stream 方法")
print("2. 可组合:使用 | 操作符可以组合多个 Runnable")
print("3. 类型安全:支持类型提示和验证")
print("4. 灵活:可以创建自定义 Runnable 实现")
print("\nRunnable 的主要方法:")
print("- invoke(input): 同步调用,处理单个输入")
print("- batch(inputs): 批处理,处理多个输入")
print("- stream(input): 流式处理,逐步产生输出")
print("\nRunnable 的实现类:")
print("- RunnableSequence: 链式组合的 Runnable")
print("- PromptTemplate: 通过 LCEL 支持成为 Runnable")
print("- ChatOpenAI: LLM 模型是 Runnable")
print("- BaseOutputParser: 输出解析器是 Runnable")
print("- 自定义类: 可以实现 Runnable 接口")
print("\n使用场景:")
print("- 统一处理不同类型的组件")
print("- 创建可重用的处理流程")
print("- 实现复杂的多步骤处理")
print("- 提供一致的 API 接口")
print("\n注意事项:")
print("- 确保实现所有必需的方法")
print("- 保持输入输出类型的一致性")
print("- 考虑错误处理和边界情况")
print("- 利用批处理和流式处理提高效率")

24. RunnableSequence和RunnableLambda #

RunnableSequence 和 RunnableLambda是 LCEL(LangChain Expression Language)中用于构建和组合数据处理流程的关键组件。

  • RunnableSequence:允许你将多个 Runnable 串联在一起,形成一个「处理链」。数据会依次经过链中的每一步。例如,你可以先对输入进行格式化、再用大语言模型(LLM)处理、最后解析输出。这种顺序执行让多步骤处理变得清晰高效。

  • RunnableLambda:用于将任意 Python 函数包装为一个兼容 LCEL 的 Runnable 对象。这样一来,普通函数也能像链条里的“积木”那样,与其他 LangChain 组件(如 PromptTemplate、LLM、解析器等)无缝组合。这极大提升了链式开发的灵活性和重用性。

这些组件可通过「管道符」| 运算符组合,用于表达数据在链中的流动。例如:

string_chain = (
    RunnableLambda(to_uppercase)
    | RunnableLambda(add_exclamation)
    | RunnableLambda(reverse_text)
)

这意味着输入将依序经过三个步骤,每一步的输出就是下一步的输入。你也可以直接用 RunnableSequence([step1, step2, step3]) 明确创建链路。

应用场景:数据格式转换、冗余或脏数据预处理、定制化的数据提取、与 LLM 结合后的后处理抽取等。

注意事项:使用这类链式组合时,要确保各节点的输入输出类型兼容。对于涉及异步、流式或高性能需求的情况,可考虑其他更适合的 Runnable 变种。

24.1. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

+   def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
+               # 检查 step 的 invoke 方法是否接受 config 参数
+               import inspect
+               try:
+                   sig = inspect.signature(step.invoke)
+                   params = list(sig.parameters.keys())
+                   # 如果 invoke 方法接受 config 参数,传递它
+                   if 'config' in params:
+                       result = step.invoke(result, config=config, **kwargs)
+                   else:
+                       # 如果不接受 config,只传递其他 kwargs
+                       result = step.invoke(result, **kwargs)
+               except (ValueError, TypeError):
+                   # 如果无法检查签名,尝试不传递 config
+                   try:
+                       result = step.invoke(result, **kwargs)
+                   except TypeError:
+                       # 如果还是失败,尝试传递 config(某些实现可能需要)
+                       result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

+   def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
+               result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

+   def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
+       result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


+class RunnableLambda(Runnable):
+   """将 Python 可调用对象转换为 Runnable
+   
+   RunnableLambda 将普通的 Python 函数包装成 Runnable,
+   使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
+   """
+   
+   def __init__(self, func, name=None):
+       """
+       初始化 RunnableLambda
+       
+       Args:
+           func: 要包装的 Python 可调用对象
+           name: 可选的名称,用于调试和显示
+       """
+       if not callable(func):
+           raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
+       self.func = func
+       self.name = name or getattr(func, '__name__', 'lambda')
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """调用函数处理输入"""
+       # 如果函数接受 config 参数,传递它
+       import inspect
+       sig = inspect.signature(self.func)
+       params = list(sig.parameters.keys())
+       
+       if len(params) > 1 and 'config' in params:
+           return self.func(input_data, config=config, **kwargs)
+       elif len(params) > 1 and len(kwargs) > 0:
+           # 尝试传递 kwargs
+           return self.func(input_data, **kwargs)
+       else:
+           return self.func(input_data)
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __repr__(self):
+       return f"RunnableLambda({self.name})"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

24.2. RunnableSequence.py #

24.RunnableSequence.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
#from langchain_core.runnables import Runnable,RunnableSequence,RunnableLambda
#import json
#from typing import Any, List

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
from langchain.runnables import Runnable, RunnableSequence, RunnableLambda
import json
from typing import Any, List

print("=" * 60)
print("RunnableSequence 和 RunnableLambda 演示")
print("=" * 60)
print("\nRunnableSequence: 用于顺序执行多个 Runnable 的序列")
print("RunnableLambda: 将普通 Python 函数包装成 Runnable\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: RunnableSequence 的基本用法")
print("=" * 60)

# 使用 | 操作符创建序列
prompt = PromptTemplate.from_template("请回答:{question}")
chain = prompt | llm | parser

print("\n创建链:")
print("chain = prompt | llm | parser")
print(f"链类型: {type(chain).__name__}")

# 测试 invoke
print("\n使用 invoke 方法:")
result = chain.invoke({"question": "什么是 Python?"})
print(f"结果: {result[:50]}...")

# 测试 batch
print("\n使用 batch 方法:")
batch_inputs = [
    {"question": "什么是 Python?"},
    {"question": "什么是机器学习?"},
]
batch_results = chain.batch(batch_inputs)
for i, (input_data, result) in enumerate(zip(batch_inputs, batch_results), 1):
    print(f"{i}. {input_data['question']}: {result[:30]}...")

print("\n" + "=" * 60)
print("演示 2: RunnableLambda 的基本用法")
print("=" * 60)

# 创建简单的函数
def add_one(x: int) -> int:
    """将输入加 1"""
    return x + 1

def multiply_by_two(x: int) -> int:
    """将输入乘以 2"""
    return x * 2

def square(x: int) -> int:
    """计算平方"""
    return x ** 2

# 将函数包装成 RunnableLambda
add_one_runnable = RunnableLambda(add_one)
multiply_runnable = RunnableLambda(multiply_by_two)
square_runnable = RunnableLambda(square)

print("\n创建 RunnableLambda:")
print(f"add_one_runnable: {add_one_runnable}")
print(f"multiply_runnable: {multiply_runnable}")
print(f"square_runnable: {square_runnable}")

# 测试 invoke
print("\n使用 invoke 方法:")
print(f"add_one_runnable.invoke(5) = {add_one_runnable.invoke(5)}")
print(f"multiply_runnable.invoke(5) = {multiply_runnable.invoke(5)}")
print(f"square_runnable.invoke(5) = {square_runnable.invoke(5)}")

# 测试 batch
print("\n使用 batch 方法:")
inputs = [1, 2, 3, 4, 5]
print(f"输入: {inputs}")
print(f"add_one_runnable.batch({inputs}) = {add_one_runnable.batch(inputs)}")
print(f"multiply_runnable.batch({inputs}) = {multiply_runnable.batch(inputs)}")
print(f"square_runnable.batch({inputs}) = {square_runnable.batch(inputs)}")

print("\n" + "=" * 60)
print("演示 3: RunnableLambda 组合成序列")
print("=" * 60)

# 组合多个 RunnableLambda
sequence = add_one_runnable | multiply_runnable | square_runnable

print("\n创建序列:")
print("sequence = add_one_runnable | multiply_runnable | square_runnable")
print(f"序列类型: {type(sequence).__name__}")

# 测试
test_value = 3
result = sequence.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"步骤 1 (add_one): {test_value} + 1 = {test_value + 1}")
print(f"步骤 2 (multiply): {test_value + 1} * 2 = {(test_value + 1) * 2}")
print(f"步骤 3 (square): {(test_value + 1) * 2} ** 2 = {result}")
print(f"最终结果: {result}")

# 批处理
print("\n批处理:")
batch_inputs = [1, 2, 3]
batch_results = sequence.batch(batch_inputs)
print(f"输入: {batch_inputs}")
print(f"输出: {batch_results}")

print("\n" + "=" * 60)
print("演示 4: RunnableLambda 与 LLM 链结合")
print("=" * 60)

# 创建文本处理函数
def extract_first_sentence(text: str) -> str:
    """提取第一句话"""
    sentences = text.split('。')
    return sentences[0] + '。' if sentences else text

def add_prefix(text: str, prefix: str = "摘要:") -> str:
    """添加前缀"""
    return f"{prefix}{text}"

# 创建处理链
text_chain = prompt | llm | parser | RunnableLambda(extract_first_sentence) | RunnableLambda(lambda x: add_prefix(x, "摘要:"))

print("\n创建文本处理链:")
print("text_chain = prompt | llm | parser | RunnableLambda(extract_first_sentence) | RunnableLambda(add_prefix)")

questions = ["什么是 Python?", "什么是机器学习?"]

for question in questions:
    result = text_chain.invoke({"question": question})
    print(f"\n问题: {question}")
    print(f"处理结果: {result[:60]}...")

print("\n" + "=" * 60)
print("演示 5: RunnableLambda 处理复杂数据")
print("=" * 60)

# 处理字典数据
def extract_name(data: dict) -> str:
    """从字典中提取 name 字段"""
    return data.get('name', 'Unknown')

def format_greeting(name: str) -> str:
    """格式化问候语"""
    return f"你好,{name}!"

# 创建处理链
data_chain = RunnableLambda(extract_name) | RunnableLambda(format_greeting)

print("\n创建数据处理链:")
print("data_chain = RunnableLambda(extract_name) | RunnableLambda(format_greeting)")

test_data = [
    {"name": "张三", "age": 25},
    {"name": "李四", "age": 30},
    {"name": "王五", "age": 28},
]

print("\n处理数据:")
for data in test_data:
    result = data_chain.invoke(data)
    print(f"  {data} -> {result}")

print("\n" + "=" * 60)
print("演示 6: RunnableLambda 与 JSON 解析结合")
print("=" * 60)

# 创建 JSON 处理函数
def parse_json_safely(text: str) -> dict:
    """安全解析 JSON"""
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        return {"error": "Invalid JSON", "text": text}

def extract_fields(data: dict) -> dict:
    """提取特定字段"""
    return {
        "name": data.get("name", ""),
        "age": data.get("age", 0),
    }

# 创建 JSON 处理链
json_prompt = PromptTemplate.from_template("请以 JSON 格式返回一个人的信息,包含 name 和 age 字段。主题:{topic}")
json_chain = json_prompt | llm | parser | RunnableLambda(parse_json_safely) | RunnableLambda(extract_fields)

print("\n创建 JSON 处理链:")
print("json_chain = json_prompt | llm | parser | RunnableLambda(parse_json_safely) | RunnableLambda(extract_fields)")

topics = ["程序员", "科学家"]

for topic in topics:
    result = json_chain.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"提取的字段: {result}")

print("\n" + "=" * 60)
print("演示 7: RunnableSequence 手动创建")
print("=" * 60)

# 手动创建 RunnableSequence
manual_sequence = RunnableSequence(
    add_one_runnable,
    multiply_runnable,
    square_runnable
)

print("\n手动创建序列:")
print("manual_sequence = RunnableSequence(add_one_runnable, multiply_runnable, square_runnable)")
print(f"序列: {manual_sequence}")

# 测试
test_value = 2
result = manual_sequence.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"结果: {result}")

# 与 | 操作符创建的序列比较
auto_sequence = add_one_runnable | multiply_runnable | square_runnable
auto_result = auto_sequence.invoke(test_value)
print(f"\n使用 | 操作符创建的结果: {auto_result}")
print(f"结果相同: {result == auto_result}")

print("\n" + "=" * 60)
print("演示 8: RunnableLambda 处理列表数据")
print("=" * 60)

# 处理列表的函数
def sum_list(numbers: list) -> int:
    """计算列表总和"""
    return sum(numbers)

def double_list(numbers: list) -> list:
    """将列表中的每个元素翻倍"""
    return [x * 2 for x in numbers]

# 创建处理链
list_chain = RunnableLambda(double_list) | RunnableLambda(sum_list)

print("\n创建列表处理链:")
print("list_chain = RunnableLambda(double_list) | RunnableLambda(sum_list)")

test_lists = [
    [1, 2, 3],
    [4, 5, 6],
    [10, 20, 30],
]

print("\n处理列表:")
for test_list in test_lists:
    result = list_chain.invoke(test_list)
    print(f"  输入: {test_list}")
    print(f"  步骤 1 (double): {double_list(test_list)}")
    print(f"  步骤 2 (sum): {result}")
    print()

print("\n" + "=" * 60)
print("演示 9: RunnableLambda 条件处理")
print("=" * 60)

# 条件处理函数
def is_positive(x: int) -> bool:
    """检查是否为正数"""
    return x > 0

def process_positive(x: int) -> str:
    """处理正数"""
    return f"{x} 是正数"

def process_negative(x: int) -> str:
    """处理负数或零"""
    return f"{x} 不是正数"

def conditional_process(x: int) -> str:
    """条件处理"""
    if is_positive(x):
        return process_positive(x)
    else:
        return process_negative(x)

# 创建条件处理链
conditional_chain = RunnableLambda(conditional_process)

print("\n创建条件处理链:")
print("conditional_chain = RunnableLambda(conditional_process)")

test_values = [-5, 0, 3, 10, -1]

print("\n条件处理:")
for value in test_values:
    result = conditional_chain.invoke(value)
    print(f"  {value} -> {result}")

print("\n" + "=" * 60)
print("演示 10: RunnableLambda 与字符串处理")
print("=" * 60)

# 字符串处理函数
def to_uppercase(text: str) -> str:
    """转换为大写"""
    return text.upper()

def add_exclamation(text: str) -> str:
    """添加感叹号"""
    return f"{text}!"

def reverse_text(text: str) -> str:
    """反转文本"""
    return text[::-1]

# 创建字符串处理链
string_chain = (
    RunnableLambda(to_uppercase) |
    RunnableLambda(add_exclamation) |
    RunnableLambda(reverse_text)
)

print("\n创建字符串处理链:")
print("string_chain = RunnableLambda(to_uppercase) | RunnableLambda(add_exclamation) | RunnableLambda(reverse_text)")

test_strings = ["hello", "world", "python"]

print("\n处理字符串:")
for test_str in test_strings:
    result = string_chain.invoke(test_str)
    print(f"  输入: {test_str}")
    print(f"  步骤 1 (uppercase): {to_uppercase(test_str)}")
    print(f"  步骤 2 (exclamation): {add_exclamation(to_uppercase(test_str))}")
    print(f"  步骤 3 (reverse): {result}")
    print()

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnableSequence 的核心概念:")
print("1. 顺序执行:按顺序执行多个 Runnable")
print("2. 数据传递:前一个 Runnable 的输出作为下一个的输入")
print("3. 使用 | 操作符:可以方便地组合 Runnable")
print("4. 手动创建:也可以使用 RunnableSequence 构造函数手动创建")
print("\nRunnableLambda 的核心概念:")
print("1. 函数包装:将普通 Python 函数包装成 Runnable")
print("2. 无缝集成:可以在 LCEL 链中与其他组件组合")
print("3. 灵活处理:可以处理各种类型的数据")
print("4. 简单易用:不需要创建完整的类,只需提供函数")
print("\n使用场景:")
print("- 数据转换和处理")
print("- 条件逻辑处理")
print("- 数据提取和格式化")
print("- 与 LLM 链结合进行后处理")
print("- 创建可重用的处理步骤")
print("\n注意事项:")
print("- RunnableLambda 适合不需要流式处理的简单函数")
print("- 复杂逻辑建议创建完整的 Runnable 类")
print("- 确保函数的输入输出类型匹配链中的其他组件")
print("- 使用 RunnableSequence 可以清晰地组织处理流程")

25. RunnableParallel #

RunnableParallel 是 LCEL (LangChain Expression Language) 提供的一个强大组件,用于「并行」地执行多个 Runnable,所有分支接收同样的输入,并将每个分支的输出收集成一个字典返回。这和 RunnableSequence 顺序处理不同,RunnableParallel 适用于需要对同一输入同时提取/处理多路信息的场景。

基本原理

  • 输入:所有分支(每个都是可兼容的 Runnable 或 RunnableLambda)都接收到同一个输入数据。
  • 同时执行:每个分支各自独立执行,不受其他分支影响,实现“并行”处理。
  • 输出:返回一个字典,key 为每个分支设定的名字,value 为对应分支的输出结果。

注意:这里的“并行”指的是逻辑上的并发处理,具体是否多线程/多进程与实现有关,但多数场景直接调用即可,无须自己管理并发。

典型用法

def extract_sentiment(text):
    # 伪代码:情感极性分析
    if "好" in text:
        return "positive"
    elif "糟糕" in text:
        return "negative"
    else:
        return "neutral"

def extract_key_phrases(text):
    # 伪代码:抽取关键词
    return [w for w in text.split() if len(w) > 1]

def calculate_length(text):
    return len(text)

# 构建并行链
text_analysis_chain = RunnableParallel({
    "sentiment": RunnableLambda(extract_sentiment),
    "key_phrases": RunnableLambda(extract_key_phrases),
    "length_info": RunnableLambda(calculate_length),
    "original": RunnableLambda(lambda x: x),
})

result = text_analysis_chain.invoke("这个产品非常好用,我很喜欢!")
print(result)
# 输出示例:
# {
#   "sentiment": "positive",
#   "key_phrases": ["这个产品非常好用,我很喜欢!"],
#   "length_info": 14,
#   "original": "这个产品非常好用,我很喜欢!"
# }

典型应用场景

  • 文本分析任务:对一句话同时做情感分析、关键词提取、长度统计等。
  • 多模型结果汇总:同一输入喂给不同 LLM 或处理链,收集多路模型输出。
  • 并行特征工程:对数据同时提取多个特征。

注意事项

  • 输入一致:所有分支接收同样的输入。
  • 输出为字典:结果用 dict 封装,方便统一下游处理。
  • 可嵌套:RunnableParallel 可作为更复杂链路中的子步骤,甚至嵌套使用。
  • 批处理/流式:同样支持 .batch() 方法批量推理,提高性能。

与 RunnableSequence 的区别

  • RunnableSequence 是“顺序执行链”,数据流一环套一环;
  • RunnableParallel 是“并行分支链”,数据流一分多支,再合成 dict。

何时用 RunnableParallel

  • 想要从一个输入并行“提取/处理”出多组结果时,优选 RunnableParallel。
  • 对需要分多路创造性处理或多模型综合判断等复杂场景尤为适用。

25.1. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
    def __repr__(self):
        return f"RunnableLambda({self.name})"


+class RunnableParallel(Runnable):
+   """并行执行多个 Runnable
+   
+   RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
+   然后返回一个字典,包含每个分支的结果。
+   """
+   
+   def __init__(self, steps=None, **kwargs):
+       """
+       初始化 RunnableParallel
+       
+       Args:
+           steps: 字典,键是分支名称,值是对应的 Runnable
+           **kwargs: 也可以直接传递关键字参数,每个参数名作为键
+       """
+       if steps is None:
+           steps = {}
+       elif not isinstance(steps, dict):
+           raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")
+       
+       # 合并 kwargs 中的步骤
+       self.steps = {**steps, **kwargs}
+       
+       # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
+       processed_steps = {}
+       for key, value in self.steps.items():
+           if isinstance(value, dict):
+               # 嵌套字典自动转换为 RunnableParallel
+               processed_steps[key] = RunnableParallel(value)
+           elif hasattr(value, 'invoke') or callable(value):
+               processed_steps[key] = value
+           else:
+               raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
+       self.steps = processed_steps
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """并行执行所有步骤"""
+       results = {}
+       for key, step in self.steps.items():
+           if hasattr(step, 'invoke'):
+               # 检查 step 的 invoke 方法是否接受 config 参数
+               import inspect
+               try:
+                   sig = inspect.signature(step.invoke)
+                   params = list(sig.parameters.keys())
+                   if 'config' in params:
+                       results[key] = step.invoke(input_data, config=config, **kwargs)
+                   else:
+                       results[key] = step.invoke(input_data, **kwargs)
+               except (ValueError, TypeError):
+                   try:
+                       results[key] = step.invoke(input_data, **kwargs)
+                   except TypeError:
+                       results[key] = step.invoke(input_data, config=config, **kwargs)
+           elif callable(step):
+               results[key] = step(input_data)
+           else:
+               raise ValueError(f"步骤 '{key}' 不可调用")
+       return results
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       results = []
+       for input_data in inputs:
+           try:
+               result = self.invoke(input_data, config=config, **kwargs)
+               results.append(result)
+           except Exception as e:
+               results.append(f"错误: {e}")
+       return results
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __repr__(self):
+       return f"RunnableParallel({list(self.steps.keys())})"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

25.2. RunnableParallel.py #

25.RunnableParallel.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
#from langchain_core.runnables import Runnable,RunnableSequence,RunnableLambda,RunnableParallel
#import json
#from typing import Any, List

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
from langchain.runnables import Runnable, RunnableSequence, RunnableLambda, RunnableParallel
import json
from typing import Any, List

print("=" * 60)
print("RunnableParallel 演示")
print("=" * 60)
print("\nRunnableParallel: 用于并行执行多个 Runnable")
print("所有 Runnable 接收相同的输入,返回一个包含所有结果的字典\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: RunnableParallel 的基本用法")
print("=" * 60)

# 创建简单的函数
def add_one(x: int) -> int:
    """将输入加 1"""
    return x + 1

def multiply_by_two(x: int) -> int:
    """将输入乘以 2"""
    return x * 2

def multiply_by_three(x: int) -> int:
    """将输入乘以 3"""
    return x * 3

# 创建 RunnableLambda
add_one_runnable = RunnableLambda(add_one)
mul_two_runnable = RunnableLambda(multiply_by_two)
mul_three_runnable = RunnableLambda(multiply_by_three)

# 创建 RunnableParallel
parallel = RunnableParallel({
    "add_one": add_one_runnable,
    "mul_two": mul_two_runnable,
    "mul_three": mul_three_runnable,
})

print("\n创建 RunnableParallel:")
print("parallel = RunnableParallel({")
print("    'add_one': add_one_runnable,")
print("    'mul_two': mul_two_runnable,")
print("    'mul_three': mul_three_runnable,")
print("})")
print(f"并行对象: {parallel}")

# 测试 invoke
print("\n使用 invoke 方法:")
result = parallel.invoke(5)
print(f"输入: 5")
print(f"结果: {result}")
print(f"  - add_one: {result['add_one']}")
print(f"  - mul_two: {result['mul_two']}")
print(f"  - mul_three: {result['mul_three']}")

# 测试 batch
print("\n使用 batch 方法:")
inputs = [1, 2, 3]
results = parallel.batch(inputs)
print(f"输入: {inputs}")
print(f"结果: {results}")

print("\n" + "=" * 60)
print("演示 2: 使用字典字面量创建 RunnableParallel")
print("=" * 60)

# 在序列中使用字典字面量
sequence = add_one_runnable | {
    "mul_two": mul_two_runnable,
    "mul_three": mul_three_runnable,
}

print("\n创建序列(包含字典字面量):")
print("sequence = add_one_runnable | {")
print("    'mul_two': mul_two_runnable,")
print("    'mul_three': mul_three_runnable,")
print("}")

# 测试
test_value = 3
result = sequence.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"步骤 1 (add_one): {test_value} + 1 = {test_value + 1}")
print(f"步骤 2 (parallel):")
print(f"  - mul_two: {test_value + 1} * 2 = {result['mul_two']}")
print(f"  - mul_three: {test_value + 1} * 3 = {result['mul_three']}")
print(f"最终结果: {result}")

print("\n" + "=" * 60)
print("演示 3: RunnableParallel 与 LLM 链结合")
print("=" * 60)

# 创建多个 LLM 链
joke_prompt = PromptTemplate.from_template("请讲一个关于 {topic} 的笑话")
joke_chain = joke_prompt | llm | parser

poem_prompt = PromptTemplate.from_template("请写一首关于 {topic} 的两行诗")
poem_chain = poem_prompt | llm | parser

summary_prompt = PromptTemplate.from_template("请用一句话总结 {topic}")
summary_chain = summary_prompt | llm | parser

# 创建并行链
parallel_chain = RunnableParallel({
    "joke": joke_chain,
    "poem": poem_chain,
    "summary": summary_chain,
})

print("\n创建并行 LLM 链:")
print("parallel_chain = RunnableParallel({")
print("    'joke': joke_chain,")
print("    'poem': poem_chain,")
print("    'summary': summary_chain,")
print("})")

topics = ["Python", "机器学习"]

for topic in topics:
    print(f"\n主题: {topic}")
    print("-" * 60)
    result = parallel_chain.invoke({"topic": topic})
    print(f"笑话: {result['joke'][:50]}...")
    print(f"诗歌: {result['poem'][:50]}...")
    print(f"总结: {result['summary'][:50]}...")

print("\n" + "=" * 60)
print("演示 4: RunnableParallel 与 RunnableSequence 组合")
print("=" * 60)

# 创建复杂的链
def extract_keywords(text: str) -> str:
    """提取关键词(简化版)"""
    words = text.split()
    return ", ".join(words[:5])

def count_words(text: str) -> int:
    """统计字数"""
    return len(text.split())

# 创建处理链
process_chain = (
    PromptTemplate.from_template("请介绍:{topic}") |
    llm |
    parser |
    {
        "keywords": RunnableLambda(extract_keywords),
        "word_count": RunnableLambda(count_words),
        "original": RunnableLambda(lambda x: x),
    }
)

print("\n创建组合链:")
print("process_chain = prompt | llm | parser | {")
print("    'keywords': RunnableLambda(extract_keywords),")
print("    'word_count': RunnableLambda(count_words),")
print("    'original': RunnableLambda(lambda x: x),")
print("}")

topics = ["人工智能", "区块链"]

for topic in topics:
    result = process_chain.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"关键词: {result['keywords']}")
    print(f"字数: {result['word_count']}")
    print(f"原文: {result['original'][:50]}...")

print("\n" + "=" * 60)
print("演示 5: RunnableParallel 处理不同数据类型")
print("=" * 60)

# 创建处理不同数据类型的函数
def process_string(data: dict) -> str:
    """处理字符串"""
    return f"字符串: {data.get('text', '')}"

def process_number(data: dict) -> int:
    """处理数字"""
    return data.get('number', 0) * 2

def process_list(data: dict) -> list:
    """处理列表"""
    items = data.get('items', [])
    return [item.upper() if isinstance(item, str) else item for item in items]

# 创建并行处理链
data_processor = RunnableParallel({
    "string_result": RunnableLambda(process_string),
    "number_result": RunnableLambda(process_number),
    "list_result": RunnableLambda(process_list),
})

print("\n创建数据处理链:")
print("data_processor = RunnableParallel({")
print("    'string_result': RunnableLambda(process_string),")
print("    'number_result': RunnableLambda(process_number),")
print("    'list_result': RunnableLambda(process_list),")
print("})")

test_data = {
    "text": "hello world",
    "number": 5,
    "items": ["apple", "banana", "cherry"],
}

result = data_processor.invoke(test_data)
print(f"\n输入: {test_data}")
print(f"结果:")
print(f"  - string_result: {result['string_result']}")
print(f"  - number_result: {result['number_result']}")
print(f"  - list_result: {result['list_result']}")

print("\n" + "=" * 60)
print("演示 6: RunnableParallel 与 JSON 解析结合")
print("=" * 60)

# 创建 JSON 处理链
json_prompt = PromptTemplate.from_template("请以 JSON 格式返回关于 {topic} 的信息,包含 name、description、features 字段")
json_parser = JsonOutputParser()

# 创建多个处理链
json_chain = json_prompt | llm | json_parser

def extract_name(data: dict) -> str:
    """提取名称"""
    return data.get("name", "Unknown")

def extract_description(data: dict) -> str:
    """提取描述"""
    return data.get("description", "")

def extract_features(data: dict) -> list:
    """提取特性"""
    return data.get("features", [])

# 创建并行处理链
json_processor = json_chain | {
    "name": RunnableLambda(extract_name),
    "description": RunnableLambda(extract_description),
    "features": RunnableLambda(extract_features),
}

print("\n创建 JSON 处理链:")
print("json_processor = json_chain | {")
print("    'name': RunnableLambda(extract_name),")
print("    'description': RunnableLambda(extract_description),")
print("    'features': RunnableLambda(extract_features),")
print("}")

topics = ["Python", "机器学习"]

for topic in topics:
    result = json_processor.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"名称: {result['name']}")
    print(f"描述: {result['description'][:50]}...")
    print(f"特性: {result['features']}")

print("\n" + "=" * 60)
print("演示 7: RunnableParallel 嵌套使用")
print("=" * 60)

# 创建嵌套的并行处理
def square(x: int) -> int:
    """计算平方"""
    return x ** 2

def cube(x: int) -> int:
    """计算立方"""
    return x ** 3

# 创建嵌套的并行链
nested_parallel = RunnableLambda(add_one) | {
    "math_ops": {
        "square": RunnableLambda(square),
        "cube": RunnableLambda(cube),
    },
    "mul_ops": {
        "mul_two": mul_two_runnable,
        "mul_three": mul_three_runnable,
    },
}

print("\n创建嵌套并行链:")
print("nested_parallel = RunnableLambda(add_one) | {")
print("    'math_ops': {")
print("        'square': RunnableLambda(square),")
print("        'cube': RunnableLambda(cube),")
print("    },")
print("    'mul_ops': {")
print("        'mul_two': mul_two_runnable,")
print("        'mul_three': mul_three_runnable,")
print("    },")
print("}")

test_value = 3
result = nested_parallel.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"结果: {result}")
print(f"  - math_ops.square: {result['math_ops']['square']}")
print(f"  - math_ops.cube: {result['math_ops']['cube']}")
print(f"  - mul_ops.mul_two: {result['mul_ops']['mul_two']}")
print(f"  - mul_ops.mul_three: {result['mul_ops']['mul_three']}")

print("\n" + "=" * 60)
print("演示 8: RunnableParallel 与条件处理结合")
print("=" * 60)

# 创建条件处理函数
def is_even(x: int) -> bool:
    """检查是否为偶数"""
    return x % 2 == 0

def process_even(x: int) -> str:
    """处理偶数"""
    return f"{x} 是偶数"

def process_odd(x: int) -> str:
    """处理奇数"""
    return f"{x} 是奇数"

# 创建并行处理链
conditional_parallel = RunnableParallel({
    "is_even": RunnableLambda(is_even),
    "even_result": RunnableLambda(process_even),
    "odd_result": RunnableLambda(process_odd),
})

print("\n创建条件处理链:")
print("conditional_parallel = RunnableParallel({")
print("    'is_even': RunnableLambda(is_even),")
print("    'even_result': RunnableLambda(process_even),")
print("    'odd_result': RunnableLambda(process_odd),")
print("})")

test_values = [1, 2, 3, 4, 5]

print("\n处理结果:")
for value in test_values:
    result = conditional_parallel.invoke(value)
    print(f"  输入: {value}")
    print(f"    是偶数: {result['is_even']}")
    print(f"    偶数结果: {result['even_result']}")
    print(f"    奇数结果: {result['odd_result']}")
    print()

print("\n" + "=" * 60)
print("演示 9: RunnableParallel 批处理")
print("=" * 60)

# 创建简单的并行处理
simple_parallel = RunnableParallel({
    "add_one": add_one_runnable,
    "mul_two": mul_two_runnable,
})

print("\n创建简单并行链:")
print("simple_parallel = RunnableParallel({")
print("    'add_one': add_one_runnable,")
print("    'mul_two': mul_two_runnable,")
print("})")

# 批处理
inputs = [1, 2, 3, 4, 5]
results = simple_parallel.batch(inputs)

print(f"\n批处理输入: {inputs}")
print("批处理结果:")
for i, (input_val, result) in enumerate(zip(inputs, results), 1):
    print(f"  {i}. 输入: {input_val}")
    print(f"     add_one: {result['add_one']}")
    print(f"     mul_two: {result['mul_two']}")

print("\n" + "=" * 60)
print("演示 10: RunnableParallel 实际应用场景")
print("=" * 60)

# 创建实际应用场景:文本分析
def extract_sentiment(text: str) -> str:
    """提取情感(简化版)"""
    positive_words = ["好", "棒", "优秀", "喜欢"]
    negative_words = ["差", "坏", "糟糕", "讨厌"]
    text_lower = text.lower()
    pos_count = sum(1 for word in positive_words if word in text_lower)
    neg_count = sum(1 for word in negative_words if word in text_lower)
    if pos_count > neg_count:
        return "积极"
    elif neg_count > pos_count:
        return "消极"
    return "中性"

def extract_key_phrases(text: str) -> list:
    """提取关键短语(简化版)"""
    words = text.split()
    return words[:3]  # 返回前3个词作为关键短语

def calculate_length(text: str) -> dict:
    """计算文本长度信息"""
    return {
        "char_count": len(text),
        "word_count": len(text.split()),
        "sentence_count": text.count('。') + text.count('!') + text.count('?'),
    }

# 创建文本分析链
text_analysis_chain = RunnableParallel({
    "sentiment": RunnableLambda(extract_sentiment),
    "key_phrases": RunnableLambda(extract_key_phrases),
    "length_info": RunnableLambda(calculate_length),
    "original": RunnableLambda(lambda x: x),
})

print("\n创建文本分析链:")
print("text_analysis_chain = RunnableParallel({")
print("    'sentiment': RunnableLambda(extract_sentiment),")
print("    'key_phrases': RunnableLambda(extract_key_phrases),")
print("    'length_info': RunnableLambda(calculate_length),")
print("    'original': RunnableLambda(lambda x: x),")
print("})")

test_texts = [
    "这个产品非常好用,我很喜欢!",
    "这个服务太糟糕了,完全不推荐。",
    "这是一个中性的评价,没有特别的情感倾向。",
]

for text in test_texts:
    result = text_analysis_chain.invoke(text)
    print(f"\n文本: {text}")
    print(f"情感: {result['sentiment']}")
    print(f"关键短语: {result['key_phrases']}")
    print(f"长度信息: {result['length_info']}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnableParallel 的核心概念:")
print("1. 并行执行:同时执行多个 Runnable,使用相同的输入")
print("2. 字典返回:返回一个字典,包含每个分支的结果")
print("3. 字典字面量:可以在序列中使用字典字面量自动创建 RunnableParallel")
print("4. 灵活组合:可以与 RunnableSequence 和其他组件组合")
print("\n使用场景:")
print("- 同时执行多个独立的处理任务")
print("- 从同一输入提取不同的信息")
print("- 并行调用多个 LLM 链")
print("- 数据分析和处理")
print("- 条件分支处理")
print("\n注意事项:")
print("- 所有分支接收相同的输入")
print("- 返回结果是字典格式")
print("- 可以嵌套使用创建复杂的处理流程")
print("- 适合需要并行处理的场景,提高效率")

26. RunnablePassthrough #

RunnablePassthrough 是 LangChain 中的一个特殊可运行单元(Runnable),它的作用是原样传递输入数据,不做任何更改。这在复杂的数据处理链中非常有用,例如:

  • 在并行任务中保留原始输入,便于后续比较或合并结果;
  • 使用 .assign() 方法为输入字典添加额外字段,实现灵活的链式数据增强;
  • 组合多个处理结果时,确保原始数据不会丢失。

主要特性

  1. 数据直通:输入什么,输出就是什么,不会对数据做任何加工或副作用。
  2. 链式和并行处理支持:经常与 RunnableParallel、RunnableSequence、RunnableLambda 等结合,用于复杂链式结构中。
  3. 可为字典输入添加字段:通过 assign() 方法,扩展输入字典,动态生成新字段。
  4. 便于保留原始数据:在处理、转换、分析数据的同时,方便随时获取原始内容。

使用场景举例

  • 保留用户原始输入,在处理流程中同时获得原始文本与处理后的结果。
  • 并行运行多个推理分支,并最终输出包含所有分支结果和输入数据的复合结果。
  • 为已存在的输入数据动态增加统计、分析、提取等新字段。

26.1. RunnablePassthrough.py #

26.RunnablePassthrough.py

from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
from langchain_core.runnables import Runnable,RunnableSequence,RunnableLambda,RunnableParallel,RunnablePassthrough


#from langchain.prompts import PromptTemplate
#from langchain.chat_models import ChatOpenAI
#from langchain.output_parsers import StrOutputParser, JsonOutputParser
#from langchain.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough

print("=" * 60)
print("RunnablePassthrough 演示")
print("=" * 60)
print("\nRunnablePassthrough: 传递输入数据不变或添加额外键")
print("常用于在链中保留原始输入数据,同时添加处理后的结果\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: RunnablePassthrough 的基本用法")
print("=" * 60)

# 创建简单的函数
def add_one(x: int) -> int:
    """将输入加 1"""
    return x + 1

# 创建 RunnablePassthrough
passthrough = RunnablePassthrough()

print("\n创建 RunnablePassthrough:")
print("passthrough = RunnablePassthrough()")
print(f"passthrough 对象: {passthrough}")

# 测试 invoke
print("\n使用 invoke 方法:")
test_values = [1, "hello", {"key": "value"}, [1, 2, 3]]

for value in test_values:
    result = passthrough.invoke(value)
    print(f"  输入: {value}")
    print(f"  输出: {result}")
    print(f"  是否相同: {result is value or result == value}")
    print()

print("=" * 60)
print("演示 2: RunnablePassthrough 在 RunnableParallel 中使用")
print("=" * 60)

# 在并行处理中保留原始输入
add_one_runnable = RunnableLambda(add_one)

parallel_with_passthrough = RunnableParallel({
    "original": RunnablePassthrough(),
    "modified": add_one_runnable,
})

print("\n创建并行处理链(包含 passthrough):")
print("parallel_with_passthrough = RunnableParallel({")
print("    'original': RunnablePassthrough(),")
print("    'modified': add_one_runnable,")
print("})")

# 测试
test_value = 5
result = parallel_with_passthrough.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"结果: {result}")
print(f"  - original: {result['original']}")
print(f"  - modified: {result['modified']}")

# 批处理
print("\n批处理:")
inputs = [1, 2, 3, 4, 5]
results = parallel_with_passthrough.batch(inputs)
print(f"输入: {inputs}")
for i, (input_val, result) in enumerate(zip(inputs, results), 1):
    print(f"  {i}. 输入: {input_val} -> 结果: {result}")

print("\n" + "=" * 60)
print("演示 3: RunnablePassthrough 在序列中使用")
print("=" * 60)

# 在序列中使用 passthrough 保留中间结果
def multiply_by_two(x: int) -> int:
    """将输入乘以 2"""
    return x * 2

sequence_with_passthrough = (
    RunnableLambda(add_one) |
    {
        "original": RunnablePassthrough(),
        "doubled": RunnableLambda(multiply_by_two),
    }
)

print("\n创建序列(包含 passthrough):")
print("sequence_with_passthrough = (")
print("    RunnableLambda(add_one) | {")
print("        'original': RunnablePassthrough(),")
print("        'doubled': RunnableLambda(multiply_by_two),")
print("    }")
print(")")

test_value = 3
result = sequence_with_passthrough.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"步骤 1 (add_one): {test_value} + 1 = {test_value + 1}")
print(f"步骤 2 (parallel):")
print(f"  - original: {result['original']}")
print(f"  - doubled: {result['doubled']}")
print(f"最终结果: {result}")

print("\n" + "=" * 60)
print("演示 4: RunnablePassthrough 与 LLM 链结合")
print("=" * 60)

# 创建 LLM 链,同时保留原始输入
prompt = PromptTemplate.from_template("请回答:{question}")

chain_with_passthrough = prompt | {
    "original_question": RunnablePassthrough(),
    "answer": llm | parser,
}

print("\n创建 LLM 链(包含 passthrough):")
print("chain_with_passthrough = prompt | {")
print("    'original_question': RunnablePassthrough(),")
print("    'answer': llm | parser,")
print("}")

questions = ["什么是 Python?", "什么是机器学习?"]

for question in questions:
    result = chain_with_passthrough.invoke({"question": question})
    print(f"\n问题: {question}")
    print(f"原始问题: {result['original_question']}")
    print(f"回答: {result['answer'][:50]}...")

print("\n" + "=" * 60)
print("演示 5: RunnablePassthrough.assign() 方法")
print("=" * 60)

# 使用 assign 方法添加额外字段
def calculate_sum(data: dict) -> int:
    """计算字典中数字字段的总和"""
    return sum(v for v in data.values() if isinstance(v, (int, float)))

def calculate_product(data: dict) -> int:
    """计算字典中数字字段的乘积"""
    numbers = [v for v in data.values() if isinstance(v, (int, float))]
    if not numbers:
        return 0
    result = 1
    for num in numbers:
        result *= num
    return result

# 创建包含多个字段的字典
data_with_assign = {
    "a": RunnableLambda(lambda x: x + 1),
    "b": RunnableLambda(lambda x: x * 2),
} | RunnablePassthrough.assign(
    sum=lambda inputs: calculate_sum(inputs),
    product=lambda inputs: calculate_product(inputs),
)

print("\n创建数据链(使用 assign):")
print("data_with_assign = {")
print("    'a': RunnableLambda(lambda x: x + 1),")
print("    'b': RunnableLambda(lambda x: x * 2),")
print("} | RunnablePassthrough.assign(")
print("    sum=lambda inputs: calculate_sum(inputs),")
print("    product=lambda inputs: calculate_product(inputs),")
print(")")

test_value = 3
result = data_with_assign.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"结果: {result}")
print(f"  - a: {result['a']}")
print(f"  - b: {result['b']}")
print(f"  - sum: {result['sum']}")
print(f"  - product: {result['product']}")

print("\n" + "=" * 60)
print("演示 6: RunnablePassthrough.assign() 与 LLM 链结合")
print("=" * 60)

# 创建多个 LLM 链,然后添加计算字段
joke_prompt = PromptTemplate.from_template("请讲一个关于 {topic} 的笑话")
joke_chain = joke_prompt | llm | parser

poem_prompt = PromptTemplate.from_template("请写一首关于 {topic} 的两行诗")
poem_chain = poem_prompt | llm | parser

# 使用 assign 添加总字符数
chain_with_assign = {
    "joke": joke_chain,
    "poem": poem_chain,
} | RunnablePassthrough.assign(
    total_chars=lambda inputs: len(inputs.get("joke", "")) + len(inputs.get("poem", ""))
)

print("\n创建 LLM 链(使用 assign):")
print("chain_with_assign = {")
print("    'joke': joke_chain,")
print("    'poem': poem_chain,")
print("} | RunnablePassthrough.assign(")
print("    total_chars=lambda inputs: len(inputs['joke']) + len(inputs['poem'])")
print(")")

topics = ["Python", "机器学习"]

for topic in topics:
    result = chain_with_assign.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"笑话: {result['joke'][:30]}...")
    print(f"诗歌: {result['poem'][:30]}...")
    print(f"总字符数: {result['total_chars']}")

print("\n" + "=" * 60)
print("演示 7: RunnablePassthrough 处理字典输入")
print("=" * 60)

# 处理字典输入,保留所有字段并添加新字段
def extract_name(data: dict) -> str:
    """提取名称"""
    return data.get("name", "Unknown")

def format_greeting(data: dict) -> str:
    """格式化问候语"""
    name = data.get("name", "Unknown")
    age = data.get("age", 0)
    return f"你好,{name}!你今年 {age} 岁。"

# 创建处理链
data_processor = {
    "name": RunnableLambda(extract_name),
    "greeting": RunnableLambda(format_greeting),
} | RunnablePassthrough.assign(
    is_adult=lambda inputs: inputs.get("age", 0) >= 18,
    age_group=lambda inputs: "成年" if inputs.get("age", 0) >= 18 else "未成年",
)

print("\n创建数据处理链:")
print("data_processor = {")
print("    'name': RunnableLambda(extract_name),")
print("    'greeting': RunnableLambda(format_greeting),")
print("} | RunnablePassthrough.assign(")
print("    is_adult=lambda inputs: inputs['age'] >= 18,")
print("    age_group=lambda inputs: '成年' if inputs['age'] >= 18 else '未成年',")
print(")")

test_data = [
    {"name": "张三", "age": 25},
    {"name": "李四", "age": 15},
    {"name": "王五", "age": 30},
]

for data in test_data:
    result = data_processor.invoke(data)
    print(f"\n输入: {data}")
    print(f"结果:")
    print(f"  - name: {result['name']}")
    print(f"  - greeting: {result['greeting']}")
    print(f"  - is_adult: {result['is_adult']}")
    print(f"  - age_group: {result['age_group']}")
    print(f"  - age: {result.get('age', 'N/A')}")  # 原始字段也被保留

print("\n" + "=" * 60)
print("演示 8: RunnablePassthrough 与 JSON 解析结合")
print("=" * 60)

# 创建 JSON 处理链
json_prompt = PromptTemplate.from_template("请以 JSON 格式返回关于 {topic} 的信息,包含 name 和 description 字段")
json_parser = JsonOutputParser()

json_chain = json_prompt | llm | json_parser

# 使用 assign 添加额外字段
json_with_assign = json_chain | RunnablePassthrough.assign(
    topic=lambda inputs: inputs.get("topic", "Unknown"),
    char_count=lambda inputs: len(str(inputs.get("name", ""))) + len(str(inputs.get("description", ""))),
    has_description=lambda inputs: bool(inputs.get("description")),
)

print("\n创建 JSON 处理链(使用 assign):")
print("json_with_assign = json_chain | RunnablePassthrough.assign(")
print("    topic=lambda inputs: inputs['topic'],")
print("    char_count=lambda inputs: len(str(inputs['name'])) + len(str(inputs['description'])),")
print("    has_description=lambda inputs: bool(inputs['description']),")
print(")")

topics = ["Python", "机器学习"]

for topic in topics:
    result = json_with_assign.invoke({"topic": topic})
    print(f"\n主题: {topic}")
    print(f"JSON 数据: {result.get('name', 'N/A')}, {result.get('description', 'N/A')[:30]}...")
    print(f"额外字段:")
    print(f"  - topic: {result.get('topic', 'N/A')}")
    print(f"  - char_count: {result.get('char_count', 'N/A')}")
    print(f"  - has_description: {result.get('has_description', 'N/A')}")

print("\n" + "=" * 60)
print("演示 9: RunnablePassthrough 嵌套使用")
print("=" * 60)

# 嵌套使用 passthrough
def square(x: int) -> int:
    """计算平方"""
    return x ** 2

nested_chain = (
    RunnableLambda(add_one) |
    {
        "original": RunnablePassthrough(),
        "squared": RunnableLambda(square),
        "doubled": RunnableLambda(multiply_by_two),
    } |
    RunnablePassthrough.assign(
        sum=lambda inputs: inputs.get("original", 0) + inputs.get("squared", 0) + inputs.get("doubled", 0),
        product=lambda inputs: inputs.get("original", 0) * inputs.get("squared", 0) * inputs.get("doubled", 0),
    )
)

print("\n创建嵌套链:")
print("nested_chain = (")
print("    RunnableLambda(add_one) | {")
print("        'original': RunnablePassthrough(),")
print("        'squared': RunnableLambda(square),")
print("        'doubled': RunnableLambda(multiply_by_two),")
print("    } | RunnablePassthrough.assign(")
print("        sum=lambda inputs: inputs['original'] + inputs['squared'] + inputs['doubled'],")
print("        product=lambda inputs: inputs['original'] * inputs['squared'] * inputs['doubled'],")
print("    )")
print(")")

test_value = 3
result = nested_chain.invoke(test_value)
print(f"\n输入: {test_value}")
print(f"结果: {result}")
print(f"  - original: {result['original']}")
print(f"  - squared: {result['squared']}")
print(f"  - doubled: {result['doubled']}")
print(f"  - sum: {result['sum']}")
print(f"  - product: {result['product']}")

print("\n" + "=" * 60)
print("演示 10: RunnablePassthrough 实际应用场景")
print("=" * 60)

# 实际应用:文本分析并保留原始文本
def extract_keywords(text: str) -> list:
    """提取关键词(简化版)"""
    words = text.split()
    return words[:5]

def count_words(text: str) -> int:
    """统计字数"""
    return len(text.split())

def extract_sentiment(text: str) -> str:
    """提取情感(简化版)"""
    positive_words = ["好", "棒", "优秀", "喜欢"]
    negative_words = ["差", "坏", "糟糕", "讨厌"]
    text_lower = text.lower()
    pos_count = sum(1 for word in positive_words if word in text_lower)
    neg_count = sum(1 for word in negative_words if word in text_lower)
    if pos_count > neg_count:
        return "积极"
    elif neg_count > pos_count:
        return "消极"
    return "中性"

# 创建文本分析链
# 使用 RunnablePassthrough 保留原始文本,同时进行分析
text_analysis_chain = {
    "original_text": RunnablePassthrough(),
    "keywords": RunnableLambda(extract_keywords),
    "word_count": RunnableLambda(count_words),
    "sentiment": RunnableLambda(extract_sentiment),
} | RunnablePassthrough.assign(
    analysis_summary=lambda inputs: f"文本包含 {inputs.get('word_count', 0)} 个词,情感倾向为 {inputs.get('sentiment', '未知')}",
)

print("\n创建文本分析链:")
print("text_analysis_chain = {")
print("    'original_text': RunnablePassthrough(),")
print("    'keywords': RunnableLambda(extract_keywords),")
print("    'word_count': RunnableLambda(count_words),")
print("    'sentiment': RunnableLambda(extract_sentiment),")
print("} | RunnablePassthrough.assign(")
print("    analysis_summary=lambda inputs: f\"文本包含 {inputs['word_count']} 个词,情感倾向为 {inputs['sentiment']}\",")
print(")")

test_texts = [
    "这个产品非常好用,我很喜欢!",
    "这个服务太糟糕了,完全不推荐。",
    "这是一个中性的评价,没有特别的情感倾向。",
]

for text in test_texts:
    result = text_analysis_chain.invoke(text)
    print(f"\n文本: {text}")
    print(f"原始文本: {result['original_text']}")
    print(f"关键词: {result['keywords']}")
    print(f"字数: {result['word_count']}")
    print(f"情感: {result['sentiment']}")
    print(f"分析摘要: {result['analysis_summary']}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnablePassthrough 的核心概念:")
print("1. 传递数据:直接返回输入数据,不做任何修改")
print("2. 保留原始数据:在并行处理中保留原始输入")
print("3. 添加字段:使用 assign() 方法在字典中添加额外字段")
print("4. 灵活组合:可以与序列、并行处理等组合使用")
print("\n使用场景:")
print("- 在并行处理中保留原始输入数据")
print("- 在链中添加计算字段或元数据")
print("- 保留中间处理结果")
print("- 组合多个处理结果并添加汇总信息")
print("\n注意事项:")
print("- RunnablePassthrough 直接返回输入,不做任何转换")
print("- assign() 方法只对字典输入有效")
print("- 在 assign() 中,lambda 函数接收整个输入字典")
print("- 可以嵌套使用,创建复杂的处理流程")

26.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


+class RunnablePassthrough(Runnable):
+   """传递输入数据不变或添加额外键的 Runnable
+   
+   RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
+   它常用于在链中保留原始输入数据,同时添加处理后的结果。
+   """
+   
+   def __init__(self, func=None):
+       """
+       初始化 RunnablePassthrough
+       
+       Args:
+           func: 可选的函数,会在传递数据时调用(用于副作用)
+       """
+       self.func = func
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """传递输入数据"""
+       # 如果提供了函数,调用它(用于副作用)
+       if self.func is not None:
+           if callable(self.func):
+               try:
+                   self.func(input_data, config=config, **kwargs)
+               except TypeError:
+                   self.func(input_data)
+       
+       # 直接返回输入数据
+       return input_data
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   @classmethod
+   def assign(cls, **kwargs):
+       """创建一个 RunnableAssign,在传递数据的同时添加额外字段
+       
+       Args:
+           **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
+       
+       Returns:
+           RunnableAssign: 可以添加额外字段的 Runnable
+       """
+       return RunnableAssign(**kwargs)
+   
+   def __repr__(self):
+       return "RunnablePassthrough()"
+
+
+class RunnableAssign(Runnable):
+   """在传递字典数据的同时添加额外字段的 Runnable
+   
+   RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
+   在保留原有字段的同时,添加新的字段。
+   """
+   
+   def __init__(self, **kwargs):
+       """
+       初始化 RunnableAssign
+       
+       Args:
+           **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
+       """
+       self.assignments = {}
+       for key, value in kwargs.items():
+           if isinstance(value, dict):
+               # 嵌套字典转换为 RunnableParallel
+               self.assignments[key] = RunnableParallel(value)
+           elif hasattr(value, 'invoke') or callable(value):
+               self.assignments[key] = value
+           else:
+               raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """传递输入并添加额外字段"""
+       # 确保输入是字典
+       if not isinstance(input_data, dict):
+           input_data = {"input": input_data}
+       
+       # 复制输入数据
+       result = dict(input_data)
+       
+       # 添加新字段
+       for key, assignment in self.assignments.items():
+           if hasattr(assignment, 'invoke'):
+               # 检查是否接受 config 参数
+               import inspect
+               try:
+                   sig = inspect.signature(assignment.invoke)
+                   params = list(sig.parameters.keys())
+                   if 'config' in params:
+                       result[key] = assignment.invoke(input_data, config=config, **kwargs)
+                   else:
+                       result[key] = assignment.invoke(input_data, **kwargs)
+               except (ValueError, TypeError):
+                   try:
+                       result[key] = assignment.invoke(input_data, **kwargs)
+                   except TypeError:
+                       result[key] = assignment.invoke(input_data, config=config, **kwargs)
+           elif callable(assignment):
+               result[key] = assignment(input_data)
+           else:
+               raise ValueError(f"赋值 '{key}' 不可调用")
+       
+       return result
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   def __ror__(self, other):
+       """支持从右侧使用 | 操作符(用于字典字面量)"""
+       if isinstance(other, dict):
+           # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
+           return RunnableSequence(RunnableParallel(other), self)
+       return NotImplemented
+   
+   def __repr__(self):
+       return f"RunnableAssign({list(self.assignments.keys())})"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

27. RunnableBranch #

RunnableBranch 是 LangChain 中用于实现"条件分支"逻辑的强大组件。它可以根据输入和一组条件,选择不同的处理流程,非常类似于 Python 的 if-elif-else 或 switch-case 结构。你可以将多个 (条件, 可运行对象) 依次排列,并指定一个默认分支。

核心用法特点:

  • 可以把条件函数和对应处理的 Runnable 成对提供(如 (lambda x: isinstance(x, str), lambda x: x.upper()))。
  • 条件按顺序逐一检查,第一个满足条件的分支会被触发并处理输入。
  • 如果没有条件被满足,则会执行最后一项(默认分支)。
  • 支持与 RunnableLambda, RunnableSequence 等其他 LCEL 组件组合使用,实现灵活且强大的数据流路由。

典型应用场景:

  • 根据输入的数据类型选择不同处理方式(如字符串、数字、对象等)。
  • 根据输入的内容、特征或用户角色决定不同的业务逻辑或下游模型。
  • 在聊天机器人、数据清洗、流程自动化等场景中实现智能路由分流。

注意事项:

  • 分支条件应为可调用对象,返回布尔值。
  • 分支数量要求至少有 2 个(含默认分支)。
  • 默认分支是必须的,放最后一位,用于兜底处理所有未被命中的情况。
  • 条件函数需要能适当处理异常,避免程序意外中断。

27.1. RunnableBranch.py #

27.RunnableBranch.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
#from langchain_core.runnables import RunnableLambda,RunnableBranch

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
from langchain.runnables import RunnableBranch, RunnableLambda

print("=" * 60)
print("RunnableBranch 演示")
print("=" * 60)
print("\nRunnableBranch: 根据条件选择执行分支")
print("类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: RunnableBranch 的基本用法")
print("=" * 60)

# 创建简单的分支
branch = RunnableBranch(
    (lambda x: isinstance(x, str), lambda x: x.upper()),
    (lambda x: isinstance(x, int), lambda x: x + 1),
    (lambda x: isinstance(x, float), lambda x: x * 2),
    lambda x: "unknown",
)

print("\n创建分支:")
print("branch = RunnableBranch(")
print("    (lambda x: isinstance(x, str), lambda x: x.upper()),")
print("    (lambda x: isinstance(x, int), lambda x: x + 1),")
print("    (lambda x: isinstance(x, float), lambda x: x * 2),")
print("    lambda x: 'unknown',  # 默认分支")
print(")")

# 测试不同的输入
test_inputs = ["hello", 5, 3.14, None, [1, 2, 3]]

print("\n测试不同的输入:")
for input_val in test_inputs:
    result = branch.invoke(input_val)
    print(f"  输入: {input_val!r:15} -> 输出: {result!r}")

print("\n" + "=" * 60)
print("演示 2: RunnableBranch 处理数字范围")
print("=" * 60)

# 根据数字范围选择不同的处理
number_branch = RunnableBranch(
    (lambda x: isinstance(x, (int, float)) and x < 0, lambda x: f"{x} 是负数"),
    (lambda x: isinstance(x, (int, float)) and x == 0, lambda x: f"{x} 是零"),
    (lambda x: isinstance(x, (int, float)) and 0 < x < 10, lambda x: f"{x} 是小正数"),
    (lambda x: isinstance(x, (int, float)) and x >= 10, lambda x: f"{x} 是大正数"),
    lambda x: f"{x} 不是数字",
)

print("\n创建数字范围分支:")
print("number_branch = RunnableBranch(")
print("    (lambda x: isinstance(x, (int, float)) and x < 0, lambda x: f'{x} 是负数'),")
print("    (lambda x: isinstance(x, (int, float)) and x == 0, lambda x: f'{x} 是零'),")
print("    (lambda x: isinstance(x, (int, float)) and 0 < x < 10, lambda x: f'{x} 是小正数'),")
print("    (lambda x: isinstance(x, (int, float)) and x >= 10, lambda x: f'{x} 是大正数'),")
print("    lambda x: f'{x} 不是数字',  # 默认分支")
print(")")

test_numbers = [-5, 0, 5, 15, "not a number"]

print("\n测试不同的数字:")
for num in test_numbers:
    result = number_branch.invoke(num)
    print(f"  输入: {num:15} -> 输出: {result}")

print("\n" + "=" * 60)
print("演示 3: RunnableBranch 与 RunnableLambda 结合")
print("=" * 60)

# 使用 RunnableLambda 作为分支
def is_even(x):
    """检查是否为偶数"""
    return isinstance(x, int) and x % 2 == 0

def is_odd(x):
    """检查是否为奇数"""
    return isinstance(x, int) and x % 2 == 1

def double(x):
    """将数字翻倍"""
    return x * 2

def triple(x):
    """将数字乘以3"""
    return x * 3

even_odd_branch = RunnableBranch(
    (is_even, RunnableLambda(double)),
    (is_odd, RunnableLambda(triple)),
    lambda x: f"{x} 不是整数",
)

print("\n创建奇偶分支:")
print("even_odd_branch = RunnableBranch(")
print("    (is_even, RunnableLambda(double)),")
print("    (is_odd, RunnableLambda(triple)),")
print("    lambda x: f'{x} 不是整数',  # 默认分支")
print(")")

test_values = [2, 3, 4, 5, 6, "not a number"]

print("\n测试不同的值:")
for val in test_values:
    result = even_odd_branch.invoke(val)
    print(f"  输入: {val:15} -> 输出: {result}")

print("\n" + "=" * 60)
print("演示 4: RunnableBranch 与 LLM 链结合")
print("=" * 60)

# 根据问题类型选择不同的 LLM 链
def is_technical_question(question_dict):
    """检查是否是技术问题"""
    question = question_dict.get("question", "").lower()
    technical_keywords = ["python", "编程", "代码", "算法", "技术"]
    return any(keyword in question for keyword in technical_keywords)

def is_general_question(question_dict):
    """检查是否是一般问题"""
    question = question_dict.get("question", "").lower()
    general_keywords = ["什么", "如何", "为什么", "哪里"]
    return any(keyword in question for keyword in general_keywords)

# 创建不同的提示模板
technical_prompt = PromptTemplate.from_template("你是一个技术专家。请详细回答:{question}")
general_prompt = PromptTemplate.from_template("你是一个友好的助手。请简洁回答:{question}")
default_prompt = PromptTemplate.from_template("请回答:{question}")

# 创建不同的链
technical_chain = technical_prompt | llm | parser
general_chain = general_prompt | llm | parser
default_chain = default_prompt | llm | parser

# 创建分支链
question_branch = RunnableBranch(
    (is_technical_question, technical_chain),
    (is_general_question, general_chain),
    default_chain,
)

print("\n创建问题分支链:")
print("question_branch = RunnableBranch(")
print("    (is_technical_question, technical_chain),")
print("    (is_general_question, general_chain),")
print("    default_chain,  # 默认分支")
print(")")

questions = [
    {"question": "什么是 Python?"},
    {"question": "如何学习编程?"},
    {"question": "今天天气怎么样?"},
]

for question_dict in questions:
    result = question_branch.invoke(question_dict)
    print(f"\n问题: {question_dict['question']}")
    print(f"回答: {result[:50]}...")

print("\n" + "=" * 60)
print("演示 5: RunnableBranch 处理字典数据")
print("=" * 60)

# 根据字典中的字段选择不同的处理
def has_name(data):
    """检查是否有 name 字段"""
    return isinstance(data, dict) and "name" in data

def has_age(data):
    """检查是否有 age 字段"""
    return isinstance(data, dict) and "age" in data

def format_person(data):
    """格式化人员信息"""
    name = data.get("name", "Unknown")
    age = data.get("age", 0)
    return f"姓名: {name}, 年龄: {age}"

def format_name_only(data):
    """只格式化姓名"""
    name = data.get("name", "Unknown")
    return f"姓名: {name}"

def format_age_only(data):
    """只格式化年龄"""
    age = data.get("age", 0)
    return f"年龄: {age}"

data_branch = RunnableBranch(
    (has_name, RunnableLambda(format_name_only)),
    (has_age, RunnableLambda(format_age_only)),
    lambda x: f"数据: {x}",
)

print("\n创建数据分支:")
print("data_branch = RunnableBranch(")
print("    (has_name, RunnableLambda(format_name_only)),")
print("    (has_age, RunnableLambda(format_age_only)),")
print("    lambda x: f'数据: {x}',  # 默认分支")
print(")")

test_data = [
    {"name": "张三", "age": 25},
    {"name": "李四"},
    {"age": 30},
    {"city": "北京"},
]

print("\n测试不同的数据:")
for data in test_data:
    result = data_branch.invoke(data)
    print(f"  输入: {data}")
    print(f"  输出: {result}")
    print()

print("\n" + "=" * 60)
print("演示 6: RunnableBranch 处理字符串长度")
print("=" * 60)

# 根据字符串长度选择不同的处理
def is_short_text(text):
    """检查是否是短文本"""
    return isinstance(text, str) and len(text) < 10

def is_medium_text(text):
    """检查是否是中等长度文本"""
    return isinstance(text, str) and 10 <= len(text) < 50

def is_long_text(text):
    """检查是否是长文本"""
    return isinstance(text, str) and len(text) >= 50

def summarize_short(text):
    """处理短文本"""
    return f"短文本: {text}"

def summarize_medium(text):
    """处理中等长度文本"""
    return f"中等文本({len(text)} 字符): {text[:20]}..."

def summarize_long(text):
    """处理长文本"""
    return f"长文本({len(text)} 字符): {text[:30]}..."

text_length_branch = RunnableBranch(
    (is_short_text, RunnableLambda(summarize_short)),
    (is_medium_text, RunnableLambda(summarize_medium)),
    (is_long_text, RunnableLambda(summarize_long)),
    lambda x: f"不是文本: {x}",
)

print("\n创建文本长度分支:")
print("text_length_branch = RunnableBranch(")
print("    (is_short_text, RunnableLambda(summarize_short)),")
print("    (is_medium_text, RunnableLambda(summarize_medium)),")
print("    (is_long_text, RunnableLambda(summarize_long)),")
print("    lambda x: f'不是文本: {x}',  # 默认分支")
print(")")

test_texts = [
    "短",
    "这是一个中等长度的文本示例",
    "这是一个非常长的文本示例,用于演示 RunnableBranch 如何处理不同长度的文本内容。",
    123,
]

print("\n测试不同的文本:")
for text in test_texts:
    result = text_length_branch.invoke(text)
    print(f"  输入: {text!r}")
    print(f"  输出: {result}")
    print()

print("\n" + "=" * 60)
print("演示 7: RunnableBranch 与序列组合")
print("=" * 60)

# 在序列中使用分支
def add_one(x):
    """加1"""
    return x + 1

def multiply_by_two(x):
    """乘以2"""
    return x * 2

def multiply_by_three(x):
    """乘以3"""
    return x * 3

# 创建分支
processing_branch = RunnableBranch(
    (lambda x: x < 5, RunnableLambda(multiply_by_two)),
    (lambda x: x < 10, RunnableLambda(multiply_by_three)),
    lambda x: x,  # 默认分支:不做处理
)

# 创建序列:先加1,然后根据结果选择处理方式
sequence_with_branch = RunnableLambda(add_one) | processing_branch

print("\n创建序列(包含分支):")
print("sequence_with_branch = RunnableLambda(add_one) | processing_branch")
print("processing_branch = RunnableBranch(")
print("    (lambda x: x < 5, RunnableLambda(multiply_by_two)),")
print("    (lambda x: x < 10, RunnableLambda(multiply_by_three)),")
print("    lambda x: x,  # 默认分支")
print(")")

test_values = [1, 3, 5, 8, 12]

print("\n测试不同的值:")
for val in test_values:
    result = sequence_with_branch.invoke(val)
    print(f"  输入: {val}")
    print(f"  步骤 1 (add_one): {val} + 1 = {val + 1}")
    if val + 1 < 5:
        print(f"  步骤 2 (branch < 5): {val + 1} * 2 = {result}")
    elif val + 1 < 10:
        print(f"  步骤 2 (branch < 10): {val + 1} * 3 = {result}")
    else:
        print(f"  步骤 2 (default): {val + 1} = {result}")
    print()

print("\n" + "=" * 60)
print("演示 8: RunnableBranch 批处理")
print("=" * 60)

# 创建简单的分支用于批处理
simple_branch = RunnableBranch(
    (lambda x: x > 0, lambda x: f"{x} 是正数"),
    (lambda x: x < 0, lambda x: f"{x} 是负数"),
    lambda x: f"{x} 是零",
)

print("\n创建简单分支:")
print("simple_branch = RunnableBranch(")
print("    (lambda x: x > 0, lambda x: f'{x} 是正数'),")
print("    (lambda x: x < 0, lambda x: f'{x} 是负数'),")
print("    lambda x: f'{x} 是零',  # 默认分支")
print(")")

# 批处理
inputs = [-5, 0, 3, -2, 10]
results = simple_branch.batch(inputs)

print(f"\n批处理输入: {inputs}")
print("批处理结果:")
for input_val, result in zip(inputs, results):
    print(f"  {input_val:3} -> {result}")

print("\n" + "=" * 60)
print("演示 9: RunnableBranch 复杂条件")
print("=" * 60)

# 使用复杂的条件逻辑
def is_valid_email(data):
    """检查是否是有效的邮箱格式(简化版)"""
    if not isinstance(data, dict):
        return False
    email = data.get("email", "")
    return isinstance(email, str) and "@" in email and "." in email.split("@")[1]

def is_valid_phone(data):
    """检查是否是有效的手机号(简化版)"""
    if not isinstance(data, dict):
        return False
    phone = data.get("phone", "")
    return isinstance(phone, str) and phone.isdigit() and len(phone) == 11

def format_email(data):
    """格式化邮箱信息"""
    email = data.get("email", "")
    return f"邮箱: {email}"

def format_phone(data):
    """格式化手机号信息"""
    phone = data.get("phone", "")
    return f"手机号: {phone}"

validation_branch = RunnableBranch(
    (is_valid_email, RunnableLambda(format_email)),
    (is_valid_phone, RunnableLambda(format_phone)),
    lambda x: "无效的联系方式",
)

print("\n创建验证分支:")
print("validation_branch = RunnableBranch(")
print("    (is_valid_email, RunnableLambda(format_email)),")
print("    (is_valid_phone, RunnableLambda(format_phone)),")
print("    lambda x: '无效的联系方式',  # 默认分支")
print(")")

test_contacts = [
    {"email": "user@example.com"},
    {"phone": "13800138000"},
    {"email": "invalid"},
    {"phone": "123"},
    {"name": "张三"},
]

print("\n测试不同的联系方式:")
for contact in test_contacts:
    result = validation_branch.invoke(contact)
    print(f"  输入: {contact}")
    print(f"  输出: {result}")
    print()

print("\n" + "=" * 60)
print("演示 10: RunnableBranch 实际应用场景")
print("=" * 60)

# 实际应用:根据用户类型选择不同的处理流程
def is_admin(user):
    """检查是否是管理员"""
    return isinstance(user, dict) and user.get("role") == "admin"

def is_vip(user):
    """检查是否是VIP用户"""
    return isinstance(user, dict) and user.get("role") == "vip"

def is_regular(user):
    """检查是否是普通用户"""
    return isinstance(user, dict) and user.get("role") == "regular"

def process_admin(user):
    """处理管理员"""
    name = user.get("name", "Unknown")
    return f"管理员 {name} 可以访问所有功能"

def process_vip(user):
    """处理VIP用户"""
    name = user.get("name", "Unknown")
    return f"VIP用户 {name} 可以访问高级功能"

def process_regular(user):
    """处理普通用户"""
    name = user.get("name", "Unknown")
    return f"普通用户 {name} 可以访问基本功能"

user_branch = RunnableBranch(
    (is_admin, RunnableLambda(process_admin)),
    (is_vip, RunnableLambda(process_vip)),
    (is_regular, RunnableLambda(process_regular)),
    lambda x: "未知用户类型",
)

print("\n创建用户分支:")
print("user_branch = RunnableBranch(")
print("    (is_admin, RunnableLambda(process_admin)),")
print("    (is_vip, RunnableLambda(process_vip)),")
print("    (is_regular, RunnableLambda(process_regular)),")
print("    lambda x: '未知用户类型',  # 默认分支")
print(")")

test_users = [
    {"name": "张三", "role": "admin"},
    {"name": "李四", "role": "vip"},
    {"name": "王五", "role": "regular"},
    {"name": "赵六", "role": "guest"},
]

print("\n测试不同的用户:")
for user in test_users:
    result = user_branch.invoke(user)
    print(f"  用户: {user}")
    print(f"  结果: {result}")
    print()

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnableBranch 的核心概念:")
print("1. 条件分支:根据条件选择不同的 Runnable 执行")
print("2. 顺序检查:按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行")
print("3. 默认分支:如果没有条件满足,执行默认分支")
print("4. 灵活组合:可以与序列、并行处理等组合使用")
print("\n使用场景:")
print("- 根据输入类型选择不同的处理方式")
print("- 根据数据特征选择不同的处理流程")
print("- 实现条件逻辑和路由功能")
print("- 处理多种输入格式")
print("\n注意事项:")
print("- 条件函数应该返回布尔值")
print("- 条件按顺序检查,第一个为 True 的条件会被执行")
print("- 必须提供默认分支(最后一个参数)")
print("- 条件函数应该处理可能的异常情况")

27.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


class RunnablePassthrough(Runnable):
    """传递输入数据不变或添加额外键的 Runnable

    RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
    它常用于在链中保留原始输入数据,同时添加处理后的结果。
    """

    def __init__(self, func=None):
        """
        初始化 RunnablePassthrough

        Args:
            func: 可选的函数,会在传递数据时调用(用于副作用)
        """
        self.func = func

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入数据"""
        # 如果提供了函数,调用它(用于副作用)
        if self.func is not None:
            if callable(self.func):
                try:
                    self.func(input_data, config=config, **kwargs)
                except TypeError:
                    self.func(input_data)

        # 直接返回输入数据
        return input_data

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    @classmethod
    def assign(cls, **kwargs):
        """创建一个 RunnableAssign,在传递数据的同时添加额外字段

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象

        Returns:
            RunnableAssign: 可以添加额外字段的 Runnable
        """
        return RunnableAssign(**kwargs)

    def __repr__(self):
        return "RunnablePassthrough()"


class RunnableAssign(Runnable):
    """在传递字典数据的同时添加额外字段的 Runnable

    RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
    在保留原有字段的同时,添加新的字段。
    """

    def __init__(self, **kwargs):
        """
        初始化 RunnableAssign

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
        """
        self.assignments = {}
        for key, value in kwargs.items():
            if isinstance(value, dict):
                # 嵌套字典转换为 RunnableParallel
                self.assignments[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                self.assignments[key] = value
            else:
                raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入并添加额外字段"""
        # 确保输入是字典
        if not isinstance(input_data, dict):
            input_data = {"input": input_data}

        # 复制输入数据
        result = dict(input_data)

        # 添加新字段
        for key, assignment in self.assignments.items():
            if hasattr(assignment, 'invoke'):
                # 检查是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(assignment.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
                    else:
                        result[key] = assignment.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        result[key] = assignment.invoke(input_data, **kwargs)
                    except TypeError:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
            elif callable(assignment):
                result[key] = assignment(input_data)
            else:
                raise ValueError(f"赋值 '{key}' 不可调用")

        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __ror__(self, other):
        """支持从右侧使用 | 操作符(用于字典字面量)"""
        if isinstance(other, dict):
            # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
            return RunnableSequence(RunnableParallel(other), self)
        return NotImplemented

    def __repr__(self):
        return f"RunnableAssign({list(self.assignments.keys())})"


+class RunnableBranch(Runnable):
+   """根据条件选择执行分支的 Runnable
+   
+   RunnableBranch 类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行。
+   它接受多个 (condition, Runnable) 对,按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行。
+   如果没有条件满足,则执行默认分支。
+   """
+   
+   def __init__(self, *branches):
+       """
+       初始化 RunnableBranch
+       
+       Args:
+           *branches: 多个 (condition, Runnable) 对,最后一个可以是默认分支(Runnable 或可调用对象)
+       """
+       if len(branches) < 2:
+           raise ValueError("RunnableBranch 至少需要 2 个分支(包括默认分支)")
+       
+       # 分离条件和默认分支
+       self.branches = []
+       default = None
+       
+       for i, branch in enumerate(branches):
+           if i == len(branches) - 1:
+               # 最后一个可能是默认分支
+               if isinstance(branch, tuple) and len(branch) == 2:
+                   # 仍然是 (condition, Runnable) 对
+                   condition, runnable = branch
+                   self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
+               else:
+                   # 是默认分支
+                   default = self._coerce_to_runnable(branch)
+           else:
+               # 必须是 (condition, Runnable) 对
+               if not isinstance(branch, tuple) or len(branch) != 2:
+                   raise ValueError(f"分支 {i} 必须是 (condition, Runnable) 对,但得到 {type(branch)}")
+               condition, runnable = branch
+               self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
+       
+       if default is None:
+           raise ValueError("RunnableBranch 必须提供默认分支(最后一个参数)")
+       
+       self.default = default
+   
+   def _coerce_to_condition(self, condition):
+       """将条件转换为可调用对象"""
+       if callable(condition):
+           return condition
+       elif hasattr(condition, 'invoke'):
+           # 如果已经是 Runnable,返回一个包装函数
+           def wrapped_condition(input_data):
+               result = condition.invoke(input_data)
+               return bool(result)
+           return wrapped_condition
+       else:
+           raise ValueError(f"条件必须是可调用对象或 Runnable,但得到 {type(condition)}")
+   
+   def _coerce_to_runnable(self, runnable):
+       """将值转换为 Runnable"""
+       if hasattr(runnable, 'invoke'):
+           return runnable
+       elif callable(runnable):
+           return RunnableLambda(runnable)
+       else:
+           raise ValueError(f"Runnable 必须是可调用对象或 Runnable,但得到 {type(runnable)}")
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """根据条件选择分支执行"""
+       # 按顺序检查条件
+       for condition, runnable in self.branches:
+           try:
+               # 评估条件
+               condition_result = condition(input_data)
+               if condition_result:
+                   # 条件满足,执行对应的 Runnable
+                   if hasattr(runnable, 'invoke'):
+                       import inspect
+                       try:
+                           sig = inspect.signature(runnable.invoke)
+                           params = list(sig.parameters.keys())
+                           if 'config' in params:
+                               return runnable.invoke(input_data, config=config, **kwargs)
+                           else:
+                               return runnable.invoke(input_data, **kwargs)
+                       except (ValueError, TypeError):
+                           try:
+                               return runnable.invoke(input_data, **kwargs)
+                           except TypeError:
+                               return runnable.invoke(input_data, config=config, **kwargs)
+                   elif callable(runnable):
+                       return runnable(input_data)
+           except Exception as e:
+               # 如果条件评估出错,继续下一个条件
+               continue
+       
+       # 没有条件满足,执行默认分支
+       if hasattr(self.default, 'invoke'):
+           import inspect
+           try:
+               sig = inspect.signature(self.default.invoke)
+               params = list(sig.parameters.keys())
+               if 'config' in params:
+                   return self.default.invoke(input_data, config=config, **kwargs)
+               else:
+                   return self.default.invoke(input_data, **kwargs)
+           except (ValueError, TypeError):
+               try:
+                   return self.default.invoke(input_data, **kwargs)
+               except TypeError:
+                   return self.default.invoke(input_data, config=config, **kwargs)
+       elif callable(self.default):
+           return self.default(input_data)
+       else:
+           return self.default
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   def __repr__(self):
+       return f"RunnableBranch({len(self.branches)} branches + default)"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

28. Retry #

with_retry() 是 LangChain LCEL 可运行组件(Runnable)的一个高阶方法,可为任何数据处理链路灵活地加上自动重试机制,极大增强健壮性。当链条某个环节出现异常或网络/模型调用失败时,with_retry 会自动按照设定的重试次数、等待时长和回退策略重新尝试,直至成功或达最大重试次数。
常用于模型调用、API请求、解析器等有可能偶发错误的场景。

核心用法特点:

  • with_retry() 可作用于任何 LCEL Runnable 类型(如 RunnableLambda、RunnableSequence 等)。
  • 默认提供最多 6 次重试,等待时间采用指数退避(每次失败之后等待时间递增)。
  • 支持自定义重试次数(max_attempts)、重试等待区间(min_seconds/max_seconds),甚至可传入自定义重试条件(should_retry)。
  • 如果最终仍然失败,异常会抛出,便于上层捕获处理。
  • 可以天然应用在 LLM 推理、API 调用、解析步骤、数据清洗等环节,为系统容错兜底。

典型应用场景:

  • LLM/OpenAI API 间歇性失败、网络抖动问题
  • 解析器(如 JSON 解析、内容提取)有输入异常的场合
  • 外部 HTTP 接口/数据库调用及任何易出错的环节

注意事项:

  • 建议仅对幂等操作(重复调用不会有副作用)使用重试,避免因重试带来隐藏的业务风险。
  • 可以结合 log/print 跟踪每次重试后捕获的异常,便于调试。
  • 可通过自定义 should_retry 函数限制只对某类异常(如网络超时)才重试,自定义更安全。

常用参数简介

  • max_attempts:最大尝试次数(默认6)
  • min_seconds:每次重试的最小等待时长(默认1s)
  • max_seconds:每次重试的最大等待时长(默认60s)
  • should_retry:可选,异常过滤函数,返回 True 则重试

28.1. with_retry.py #

28.with_retry.py

from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
from langchain_core.runnables import RunnableLambda,RunnableBranch

#from langchain.prompts import PromptTemplate
#from langchain.chat_models import ChatOpenAI
#from langchain.output_parsers import StrOutputParser, JsonOutputParser
#from langchain.runnables import RunnableLambda, RunnableBranch
import time
import random

print("=" * 60)
print("RunnableLambda.with_retry() 演示")
print("=" * 60)
print("\nwith_retry: 为 Runnable 添加重试功能")
print("当执行失败时会自动重试,支持指数退避和抖动\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: with_retry 的基本用法")
print("=" * 60)

# 创建一个可能失败的函数
call_count = 0

def unreliable_function(x: int) -> int:
    """不可靠的函数,前两次调用会失败"""
    global call_count
    call_count += 1
    if call_count < 3:
        raise ValueError(f"调用失败(第 {call_count} 次尝试)")
    return x * 2

# 创建 RunnableLambda
runnable = RunnableLambda(unreliable_function)

print("\n创建可能失败的 Runnable:")
print("def unreliable_function(x: int) -> int:")
print("    global call_count")
print("    call_count += 1")
print("    if call_count < 3:")
print("        raise ValueError(f'调用失败(第 {call_count} 次尝试)')")
print("    return x * 2")
print("\nrunnable = RunnableLambda(unreliable_function)")

# 添加重试功能
runnable_with_retry = runnable.with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=5,
    wait_exponential_jitter=True,
)

print("\n添加重试功能:")
print("runnable_with_retry = runnable.with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=5,")
print("    wait_exponential_jitter=True,")
print(")")

# 重置计数器
call_count = 0

print("\n测试重试功能:")
try:
    result = runnable_with_retry.invoke(5)
    print(f"  输入: 5")
    print(f"  输出: {result}")
    print(f"  总调用次数: {call_count}")
except Exception as e:
    print(f"  最终失败: {e}")

print("\n" + "=" * 60)
print("演示 2: 指定重试的异常类型")
print("=" * 60)

# 创建可能抛出不同异常的函数
def function_with_multiple_errors(x: int) -> int:
    """可能抛出多种异常的函数"""
    if x < 0:
        raise ValueError("负数")
    elif x == 0:
        raise ZeroDivisionError("零")
    elif x > 100:
        raise OverflowError("溢出")
    return x * 2

# 只重试 ValueError
runnable_value_error = RunnableLambda(function_with_multiple_errors).with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=3,
)

print("\n创建只重试 ValueError 的 Runnable:")
print("runnable_value_error = RunnableLambda(function_with_multiple_errors).with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=3,")
print(")")

test_values = [-5, 0, 50, 150]

print("\n测试不同的输入:")
for val in test_values:
    try:
        result = runnable_value_error.invoke(val)
        print(f"  输入: {val:3} -> 输出: {result}")
    except ValueError as e:
        print(f"  输入: {val:3} -> ValueError (会重试): {e}")
    except ZeroDivisionError as e:
        print(f"  输入: {val:3} -> ZeroDivisionError (不会重试): {e}")
    except OverflowError as e:
        print(f"  输入: {val:3} -> OverflowError (不会重试): {e}")

print("\n" + "=" * 60)
print("演示 3: 重试多次异常类型")
print("=" * 60)

# 重试多种异常类型
runnable_multiple_errors = RunnableLambda(function_with_multiple_errors).with_retry(
    retry_if_exception_type=(ValueError, ZeroDivisionError),
    stop_after_attempt=3,
)

print("\n创建重试多种异常的 Runnable:")
print("runnable_multiple_errors = RunnableLambda(function_with_multiple_errors).with_retry(")
print("    retry_if_exception_type=(ValueError, ZeroDivisionError),")
print("    stop_after_attempt=3,")
print(")")

print("\n测试不同的输入:")
for val in test_values:
    try:
        result = runnable_multiple_errors.invoke(val)
        print(f"  输入: {val:3} -> 输出: {result}")
    except ValueError as e:
        print(f"  输入: {val:3} -> ValueError (会重试): {e}")
    except ZeroDivisionError as e:
        print(f"  输入: {val:3} -> ZeroDivisionError (会重试): {e}")
    except OverflowError as e:
        print(f"  输入: {val:3} -> OverflowError (不会重试): {e}")

print("\n" + "=" * 60)
print("演示 4: 控制重试次数")
print("=" * 60)

# 创建一个总是失败的函数(用于演示)
fail_count = 0

def always_fail_once(x: int) -> int:
    """第一次调用失败,之后成功"""
    global fail_count
    fail_count += 1
    if fail_count == 1:
        raise ValueError("第一次失败")
    return x * 2

# 测试不同的重试次数
print("\n测试不同的重试次数:")

for max_attempts in [1, 2, 3]:
    fail_count = 0
    runnable_attempts = RunnableLambda(always_fail_once).with_retry(
        retry_if_exception_type=(ValueError,),
        stop_after_attempt=max_attempts,
    )

    try:
        result = runnable_attempts.invoke(5)
        print(f"  max_attempts={max_attempts}: 成功,结果={result}, 调用次数={fail_count}")
    except ValueError as e:
        print(f"  max_attempts={max_attempts}: 失败,调用次数={fail_count}")

print("\n" + "=" * 60)
print("演示 5: 指数退避和抖动")
print("=" * 60)

# 创建一个会失败多次的函数
retry_count = 0

def fail_multiple_times(x: int) -> int:
    """前两次调用失败"""
    global retry_count
    retry_count += 1
    if retry_count < 3:
        raise ValueError(f"第 {retry_count} 次失败")
    return x * 2

# 测试带指数退避的重试
runnable_backoff = RunnableLambda(fail_multiple_times).with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=5,
    wait_exponential_jitter=True,
)

print("\n创建带指数退避的重试 Runnable:")
print("runnable_backoff = RunnableLambda(fail_multiple_times).with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=5,")
print("    wait_exponential_jitter=True,  # 启用指数退避和抖动")
print(")")

retry_count = 0
start_time = time.time()

try:
    result = runnable_backoff.invoke(5)
    elapsed_time = time.time() - start_time
    print(f"\n输入: 5")
    print(f"输出: {result}")
    print(f"总调用次数: {retry_count}")
    print(f"总耗时: {elapsed_time:.2f} 秒(包含重试等待时间)")
except Exception as e:
    elapsed_time = time.time() - start_time
    print(f"\n最终失败: {e}")
    print(f"总调用次数: {retry_count}")
    print(f"总耗时: {elapsed_time:.2f} 秒")

print("\n" + "=" * 60)
print("演示 6: 在链中使用 with_retry")
print("=" * 60)

# 创建一个可能失败的链
def add_one(x: int) -> int:
    """加1"""
    return x + 1

def unreliable_multiply(x: int) -> int:
    """不可靠的乘法,可能失败"""
    if random.random() < 0.7:  # 70% 的概率失败
        raise ValueError("随机失败")
    return x * 2

# 创建链,只对可能失败的步骤添加重试
chain_with_retry = (
    RunnableLambda(add_one) |
    RunnableLambda(unreliable_multiply).with_retry(
        retry_if_exception_type=(ValueError,),
        stop_after_attempt=5,
    )
)

print("\n创建链(只对可能失败的步骤添加重试):")
print("chain_with_retry = (")
print("    RunnableLambda(add_one) |")
print("    RunnableLambda(unreliable_multiply).with_retry(")
print("        retry_if_exception_type=(ValueError,),")
print("        stop_after_attempt=5,")
print("    )")
print(")")

print("\n测试链(可能需要多次重试):")
for i in range(3):
    try:
        result = chain_with_retry.invoke(5)
        print(f"  尝试 {i+1}: 输入=5 -> 输出={result}")
        break
    except ValueError as e:
        print(f"  尝试 {i+1}: 失败 - {e}")

print("\n" + "=" * 60)
print("演示 7: with_retry 与 RunnableBranch 结合")
print("=" * 60)

# 创建一个可能失败的分支
def risky_operation(x: int) -> int:
    """有风险的操作,可能失败"""
    if x % 3 == 0:
        raise ValueError("被3整除,失败")
    return x * 3

def safe_operation(x: int) -> int:
    """安全的操作"""
    return x * 2

# 创建分支,对风险操作添加重试
branch_with_retry = RunnableBranch(
    (lambda x: x % 2 == 0, RunnableLambda(risky_operation).with_retry(
        retry_if_exception_type=(ValueError,),
        stop_after_attempt=3,
    )),
    (lambda x: x % 2 == 1, RunnableLambda(safe_operation)),
    lambda x: x,
)

print("\n创建分支(对风险操作添加重试):")
print("branch_with_retry = RunnableBranch(")
print("    (lambda x: x % 2 == 0, RunnableLambda(risky_operation).with_retry(...)),")
print("    (lambda x: x % 2 == 1, RunnableLambda(safe_operation)),")
print("    lambda x: x,")
print(")")

test_values = [2, 3, 4, 5, 6]

print("\n测试不同的输入:")
for val in test_values:
    try:
        result = branch_with_retry.invoke(val)
        print(f"  输入: {val} -> 输出: {result}")
    except ValueError as e:
        print(f"  输入: {val} -> 失败: {e}")

print("\n" + "=" * 60)
print("演示 8: with_retry 批处理")
print("=" * 60)

# 创建一个批处理场景
batch_fail_count = 0

def batch_unreliable_function(x: int) -> int:
    """批处理中可能失败的函数"""
    global batch_fail_count
    batch_fail_count += 1
    if batch_fail_count <= 2:
        raise ValueError(f"批处理失败(第 {batch_fail_count} 次)")
    return x * 2

runnable_batch_retry = RunnableLambda(batch_unreliable_function).with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=3,
)

print("\n创建批处理重试 Runnable:")
print("runnable_batch_retry = RunnableLambda(batch_unreliable_function).with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=3,")
print(")")

# 批处理
batch_fail_count = 0
inputs = [1, 2, 3]

print(f"\n批处理输入: {inputs}")
try:
    results = runnable_batch_retry.batch(inputs)
    print(f"批处理结果: {results}")
    print(f"总调用次数: {batch_fail_count}")
except Exception as e:
    print(f"批处理失败: {e}")
    print(f"总调用次数: {batch_fail_count}")

print("\n" + "=" * 60)
print("演示 9: 模拟网络请求重试")
print("=" * 60)

# 模拟网络请求
network_attempt = 0

def network_request(url: str) -> str:
    """模拟网络请求,可能失败"""
    global network_attempt
    network_attempt += 1
    if network_attempt < 3:
        raise ConnectionError(f"网络连接失败(第 {network_attempt} 次尝试)")
    return f"成功获取 {url} 的内容"

# 创建带重试的网络请求 Runnable
network_runnable = RunnableLambda(network_request).with_retry(
    retry_if_exception_type=(ConnectionError,),
    stop_after_attempt=5,
    wait_exponential_jitter=True,
)

print("\n创建网络请求 Runnable(带重试):")
print("network_runnable = RunnableLambda(network_request).with_retry(")
print("    retry_if_exception_type=(ConnectionError,),")
print("    stop_after_attempt=5,")
print("    wait_exponential_jitter=True,")
print(")")

network_attempt = 0
start_time = time.time()

try:
    result = network_runnable.invoke("https://example.com")
    elapsed_time = time.time() - start_time
    print(f"\n请求 URL: https://example.com")
    print(f"结果: {result}")
    print(f"总尝试次数: {network_attempt}")
    print(f"总耗时: {elapsed_time:.2f} 秒")
except Exception as e:
    elapsed_time = time.time() - start_time
    print(f"\n最终失败: {e}")
    print(f"总尝试次数: {network_attempt}")
    print(f"总耗时: {elapsed_time:.2f} 秒")

print("\n" + "=" * 60)
print("演示 10: 实际应用场景")
print("=" * 60)

# 实际应用:API 调用重试
api_call_count = 0

def api_call(data: dict) -> dict:
    """模拟 API 调用,可能因为限流或网络问题失败"""
    global api_call_count
    api_call_count += 1
    if api_call_count < 2:
        raise ValueError("API 调用失败(可能因为限流)")
    return {"status": "success", "data": data}

# 创建带重试的 API 调用链
api_chain = RunnableLambda(api_call).with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=5,
    wait_exponential_jitter=True,
)

print("\n创建 API 调用链(带重试):")
print("api_chain = RunnableLambda(api_call).with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=5,")
print("    wait_exponential_jitter=True,")
print(")")

api_call_count = 0
test_data = {"user_id": 123, "action": "get_profile"}

try:
    result = api_chain.invoke(test_data)
    print(f"\nAPI 调用数据: {test_data}")
    print(f"API 响应: {result}")
    print(f"总调用次数: {api_call_count}")
except Exception as e:
    print(f"\nAPI 调用失败: {e}")
    print(f"总调用次数: {api_call_count}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nwith_retry 的核心概念:")
print("1. 自动重试:当执行失败时自动重试")
print("2. 异常类型过滤:只重试指定的异常类型")
print("3. 重试次数控制:可以设置最大重试次数")
print("4. 指数退避:支持指数退避和抖动,避免频繁重试")
print("\n使用场景:")
print("- 网络请求可能因为临时网络问题失败")
print("- API 调用可能因为限流或服务暂时不可用失败")
print("- 数据库操作可能因为连接问题失败")
print("- 任何可能因为临时问题失败的操作")
print("\n注意事项:")
print("- 只对可能因为临时问题失败的异常进行重试")
print("- 不要对业务逻辑错误进行重试(如参数验证错误)")
print("- 合理设置重试次数,避免无限重试")
print("- 使用指数退避可以避免对服务器造成过大压力")
print("- 在链中只对可能失败的步骤添加重试,而不是整个链")

28.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented
+   
+   def with_retry(
+       self,
+       *,
+       retry_if_exception_type: tuple = (Exception,),
+       wait_exponential_jitter: bool = True,
+       stop_after_attempt: int = 3,
+       **kwargs
+   ):
+       """
+       创建一个带重试功能的 Runnable
+       
+       Args:
+           retry_if_exception_type: 需要重试的异常类型元组
+           wait_exponential_jitter: 是否使用指数退避和抖动
+           stop_after_attempt: 最大重试次数
+           **kwargs: 其他参数
+       
+       Returns:
+           RunnableRetry: 带重试功能的 Runnable
+       """
+       return RunnableRetry(
+           bound=self,
+           retry_exception_types=retry_if_exception_type,
+           wait_exponential_jitter=wait_exponential_jitter,
+           max_attempt_number=stop_after_attempt,
+           **kwargs
+       )


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


class RunnablePassthrough(Runnable):
    """传递输入数据不变或添加额外键的 Runnable

    RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
    它常用于在链中保留原始输入数据,同时添加处理后的结果。
    """

    def __init__(self, func=None):
        """
        初始化 RunnablePassthrough

        Args:
            func: 可选的函数,会在传递数据时调用(用于副作用)
        """
        self.func = func

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入数据"""
        # 如果提供了函数,调用它(用于副作用)
        if self.func is not None:
            if callable(self.func):
                try:
                    self.func(input_data, config=config, **kwargs)
                except TypeError:
                    self.func(input_data)

        # 直接返回输入数据
        return input_data

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    @classmethod
    def assign(cls, **kwargs):
        """创建一个 RunnableAssign,在传递数据的同时添加额外字段

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象

        Returns:
            RunnableAssign: 可以添加额外字段的 Runnable
        """
        return RunnableAssign(**kwargs)

    def __repr__(self):
        return "RunnablePassthrough()"


class RunnableAssign(Runnable):
    """在传递字典数据的同时添加额外字段的 Runnable

    RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
    在保留原有字段的同时,添加新的字段。
    """

    def __init__(self, **kwargs):
        """
        初始化 RunnableAssign

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
        """
        self.assignments = {}
        for key, value in kwargs.items():
            if isinstance(value, dict):
                # 嵌套字典转换为 RunnableParallel
                self.assignments[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                self.assignments[key] = value
            else:
                raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入并添加额外字段"""
        # 确保输入是字典
        if not isinstance(input_data, dict):
            input_data = {"input": input_data}

        # 复制输入数据
        result = dict(input_data)

        # 添加新字段
        for key, assignment in self.assignments.items():
            if hasattr(assignment, 'invoke'):
                # 检查是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(assignment.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
                    else:
                        result[key] = assignment.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        result[key] = assignment.invoke(input_data, **kwargs)
                    except TypeError:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
            elif callable(assignment):
                result[key] = assignment(input_data)
            else:
                raise ValueError(f"赋值 '{key}' 不可调用")

        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __ror__(self, other):
        """支持从右侧使用 | 操作符(用于字典字面量)"""
        if isinstance(other, dict):
            # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
            return RunnableSequence(RunnableParallel(other), self)
        return NotImplemented

    def __repr__(self):
        return f"RunnableAssign({list(self.assignments.keys())})"


class RunnableBranch(Runnable):
    """根据条件选择执行分支的 Runnable

    RunnableBranch 类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行。
    它接受多个 (condition, Runnable) 对,按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行。
    如果没有条件满足,则执行默认分支。
    """

    def __init__(self, *branches):
        """
        初始化 RunnableBranch

        Args:
            *branches: 多个 (condition, Runnable) 对,最后一个可以是默认分支(Runnable 或可调用对象)
        """
        if len(branches) < 2:
            raise ValueError("RunnableBranch 至少需要 2 个分支(包括默认分支)")

        # 分离条件和默认分支
        self.branches = []
        default = None

        for i, branch in enumerate(branches):
            if i == len(branches) - 1:
                # 最后一个可能是默认分支
                if isinstance(branch, tuple) and len(branch) == 2:
                    # 仍然是 (condition, Runnable) 对
                    condition, runnable = branch
                    self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
                else:
                    # 是默认分支
                    default = self._coerce_to_runnable(branch)
            else:
                # 必须是 (condition, Runnable) 对
                if not isinstance(branch, tuple) or len(branch) != 2:
                    raise ValueError(f"分支 {i} 必须是 (condition, Runnable) 对,但得到 {type(branch)}")
                condition, runnable = branch
                self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))

        if default is None:
            raise ValueError("RunnableBranch 必须提供默认分支(最后一个参数)")

        self.default = default

    def _coerce_to_condition(self, condition):
        """将条件转换为可调用对象"""
        if callable(condition):
            return condition
        elif hasattr(condition, 'invoke'):
            # 如果已经是 Runnable,返回一个包装函数
            def wrapped_condition(input_data):
                result = condition.invoke(input_data)
                return bool(result)
            return wrapped_condition
        else:
            raise ValueError(f"条件必须是可调用对象或 Runnable,但得到 {type(condition)}")

    def _coerce_to_runnable(self, runnable):
        """将值转换为 Runnable"""
        if hasattr(runnable, 'invoke'):
            return runnable
        elif callable(runnable):
            return RunnableLambda(runnable)
        else:
            raise ValueError(f"Runnable 必须是可调用对象或 Runnable,但得到 {type(runnable)}")

    def invoke(self, input_data, config=None, **kwargs):
        """根据条件选择分支执行"""
        # 按顺序检查条件
        for condition, runnable in self.branches:
            try:
                # 评估条件
                condition_result = condition(input_data)
                if condition_result:
                    # 条件满足,执行对应的 Runnable
                    if hasattr(runnable, 'invoke'):
                        import inspect
                        try:
                            sig = inspect.signature(runnable.invoke)
                            params = list(sig.parameters.keys())
                            if 'config' in params:
                                return runnable.invoke(input_data, config=config, **kwargs)
                            else:
                                return runnable.invoke(input_data, **kwargs)
                        except (ValueError, TypeError):
                            try:
                                return runnable.invoke(input_data, **kwargs)
                            except TypeError:
                                return runnable.invoke(input_data, config=config, **kwargs)
                    elif callable(runnable):
                        return runnable(input_data)
            except Exception as e:
                # 如果条件评估出错,继续下一个条件
                continue

        # 没有条件满足,执行默认分支
        if hasattr(self.default, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.default.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.default.invoke(input_data, config=config, **kwargs)
                else:
                    return self.default.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.default.invoke(input_data, **kwargs)
                except TypeError:
                    return self.default.invoke(input_data, config=config, **kwargs)
        elif callable(self.default):
            return self.default(input_data)
        else:
            return self.default

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableBranch({len(self.branches)} branches + default)"


+class RunnableRetry(Runnable):
+   """带重试功能的 Runnable
+   
+   RunnableRetry 包装另一个 Runnable,当执行失败时会自动重试。
+   支持指数退避和抖动,以及指定重试的异常类型。
+   """
+   
+   def __init__(
+       self,
+       bound,
+       retry_exception_types=(Exception,),
+       wait_exponential_jitter=True,
+       max_attempt_number=3,
+       **kwargs
+   ):
+       """
+       初始化 RunnableRetry
+       
+       Args:
+           bound: 要包装的 Runnable
+           retry_exception_types: 需要重试的异常类型元组
+           wait_exponential_jitter: 是否使用指数退避和抖动
+           max_attempt_number: 最大重试次数
+           **kwargs: 其他参数
+       """
+       self.bound = bound
+       self.retry_exception_types = retry_exception_types
+       self.wait_exponential_jitter = wait_exponential_jitter
+       self.max_attempt_number = max_attempt_number
+       self.kwargs = kwargs
+   
+   def _should_retry(self, exception):
+       """检查是否应该重试"""
+       return isinstance(exception, self.retry_exception_types)
+   
+   def _wait_time(self, attempt_number):
+       """计算等待时间(指数退避)"""
+       if not self.wait_exponential_jitter:
+           return 0
+       
+       import random
+       # 指数退避:2^attempt_number 秒,最大 10 秒
+       base_wait = min(2 ** attempt_number, 10)
+       # 添加抖动:随机 0-1 秒
+       jitter = random.uniform(0, 1)
+       return base_wait + jitter
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """执行并重试"""
+       last_exception = None
+       
+       for attempt in range(1, self.max_attempt_number + 1):
+           try:
+               # 调用原始的 Runnable
+               if hasattr(self.bound, 'invoke'):
+                   import inspect
+                   try:
+                       sig = inspect.signature(self.bound.invoke)
+                       params = list(sig.parameters.keys())
+                       if 'config' in params:
+                           return self.bound.invoke(input_data, config=config, **kwargs)
+                       else:
+                           return self.bound.invoke(input_data, **kwargs)
+                   except (ValueError, TypeError):
+                       try:
+                           return self.bound.invoke(input_data, **kwargs)
+                       except TypeError:
+                           return self.bound.invoke(input_data, config=config, **kwargs)
+               elif callable(self.bound):
+                   return self.bound(input_data)
+               else:
+                   raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")
+           except Exception as e:
+               last_exception = e
+               if not self._should_retry(e):
+                   # 不应该重试的异常,直接抛出
+                   raise
+               
+               if attempt < self.max_attempt_number:
+                   # 计算等待时间
+                   wait_time = self._wait_time(attempt)
+                   if wait_time > 0:
+                       import time
+                       time.sleep(wait_time)
+                   # 继续重试
+                   continue
+               else:
+                   # 达到最大重试次数,抛出最后一个异常
+                   raise
+       
+       # 如果所有重试都失败,抛出最后一个异常
+       if last_exception:
+           raise last_exception
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   def __repr__(self):
+       return f"RunnableRetry(bound={self.bound}, max_attempts={self.max_attempt_number})"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

29. Config #

Config——灵活的执行上下文参数容器

在 LCEL(LangChain Expression Language)范式中,Config 并不是一个专用的类,而是以普通 Python 字典(dict)的形式传递,作为配置与元数据容器,贯穿于整个 Runnable 链式调用的始终。
它允许你为每次数据链路运行,动态注入追踪标签(tags)、运行名(run_name)、元信息(metadata)、唯一 ID(run_id)、回调函数(callbacks)、并发限制(max_concurrency)等参数。这极大增强了链路的可追踪性、可监控性和灵活性,方便在大规模部署、调试、性能分析等场景中精准定位与管控每一次调用。

Config 的核心属性说明

  • run_name: 本次调用的名字,便于日志和可视化追踪
  • run_id: 调用的唯一标识(如 uuid),用于在分布式或复杂流水线中精确定位
  • tags: 标签组(list),支持对调用批量分组、过滤检索
  • metadata: 以 dict 形式存储任意业务或上下文相关的额外信息(如接口名、用户 ID、调用参数…)
  • callbacks: 可选,配置调用前后、异常等时机的回调处理逻辑
  • max_concurrency: 控制批处理时的最大并发度
  • recursion_limit: 递归嵌套时的最大层数
  • configurable: 动态传参用于运行时调整 Runnable 的行为(如模型参数温度等)

Config 的典型应用场景

  • 链路追踪 & 监控:传递 run_id, run_name, tags、metadata 实现全流程精细追踪和 API 监控分析
  • 用户分隔与定制:metadata 可用于标记当前 session、user、业务 ID,实现多租户或权限隔离
  • 动态配置:某些组件(如 LLM)可用 config 里的参数覆盖默认行为,实现运行期灵活调整
  • 并发控制:max_concurrency 可灵活限制批量处理的并发数,适配资源约束

Config 常用传递方式

  • 作为 .invoke(input, config=config)、.batch(inputs, config=config)、.stream(...) 等方法的参数字典明示传递。
  • 支持向下游链路自动透传(部分属性如 tags/metadata 会继承;run_id/run_name 只在发起环节起作用)。

注意事项

  • config 是完全可选的,不传时所有属性均有默认值(如自动生成 run_id)。
  • 推荐合理利用 tags/metadata 分类与分层管理调用,方便大规模场景下的溯源与治理。
  • 具体哪些属性被哪些组件识别与利用,可查对应文档或源码,例如 Runnable、模型/解析器不同组件的特殊处理。

29.1. Config.py #

29.Config.py

#from langchain_core.prompts import PromptTemplate
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import BaseOutputParser, JsonOutputParser,StrOutputParser
#from langchain_core.runnables import RunnableLambda,RunnableBranch
#from langchain_core.callbacks import BaseCallbackHandler
#import time
#import uuid

from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.output_parsers import StrOutputParser, JsonOutputParser
from langchain.runnables import RunnableLambda, RunnableBranch
from langchain.callbacks import BaseCallbackHandler
from typing import Optional, Any
from uuid import UUID
import uuid
import time

print("=" * 60)
print("Config 演示")
print("=" * 60)
print("\nConfig 用于配置 Runnable 的执行行为")
print("包括 tags、metadata、callbacks、run_name、max_concurrency 等\n")

# 创建模型和组件
llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()

print("=" * 60)
print("演示 1: 基本 Config 用法")
print("=" * 60)

# 创建一个简单的 Runnable
def process_data(x: int) -> int:
    """处理数据"""
    return x * 2

runnable = RunnableLambda(process_data)

print("\n创建 Runnable:")
print("runnable = RunnableLambda(process_data)")

# 使用 config
config = {
    "run_name": "process_data_operation",
    "tags": ["demo", "processing"],
    "metadata": {"version": "1.0", "user": "test_user"},
}

print("\n创建 Config:")
print("config = {")
print("    'run_name': 'process_data_operation',")
print("    'tags': ['demo', 'processing'],")
print("    'metadata': {'version': '1.0', 'user': 'test_user'},")
print("}")

result = runnable.invoke(5, config=config)
print(f"\n使用 config 调用:")
print(f"  输入: 5")
print(f"  输出: {result}")
print(f"  Config: {config}")

print("\n" + "=" * 60)
print("演示 2: run_name - 运行名称")
print("=" * 60)

# run_name 用于标识特定的运行
def add_one(x: int) -> int:
    """加1"""
    return x + 1

def multiply_by_two(x: int) -> int:
    """乘以2"""
    return x * 2

# 创建链
chain = RunnableLambda(add_one) | RunnableLambda(multiply_by_two)

print("\n创建链:")
print("chain = RunnableLambda(add_one) | RunnableLambda(multiply_by_two)")

# 使用不同的 run_name
configs = [
    {"run_name": "operation_1"},
    {"run_name": "operation_2"},
    {"run_name": "operation_3"},
]

print("\n使用不同的 run_name:")
for i, config in enumerate(configs, 1):
    result = chain.invoke(5, config=config)
    print(f"  {i}. run_name='{config['run_name']}': 输入=5 -> 输出={result}")

print("\n" + "=" * 60)
print("演示 3: tags - 标签")
print("=" * 60)

# tags 用于分类和过滤
def process_text(text: str) -> str:
    """处理文本"""
    return text.upper()

text_processor = RunnableLambda(process_text)

print("\n创建文本处理 Runnable:")
print("text_processor = RunnableLambda(process_text)")

# 使用不同的 tags
configs = [
    {"tags": ["production", "text"]},
    {"tags": ["development", "text"]},
    {"tags": ["test", "text"]},
]

test_texts = ["hello", "world", "python"]

print("\n使用不同的 tags:")
for config, text in zip(configs, test_texts):
    result = text_processor.invoke(text, config=config)
    print(f"  tags={config['tags']}: 输入='{text}' -> 输出='{result}'")

print("\n" + "=" * 60)
print("演示 4: metadata - 元数据")
print("=" * 60)

# metadata 用于存储额外的信息
def calculate_sum(numbers: list) -> int:
    """计算总和"""
    return sum(numbers)

sum_calculator = RunnableLambda(calculate_sum)

print("\n创建计算器 Runnable:")
print("sum_calculator = RunnableLambda(calculate_sum)")

# 使用 metadata 存储额外信息
config = {
    "metadata": {
        "operation": "sum",
        "timestamp": time.time(),
        "user_id": 12345,
        "request_id": str(uuid.uuid4()),
    }
}

print("\n使用 metadata:")
print(f"config = {config}")

result = sum_calculator.invoke([1, 2, 3, 4, 5], config=config)
print(f"\n输入: [1, 2, 3, 4, 5]")
print(f"输出: {result}")
print(f"Metadata: {config['metadata']}")

print("\n" + "=" * 60)
print("演示 5: run_id - 运行ID")
print("=" * 60)

# run_id 用于唯一标识每次运行
def process_item(item: str) -> str:
    """处理项目"""
    return f"Processed: {item}"

processor = RunnableLambda(process_item)

print("\n创建处理器 Runnable:")
print("processor = RunnableLambda(process_item)")

# 使用 run_id
run_id_1 = uuid.uuid4()
run_id_2 = uuid.uuid4()

configs = [
    {"run_id": run_id_1, "run_name": "run_1"},
    {"run_id": run_id_2, "run_name": "run_2"},
]

print("\n使用不同的 run_id:")
for config in configs:
    result = processor.invoke("item1", config=config)
    print(f"  run_id={config['run_id']}: {result}")

print("\n" + "=" * 60)
print("演示 6: max_concurrency - 最大并发数")
print("=" * 60)

# max_concurrency 用于控制批处理的并发数
def slow_operation(x: int) -> int:
    """慢速操作"""
    time.sleep(0.1)  # 模拟慢速操作
    return x * 2

slow_runnable = RunnableLambda(slow_operation)

print("\n创建慢速操作 Runnable:")
print("slow_runnable = RunnableLambda(slow_operation)")

# 测试不同的 max_concurrency
inputs = list(range(1, 6))

print(f"\n批处理输入: {inputs}")

for max_concurrency in [1, 3, 5]:
    config = {"max_concurrency": max_concurrency}
    start_time = time.time()
    results = slow_runnable.batch(inputs, config=config)
    elapsed_time = time.time() - start_time
    print(f"\n  max_concurrency={max_concurrency}:")
    print(f"    结果: {results}")
    print(f"    耗时: {elapsed_time:.2f} 秒")

print("\n" + "=" * 60)
print("演示 7: recursion_limit - 递归限制")
print("=" * 60)

# recursion_limit 用于限制递归深度
recursion_count = 0

def recursive_function(x: int, depth: int = 0) -> int:
    """递归函数"""
    global recursion_count
    recursion_count += 1
    if depth < 3:
        return recursive_function(x + 1, depth + 1)
    return x

# 注意:这个演示需要 Runnable 支持递归
# 这里只是展示概念
print("\n递归限制的概念:")
print("recursion_limit 用于限制 Runnable 可以递归的最大次数")
print("默认值通常是 25")

print("\n" + "=" * 60)
print("演示 8: configurable - 可配置属性")
print("=" * 60)

# configurable 用于运行时配置 Runnable 的属性
def configurable_function(x: int) -> int:
    """可配置的函数"""
    # 在实际应用中,可以从 config 中读取配置
    return x * 2

configurable_runnable = RunnableLambda(configurable_function)

print("\n创建可配置 Runnable:")
print("configurable_runnable = RunnableLambda(configurable_function)")

# 使用 configurable
config = {
    "configurable": {
        "multiplier": 3,
        "operation": "multiply",
    }
}

print("\n使用 configurable:")
print(f"config = {config}")

result = configurable_runnable.invoke(5, config=config)
print(f"\n输入: 5")
print(f"输出: {result}")
print(f"Configurable: {config['configurable']}")

print("\n" + "=" * 60)
print("演示 9: 组合使用多个 Config 选项")
print("=" * 60)

# 组合使用多个配置选项
def complex_operation(x: int) -> int:
    """复杂操作"""
    return x ** 2

complex_runnable = RunnableLambda(complex_operation)

print("\n创建复杂操作 Runnable:")
print("complex_runnable = RunnableLambda(complex_operation)")

# 组合多个配置选项
config = {
    "run_name": "complex_operation",
    "run_id": uuid.uuid4(),
    "tags": ["math", "square", "production"],
    "metadata": {
        "operation": "square",
        "user": "admin",
        "timestamp": time.time(),
    },
    "configurable": {
        "precision": 2,
    },
}

print("\n组合使用多个配置选项:")
print(f"config = {config}")

result = complex_runnable.invoke(5, config=config)
print(f"\n输入: 5")
print(f"输出: {result}")
print(f"\n配置详情:")
print(f"  run_name: {config['run_name']}")
print(f"  run_id: {config['run_id']}")
print(f"  tags: {config['tags']}")
print(f"  metadata: {config['metadata']}")
print(f"  configurable: {config['configurable']}")

print("\n" + "=" * 60)
print("演示 10: 在链中使用 Config")
print("=" * 60)

# 在链中使用 config
def step1(x: int) -> int:
    """步骤1"""
    return x + 1

def step2(x: int) -> int:
    """步骤2"""
    return x * 2

def step3(x: int) -> int:
    """步骤3"""
    return x - 1

# 创建链
chain = (
    RunnableLambda(step1) |
    RunnableLambda(step2) |
    RunnableLambda(step3)
)

print("\n创建链:")
print("chain = (")
print("    RunnableLambda(step1) |")
print("    RunnableLambda(step2) |")
print("    RunnableLambda(step3)")
print(")")

# 使用 config
config = {
    "run_name": "three_step_chain",
    "tags": ["chain", "multi_step"],
    "metadata": {
        "steps": 3,
        "description": "三步处理链",
    },
}

print("\n在链中使用 config:")
print(f"config = {config}")

result = chain.invoke(5, config=config)
print(f"\n输入: 5")
print(f"步骤 1: 5 + 1 = 6")
print(f"步骤 2: 6 * 2 = 12")
print(f"步骤 3: 12 - 1 = {result}")
print(f"最终输出: {result}")

print("\n" + "=" * 60)
print("演示 11: Config 与 LLM 链结合")
print("=" * 60)

# 在 LLM 链中使用 config
prompt = PromptTemplate.from_template("请回答:{question}")

llm_chain = prompt | llm | parser

print("\n创建 LLM 链:")
print("llm_chain = prompt | llm | parser")

# 使用 config 添加标签和元数据
config = {
    "run_name": "llm_qa_chain",
    "tags": ["llm", "qa", "production"],
    "metadata": {
        "model": "gpt-4o",
        "temperature": 0.7,
        "user_id": 12345,
    },
}

print("\n使用 config:")
print(f"config = {config}")

question = "什么是 Python?"
result = llm_chain.invoke({"question": question}, config=config)

print(f"\n问题: {question}")
print(f"回答: {result[:50]}...")
print(f"\n使用的配置:")
print(f"  run_name: {config['run_name']}")
print(f"  tags: {config['tags']}")
print(f"  metadata: {config['metadata']}")

print("\n" + "=" * 60)
print("演示 12: callbacks - 回调处理器")
print("=" * 60)

# 使用 BaseCallbackHandler
from typing import Optional, Any
from uuid import UUID

class CustomCallbackHandler(BaseCallbackHandler):
    """自定义回调处理器"""

    def __init__(self):
        super().__init__()
        self.start_count = 0
        self.end_count = 0
        self.error_count = 0

    def on_run_start(
        self,
        run_id: UUID,
        name: Optional[str] = None,
        **kwargs: Any
    ) -> None:
        """运行开始时调用"""
        self.start_count += 1
        print(f"  [自定义回调] 运行开始 #{self.start_count}: name={name}, run_id={run_id}")
        super().on_run_start(run_id, name, **kwargs)

    def on_run_end(
        self,
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """运行结束时调用"""
        self.end_count += 1
        print(f"  [自定义回调] 运行结束 #{self.end_count}: run_id={run_id}")
        super().on_run_end(run_id, **kwargs)

    def on_run_error(
        self,
        run_id: UUID,
        error: Exception,
        **kwargs: Any
    ) -> None:
        """运行出错时调用"""
        self.error_count += 1
        print(f"  [自定义回调] 运行错误 #{self.error_count}: run_id={run_id}, error={error}")
        super().on_run_error(run_id, error, **kwargs)

# 创建自定义回调处理器
custom_handler = CustomCallbackHandler()

# 创建 Runnable
def process_number(x: int) -> int:
    """处理数字"""
    return x * 2

processor = RunnableLambda(process_number)

print("\n创建自定义回调处理器:")
print("class CustomCallbackHandler(BaseCallbackHandler):")
print("    def on_run_start(...): ...")
print("    def on_run_end(...): ...")
print("    def on_run_error(...): ...")

# 使用 callbacks
config = {
    "callbacks": [custom_handler],
    "run_name": "number_processing",
}

print("\n使用 callbacks:")
print("config = {")
print("    'callbacks': [custom_handler],")
print("    'run_name': 'number_processing',")
print("}")

print("\n调用 Runnable(会触发回调):")
result = processor.invoke(5, config=config)
print(f"  输入: 5")
print(f"  输出: {result}")
print(f"\n回调统计:")
print(f"  开始次数: {custom_handler.start_count}")
print(f"  结束次数: {custom_handler.end_count}")
print(f"  错误次数: {custom_handler.error_count}")

print("\n" + "=" * 60)
print("演示 13: 多个回调处理器")
print("=" * 60)

# 使用多个回调处理器
custom_handler1 = CustomCallbackHandler()
custom_handler2 = CustomCallbackHandler()

print("\n创建多个自定义回调处理器:")
print("custom_handler1 = CustomCallbackHandler()")
print("custom_handler2 = CustomCallbackHandler()")

# 创建 Runnable
def complex_operation(x: int) -> int:
    """复杂操作"""
    return x ** 2

complex_runnable = RunnableLambda(complex_operation)

# 使用多个回调处理器
config = {
    "callbacks": [custom_handler1, custom_handler2],
    "run_name": "complex_operation",
    "tags": ["demo", "callbacks"],
    "metadata": {"operation": "square"},
}

print("\n使用多个回调处理器:")
print("config = {")
print("    'callbacks': [custom_handler1, custom_handler2],")
print("    'run_name': 'complex_operation',")
print("    'tags': ['demo', 'callbacks'],")
print("    'metadata': {'operation': 'square'},")
print("}")

print("\n调用 Runnable(会触发所有回调):")
result = complex_runnable.invoke(5, config=config)
print(f"  输入: 5")
print(f"  输出: {result}")

print("\n回调统计:")
print(f"  CustomCallbackHandler 1 开始次数: {custom_handler1.start_count}")
print(f"  CustomCallbackHandler 2 开始次数: {custom_handler2.start_count}")

print("\n" + "=" * 60)
print("演示 14: Config 的实际应用场景")
print("=" * 60)

# 实际应用:API 调用追踪
def api_call(data: dict) -> dict:
    """模拟 API 调用"""
    return {"status": "success", "data": data}

api_runnable = RunnableLambda(api_call)

print("\n创建 API 调用 Runnable:")
print("api_runnable = RunnableLambda(api_call)")

# 使用 config 进行追踪
config = {
    "run_name": "user_profile_api",
    "run_id": uuid.uuid4(),
    "tags": ["api", "user_profile", "production"],
    "metadata": {
        "endpoint": "/api/user/profile",
        "method": "GET",
        "user_id": 12345,
        "request_time": time.time(),
    },
}

print("\n使用 config 进行 API 追踪:")
print(f"config = {config}")

test_data = {"user_id": 12345, "action": "get_profile"}
result = api_runnable.invoke(test_data, config=config)

print(f"\nAPI 调用数据: {test_data}")
print(f"API 响应: {result}")
print(f"\n追踪信息:")
print(f"  run_id: {config['run_id']}")
print(f"  run_name: {config['run_name']}")
print(f"  tags: {config['tags']}")
print(f"  metadata: {config['metadata']}")

print("\n" + "=" * 60)
print("演示 17: with_config - 绑定配置")
print("=" * 60)

# with_config 用于将配置绑定到 Runnable
def process_data(x: int) -> int:
    """处理数据"""
    return x * 2

runnable = RunnableLambda(process_data)

print("\n创建 Runnable:")
print("runnable = RunnableLambda(process_data)")

# 使用 with_config 绑定配置
configured_runnable = runnable.with_config(
    run_name="configured_processor",
    tags=["production", "data_processing"],
    metadata={"version": "1.0"},
)

print("\n使用 with_config 绑定配置:")
print("configured_runnable = runnable.with_config(")
print("    run_name='configured_processor',")
print("    tags=['production', 'data_processing'],")
print("    metadata={'version': '1.0'},")
print(")")

# 调用时不需要再传递配置
result = configured_runnable.invoke(5)
print(f"\n调用(自动使用绑定的配置):")
print(f"  输入: 5")
print(f"  输出: {result}")
print(f"  绑定的配置: {configured_runnable.config}")

print("\n" + "=" * 60)
print("演示 18: with_config 与回调结合")
print("=" * 60)

# 使用 with_config 绑定回调
custom_handler = CustomCallbackHandler()

print("\n创建自定义回调处理器:")
print("custom_handler = CustomCallbackHandler()")

# 绑定回调到 Runnable
runnable_with_callback = runnable.with_config(
    callbacks=[custom_handler],
    run_name="runnable_with_callback",
)

print("\n使用 with_config 绑定回调:")
print("runnable_with_callback = runnable.with_config(")
print("    callbacks=[custom_handler],")
print("    run_name='runnable_with_callback',")
print(")")

print("\n调用(会自动触发回调):")
result = runnable_with_callback.invoke(10)
print(f"  输入: 10")
print(f"  输出: {result}")
print(f"  回调开始次数: {custom_handler.start_count}")

print("\n" + "=" * 60)
print("演示 19: with_config 链式调用")
print("=" * 60)

# with_config 可以链式调用
def add_one(x: int) -> int:
    """加1"""
    return x + 1

def multiply_by_two(x: int) -> int:
    """乘以2"""
    return x * 2

# 创建链并绑定配置
chain = RunnableLambda(add_one) | RunnableLambda(multiply_by_two)

configured_chain = chain.with_config(
    run_name="configured_chain",
    tags=["math", "chain"],
).with_config(
    metadata={"operation": "add_then_multiply"},
)

print("\n创建链并链式绑定配置:")
print("configured_chain = chain.with_config(")
print("    run_name='configured_chain',")
print("    tags=['math', 'chain'],")
print(").with_config(")
print("    metadata={'operation': 'add_then_multiply'},")
print(")")

result = configured_chain.invoke(5)
print(f"\n调用:")
print(f"  输入: 5")
print(f"  输出: {result}")
print(f"  绑定的配置: {configured_chain.config}")

print("\n" + "=" * 60)
print("演示 20: with_config 覆盖配置")
print("=" * 60)

# with_config 绑定的配置可以被调用时传入的配置覆盖
base_runnable = RunnableLambda(process_data)

configured_base = base_runnable.with_config(
    run_name="base_processor",
    tags=["base"],
    metadata={"base": True},
)

print("\n创建基础配置的 Runnable:")
print("configured_base = base_runnable.with_config(")
print("    run_name='base_processor',")
print("    tags=['base'],")
print("    metadata={'base': True},")
print(")")

# 调用时传入的配置会覆盖绑定的配置
result1 = configured_base.invoke(5)
print(f"\n使用绑定的配置调用:")
print(f"  输入: 5")
print(f"  输出: {result1}")
print(f"  使用的配置: {configured_base.config}")

# 传入新配置(会覆盖绑定的配置)
result2 = configured_base.invoke(5, config={"run_name": "override_processor", "tags": ["override"]})
print(f"\n使用覆盖的配置调用:")
print(f"  输入: 5")
print(f"  输出: {result2}")
print(f"  注意:传入的配置会覆盖绑定的配置")

print("\n" + "=" * 60)
print("演示 21: with_config 在链中使用")
print("=" * 60)

# 在链的特定步骤使用 with_config
def step1(x: int) -> int:
    """步骤1"""
    return x + 1

def step2(x: int) -> int:
    """步骤2"""
    return x * 2

def step3(x: int) -> int:
    """步骤3"""
    return x - 1

# 只对特定步骤绑定配置
chain_with_config = (
    RunnableLambda(step1).with_config(run_name="step1", tags=["step1"]) |
    RunnableLambda(step2).with_config(run_name="step2", tags=["step2"]) |
    RunnableLambda(step3).with_config(run_name="step3", tags=["step3"])
)

print("\n创建链(每个步骤绑定配置):")
print("chain_with_config = (")
print("    RunnableLambda(step1).with_config(run_name='step1', tags=['step1']) |")
print("    RunnableLambda(step2).with_config(run_name='step2', tags=['step2']) |")
print("    RunnableLambda(step3).with_config(run_name='step3', tags=['step3'])")
print(")")

result = chain_with_config.invoke(5)
print(f"\n调用:")
print(f"  输入: 5")
print(f"  输出: {result}")

print("\n" + "=" * 60)
print("演示 22: with_config 与 with_retry 结合")
print("=" * 60)

# 结合使用 with_config 和 with_retry
def unreliable_function(x: int) -> int:
    """不可靠的函数"""
    import random
    if random.random() < 0.5:
        raise ValueError("随机失败")
    return x * 2

unreliable_runnable = RunnableLambda(unreliable_function)

# 先绑定配置,再添加重试
configured_with_retry = unreliable_runnable.with_config(
    run_name="unreliable_with_config",
    tags=["unreliable", "retry"],
).with_retry(
    retry_if_exception_type=(ValueError,),
    stop_after_attempt=3,
)

print("\n创建带配置和重试的 Runnable:")
print("configured_with_retry = unreliable_runnable.with_config(")
print("    run_name='unreliable_with_config',")
print("    tags=['unreliable', 'retry'],")
print(").with_retry(")
print("    retry_if_exception_type=(ValueError,),")
print("    stop_after_attempt=3,")
print(")")

print("\n调用(可能需要重试):")
for i in range(3):
    try:
        result = configured_with_retry.invoke(5)
        print(f"  尝试 {i+1}: 成功,输出={result}")
        break
    except ValueError as e:
        print(f"  尝试 {i+1}: 失败 - {e}")

print("\n" + "=" * 60)
print("演示 23: with_config 实际应用场景")
print("=" * 60)

# 实际应用:为不同的环境创建不同的配置
def api_call(data: dict) -> dict:
    """API 调用"""
    return {"status": "success", "data": data}

api_runnable = RunnableLambda(api_call)

# 为生产环境创建配置
production_runnable = api_runnable.with_config(
    run_name="production_api",
    tags=["production", "api"],
    metadata={"environment": "production", "timeout": 30},
)

# 为开发环境创建配置
development_runnable = api_runnable.with_config(
    run_name="development_api",
    tags=["development", "api"],
    metadata={"environment": "development", "timeout": 10},
)

print("\n为不同环境创建配置:")
print("production_runnable = api_runnable.with_config(")
print("    run_name='production_api',")
print("    tags=['production', 'api'],")
print("    metadata={'environment': 'production', 'timeout': 30},")
print(")")
print("\ndevelopment_runnable = api_runnable.with_config(")
print("    run_name='development_api',")
print("    tags=['development', 'api'],")
print("    metadata={'environment': 'development', 'timeout': 10},")
print(")")

test_data = {"user_id": 12345}

print("\n生产环境调用:")
result1 = production_runnable.invoke(test_data)
print(f"  输入: {test_data}")
print(f"  输出: {result1}")
print(f"  配置: {production_runnable.config}")

print("\n开发环境调用:")
result2 = development_runnable.invoke(test_data)
print(f"  输入: {test_data}")
print(f"  输出: {result2}")
print(f"  配置: {development_runnable.config}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nConfig 的核心属性:")
print("1. run_name: 用于标识特定的运行(不继承)")
print("2. run_id: 唯一标识每次调用")
print("3. tags: 用于分类和过滤调用")
print("4. metadata: 存储额外的元数据信息")
print("5. callbacks: 用于回调处理")
print("6. max_concurrency: 控制批处理的最大并发数")
print("7. recursion_limit: 限制递归的最大深度")
print("8. configurable: 运行时配置 Runnable 的属性")
print("\nwith_config 方法:")
print("- 将配置绑定到 Runnable,返回一个新的 Runnable")
print("- 绑定的配置会在每次调用时自动使用")
print("- 可以链式调用,逐步添加配置")
print("- 调用时传入的配置会覆盖绑定的配置")
print("- 适合为不同的环境或场景创建预配置的 Runnable")
print("\n使用场景:")
print("- 追踪和调试:使用 run_name 和 run_id 追踪特定调用")
print("- 分类和过滤:使用 tags 对调用进行分类")
print("- 元数据存储:使用 metadata 存储额外的上下文信息")
print("- 性能控制:使用 max_concurrency 控制并发数")
print("- 运行时配置:使用 configurable 动态配置属性")
print("- 预配置 Runnable:使用 with_config 为不同场景创建预配置的 Runnable")
print("\n注意事项:")
print("- Config 是可选的,所有属性都有默认值")
print("- Config 可以传递给 invoke、batch、stream 等方法")
print("- Config 中的信息可以用于追踪、调试和监控")
print("- 合理使用 tags 和 metadata 可以更好地组织和管理调用")
print("- with_config 绑定的配置可以被调用时传入的配置覆盖")
print("- 可以结合 with_retry 等方法使用,创建功能丰富的 Runnable")

29.2. callbacks.py #

langchain\callbacks.py

# 实现 BaseCallbackHandler 用于回调处理

from typing import Any, Dict, List, Optional
from uuid import UUID


class BaseCallbackHandler:
    """基础回调处理器

    BaseCallbackHandler 用于处理 Runnable 执行过程中的各种事件。
    可以重写各种方法来处理不同的事件,如开始、结束、错误等。
    """

    def __init__(self):
        """初始化回调处理器"""
        self.events = []

    def on_run_start(
        self,
        run_id: UUID,
        name: Optional[str] = None,
        **kwargs: Any
    ) -> None:
        """运行开始时调用"""
        event = {
            "type": "on_run_start",
            "run_id": run_id,
            "name": name,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_run_end(
        self,
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """运行结束时调用"""
        event = {
            "type": "on_run_end",
            "run_id": run_id,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_run_error(
        self,
        run_id: UUID,
        error: Exception,
        **kwargs: Any
    ) -> None:
        """运行出错时调用"""
        event = {
            "type": "on_run_error",
            "run_id": run_id,
            "error": str(error),
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """LLM 开始时调用"""
        event = {
            "type": "on_llm_start",
            "run_id": run_id,
            "prompts": prompts,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_llm_end(
        self,
        response: Any,
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """LLM 结束时调用"""
        event = {
            "type": "on_llm_end",
            "run_id": run_id,
            "response": str(response)[:100] if response else None,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_llm_error(
        self,
        error: Exception,
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """LLM 出错时调用"""
        event = {
            "type": "on_llm_error",
            "run_id": run_id,
            "error": str(error),
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """链开始时调用"""
        event = {
            "type": "on_chain_start",
            "run_id": run_id,
            "inputs": inputs,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """链结束时调用"""
        event = {
            "type": "on_chain_end",
            "run_id": run_id,
            "outputs": outputs,
            "kwargs": kwargs,
        }
        self.events.append(event)

    def on_chain_error(
        self,
        error: Exception,
        run_id: UUID,
        **kwargs: Any
    ) -> None:
        """链出错时调用"""
        event = {
            "type": "on_chain_error",
            "run_id": run_id,
            "error": str(error),
            "kwargs": kwargs,
        }
        self.events.append(event)

    def get_events(self) -> List[Dict[str, Any]]:
        """获取所有事件"""
        return self.events

    def clear_events(self) -> None:
        """清除所有事件"""
        self.events = []

29.3. runnables.py #

langchain\runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_retry(
        self,
        *,
        retry_if_exception_type: tuple = (Exception,),
        wait_exponential_jitter: bool = True,
        stop_after_attempt: int = 3,
        **kwargs
    ):
        """
        创建一个带重试功能的 Runnable

        Args:
            retry_if_exception_type: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            stop_after_attempt: 最大重试次数
            **kwargs: 其他参数

        Returns:
            RunnableRetry: 带重试功能的 Runnable
        """
        return RunnableRetry(
            bound=self,
            retry_exception_types=retry_if_exception_type,
            wait_exponential_jitter=wait_exponential_jitter,
            max_attempt_number=stop_after_attempt,
            **kwargs
        )

    def with_config(
        self,
        config=None,
        **kwargs
    ):
        """
        将配置绑定到 Runnable,返回一个新的 Runnable

        Args:
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数(会合并到 config 中)

        Returns:
            RunnableBinding: 绑定配置后的 Runnable
        """
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        merged_config = {**config, **kwargs}
        return RunnableBinding(bound=self, config=merged_config)


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


class RunnablePassthrough(Runnable):
    """传递输入数据不变或添加额外键的 Runnable

    RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
    它常用于在链中保留原始输入数据,同时添加处理后的结果。
    """

    def __init__(self, func=None):
        """
        初始化 RunnablePassthrough

        Args:
            func: 可选的函数,会在传递数据时调用(用于副作用)
        """
        self.func = func

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入数据"""
        # 如果提供了函数,调用它(用于副作用)
        if self.func is not None:
            if callable(self.func):
                try:
                    self.func(input_data, config=config, **kwargs)
                except TypeError:
                    self.func(input_data)

        # 直接返回输入数据
        return input_data

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    @classmethod
    def assign(cls, **kwargs):
        """创建一个 RunnableAssign,在传递数据的同时添加额外字段

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象

        Returns:
            RunnableAssign: 可以添加额外字段的 Runnable
        """
        return RunnableAssign(**kwargs)

    def __repr__(self):
        return "RunnablePassthrough()"


class RunnableAssign(Runnable):
    """在传递字典数据的同时添加额外字段的 Runnable

    RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
    在保留原有字段的同时,添加新的字段。
    """

    def __init__(self, **kwargs):
        """
        初始化 RunnableAssign

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
        """
        self.assignments = {}
        for key, value in kwargs.items():
            if isinstance(value, dict):
                # 嵌套字典转换为 RunnableParallel
                self.assignments[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                self.assignments[key] = value
            else:
                raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入并添加额外字段"""
        # 确保输入是字典
        if not isinstance(input_data, dict):
            input_data = {"input": input_data}

        # 复制输入数据
        result = dict(input_data)

        # 添加新字段
        for key, assignment in self.assignments.items():
            if hasattr(assignment, 'invoke'):
                # 检查是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(assignment.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
                    else:
                        result[key] = assignment.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        result[key] = assignment.invoke(input_data, **kwargs)
                    except TypeError:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
            elif callable(assignment):
                result[key] = assignment(input_data)
            else:
                raise ValueError(f"赋值 '{key}' 不可调用")

        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __ror__(self, other):
        """支持从右侧使用 | 操作符(用于字典字面量)"""
        if isinstance(other, dict):
            # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
            return RunnableSequence(RunnableParallel(other), self)
        return NotImplemented

    def __repr__(self):
        return f"RunnableAssign({list(self.assignments.keys())})"


class RunnableBranch(Runnable):
    """根据条件选择执行分支的 Runnable

    RunnableBranch 类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行。
    它接受多个 (condition, Runnable) 对,按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行。
    如果没有条件满足,则执行默认分支。
    """

    def __init__(self, *branches):
        """
        初始化 RunnableBranch

        Args:
            *branches: 多个 (condition, Runnable) 对,最后一个可以是默认分支(Runnable 或可调用对象)
        """
        if len(branches) < 2:
            raise ValueError("RunnableBranch 至少需要 2 个分支(包括默认分支)")

        # 分离条件和默认分支
        self.branches = []
        default = None

        for i, branch in enumerate(branches):
            if i == len(branches) - 1:
                # 最后一个可能是默认分支
                if isinstance(branch, tuple) and len(branch) == 2:
                    # 仍然是 (condition, Runnable) 对
                    condition, runnable = branch
                    self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
                else:
                    # 是默认分支
                    default = self._coerce_to_runnable(branch)
            else:
                # 必须是 (condition, Runnable) 对
                if not isinstance(branch, tuple) or len(branch) != 2:
                    raise ValueError(f"分支 {i} 必须是 (condition, Runnable) 对,但得到 {type(branch)}")
                condition, runnable = branch
                self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))

        if default is None:
            raise ValueError("RunnableBranch 必须提供默认分支(最后一个参数)")

        self.default = default

    def _coerce_to_condition(self, condition):
        """将条件转换为可调用对象"""
        if callable(condition):
            return condition
        elif hasattr(condition, 'invoke'):
            # 如果已经是 Runnable,返回一个包装函数
            def wrapped_condition(input_data):
                result = condition.invoke(input_data)
                return bool(result)
            return wrapped_condition
        else:
            raise ValueError(f"条件必须是可调用对象或 Runnable,但得到 {type(condition)}")

    def _coerce_to_runnable(self, runnable):
        """将值转换为 Runnable"""
        if hasattr(runnable, 'invoke'):
            return runnable
        elif callable(runnable):
            return RunnableLambda(runnable)
        else:
            raise ValueError(f"Runnable 必须是可调用对象或 Runnable,但得到 {type(runnable)}")

    def invoke(self, input_data, config=None, **kwargs):
        """根据条件选择分支执行"""
        # 按顺序检查条件
        for condition, runnable in self.branches:
            try:
                # 评估条件
                condition_result = condition(input_data)
                if condition_result:
                    # 条件满足,执行对应的 Runnable
                    if hasattr(runnable, 'invoke'):
                        import inspect
                        try:
                            sig = inspect.signature(runnable.invoke)
                            params = list(sig.parameters.keys())
                            if 'config' in params:
                                return runnable.invoke(input_data, config=config, **kwargs)
                            else:
                                return runnable.invoke(input_data, **kwargs)
                        except (ValueError, TypeError):
                            try:
                                return runnable.invoke(input_data, **kwargs)
                            except TypeError:
                                return runnable.invoke(input_data, config=config, **kwargs)
                    elif callable(runnable):
                        return runnable(input_data)
            except Exception as e:
                # 如果条件评估出错,继续下一个条件
                continue

        # 没有条件满足,执行默认分支
        if hasattr(self.default, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.default.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.default.invoke(input_data, config=config, **kwargs)
                else:
                    return self.default.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.default.invoke(input_data, **kwargs)
                except TypeError:
                    return self.default.invoke(input_data, config=config, **kwargs)
        elif callable(self.default):
            return self.default(input_data)
        else:
            return self.default

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableBranch({len(self.branches)} branches + default)"


class RunnableRetry(Runnable):
    """带重试功能的 Runnable

    RunnableRetry 包装另一个 Runnable,当执行失败时会自动重试。
    支持指数退避和抖动,以及指定重试的异常类型。
    """

    def __init__(
        self,
        bound,
        retry_exception_types=(Exception,),
        wait_exponential_jitter=True,
        max_attempt_number=3,
        **kwargs
    ):
        """
        初始化 RunnableRetry

        Args:
            bound: 要包装的 Runnable
            retry_exception_types: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            max_attempt_number: 最大重试次数
            **kwargs: 其他参数
        """
        self.bound = bound
        self.retry_exception_types = retry_exception_types
        self.wait_exponential_jitter = wait_exponential_jitter
        self.max_attempt_number = max_attempt_number
        self.kwargs = kwargs

    def _should_retry(self, exception):
        """检查是否应该重试"""
        return isinstance(exception, self.retry_exception_types)

    def _wait_time(self, attempt_number):
        """计算等待时间(指数退避)"""
        if not self.wait_exponential_jitter:
            return 0

        import random
        # 指数退避:2^attempt_number 秒,最大 10 秒
        base_wait = min(2 ** attempt_number, 10)
        # 添加抖动:随机 0-1 秒
        jitter = random.uniform(0, 1)
        return base_wait + jitter

    def invoke(self, input_data, config=None, **kwargs):
        """执行并重试"""
        last_exception = None

        for attempt in range(1, self.max_attempt_number + 1):
            try:
                # 调用原始的 Runnable
                if hasattr(self.bound, 'invoke'):
                    import inspect
                    try:
                        sig = inspect.signature(self.bound.invoke)
                        params = list(sig.parameters.keys())
                        if 'config' in params:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                        else:
                            return self.bound.invoke(input_data, **kwargs)
                    except (ValueError, TypeError):
                        try:
                            return self.bound.invoke(input_data, **kwargs)
                        except TypeError:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                elif callable(self.bound):
                    return self.bound(input_data)
                else:
                    raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")
            except Exception as e:
                last_exception = e
                if not self._should_retry(e):
                    # 不应该重试的异常,直接抛出
                    raise

                if attempt < self.max_attempt_number:
                    # 计算等待时间
                    wait_time = self._wait_time(attempt)
                    if wait_time > 0:
                        import time
                        time.sleep(wait_time)
                    # 继续重试
                    continue
                else:
                    # 达到最大重试次数,抛出最后一个异常
                    raise

        # 如果所有重试都失败,抛出最后一个异常
        if last_exception:
            raise last_exception

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableRetry(bound={self.bound}, max_attempts={self.max_attempt_number})"


class RunnableBinding(Runnable):
    """绑定配置的 Runnable

    RunnableBinding 包装另一个 Runnable,并绑定配置。
    当调用绑定的 Runnable 时,会自动使用绑定的配置。
    """

    def __init__(self, bound, config=None, **kwargs):
        """
        初始化 RunnableBinding

        Args:
            bound: 要包装的 Runnable
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数
        """
        self.bound = bound
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        self.config = {**config, **kwargs}

    def invoke(self, input_data, config=None, **kwargs):
        """执行并合并配置"""
        # 合并绑定的配置和传入的配置(传入的配置优先)
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        # 合并 kwargs 到 config
        merged_config.update(kwargs)

        # 调用原始的 Runnable
        if hasattr(self.bound, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.bound.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
                else:
                    return self.bound.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.bound.invoke(input_data, **kwargs)
                except TypeError:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
        elif callable(self.bound):
            return self.bound(input_data)
        else:
            raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_config(self, config=None, **kwargs):
        """添加更多配置(合并到现有配置)"""
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        merged_config.update(kwargs)
        return RunnableBinding(bound=self.bound, config=merged_config)

    def __repr__(self):
        return f"RunnableBinding(bound={self.bound}, config={self.config})"


def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
    from langchain.prompts import PromptTemplate
    from langchain.chat_models import ChatOpenAI
    from langchain.output_parsers import BaseOutputParser

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

30. ConfigurableField #

通过 ConfigurableField,我们可以构建支持自定义参数的组件,无需重新定义/实例化类,只需在 .invoke 调用时通过 config 传参即可动态调整行为。
这对于封装 LLM、可配置处理链、流程变体切换等场景非常实用。

一般搭配 .configurable_fields() 方法使用,将某些字段设为“可配置”,并指定它们的类型、说明等。这样,调用时可以灵活传入参数。

  • 支持的数据类型包括 int, float, str 等任意注解。
  • 可以用作单一选项、具有默认值或多选项(具体参考下方 ConfigurableFieldSingleOption, ConfigurableFieldMultiOption)。
  • 结合 LLM, Prompt, OutputParser、各类自定义处理器等,都可以利用 ConfigurableField 实现运行时动态参数化与行为变换。

30.1. Configurable.py #

30.Configurable.py

from langchain.runnables import RunnableLambda, ConfigurableField
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser

#from langchain.prompts import PromptTemplate
#from langchain.chat_models import ChatOpenAI
#from langchain.output_parsers import StrOutputParser, JsonOutputParser
#from langchain.callbacks import BaseCallbackHandler
#from langchain.runnables import RunnableLambda,ConfigurableField
import time
import uuid

llm = ChatOpenAI(model="gpt-4o")
parser = StrOutputParser()
print("=" * 60)
print("ConfigurableField 演示")
print("=" * 60)
print("\nConfigurableField: 用于在运行时动态配置 Runnable 的属性")
print("允许通过 config 参数在调用时动态改变 Runnable 的行为\n")

# 注意:某些演示需要 LLM 支持,如果导入失败将跳过

print("=" * 60)
print("演示 1: ConfigurableField 的基本用法")
print("=" * 60)

# 创建一个可配置的 Runnable
class ConfigurableProcessor(RunnableLambda):
    """可配置的处理器"""

    def __init__(self, multiplier=1):
        self._multiplier = multiplier
        # 创建内部处理函数
        def process_func(x: int) -> int:
            return x * self._multiplier
        super().__init__(process_func)

    def invoke(self, input_data, config=None, **kwargs):
        """调用方法"""
        # 从 config 中获取可配置字段
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'multiplier' in configurable:
                self._multiplier = configurable['multiplier']
        # 使用当前的 multiplier 计算结果
        if isinstance(input_data, int):
            return input_data * self._multiplier
        return input_data

processor = ConfigurableProcessor()

# 使用 configurable_fields 定义可配置字段
configurable_processor = processor.configurable_fields(
    multiplier=ConfigurableField(
        id="multiplier",
        name="Multiplier",
        description="用于乘以输入的数字",
        annotation=int,
    )
)

print("\n创建可配置的处理器:")
print("configurable_processor = processor.configurable_fields(")
print("    multiplier=ConfigurableField(")
print("        id='multiplier',")
print("        name='Multiplier',")
print("        description='用于乘以输入的数字',")
print("        annotation=int,")
print("    )")
print(")")

# 使用默认配置
print("\n使用默认配置:")
result1 = configurable_processor.invoke(5)
print(f"  输入: 5")
print(f"  输出: {result1}")

# 使用自定义配置
print("\n使用自定义配置:")
result2 = configurable_processor.invoke(5, config={"configurable": {"multiplier": 3}})
print(f"  输入: 5")
print(f"  配置: {{'configurable': {{'multiplier': 3}}}}")
print(f"  输出: {result2}")

print("\n" + "=" * 60)
print("演示 2: ConfigurableField 与 LLM 结合")
print("=" * 60)

# 创建一个可配置温度的 LLM 包装器
class ConfigurableLLM(RunnableLambda):
    """可配置的 LLM 包装器"""

    def __init__(self, llm, temperature=0.7):
        self.llm = llm
        self._temperature = temperature
        # 创建包装函数
        def llm_func(input_data):
            return input_data
        super().__init__(llm_func)

    def invoke(self, input_data, config=None, **kwargs):
        """调用 LLM"""
        # 从 config 中获取温度
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'temperature' in configurable:
                self._temperature = configurable['temperature']
                # 更新 LLM 的温度(如果支持)
                if hasattr(self.llm, 'temperature'):
                    self.llm.temperature = self._temperature

        # 调用 LLM
        if hasattr(self.llm, 'invoke'):
            return self.llm.invoke(input_data, config=config, **kwargs)
        return input_data

configurable_llm = ConfigurableLLM(llm)

# 使用 configurable_fields
llm_with_config = configurable_llm.configurable_fields(
    temperature=ConfigurableField(
        id="temperature",
        name="Temperature",
        description="LLM 的温度参数,控制输出的随机性",
        annotation=float,
    )
)

print("\n创建可配置温度的 LLM:")
print("llm_with_config = configurable_llm.configurable_fields(")
print("    temperature=ConfigurableField(")
print("        id='temperature',")
print("        name='Temperature',")
print("        description='LLM 的温度参数,控制输出的随机性',")
print("        annotation=float,")
print("    )")
print(")")

# 使用不同的温度
if PromptTemplate is not None and llm is not None:
    prompt = PromptTemplate.from_template("请用一句话介绍:{topic}")

    print("\n使用不同的温度配置:")
    print("(注意:此演示需要 LLM API 支持,如果失败请检查 API 配置)")
    try:
        for temp in [0.1, 0.7, 1.0]:
            # 先格式化提示
            formatted = prompt.format(topic="Python")
            # 然后调用 LLM
            result = llm_with_config.invoke(
                formatted,
                config={"configurable": {"temperature": temp}}
            )
            # 解析结果
            if parser is not None:
                if hasattr(result, 'content'):
                    parsed = parser.parse(result.content)
                else:
                    parsed = str(result)
            else:
                parsed = str(result)
            print(f"  温度={temp}: {parsed[:50]}...")
    except Exception as e:
        print(f"  (演示失败:{e},可能需要配置 LLM API)")
else:
    print("\n(跳过此演示,需要 PromptTemplate 和 LLM 支持)")

print("\n" + "=" * 60)
print("演示 3: ConfigurableField 多个字段")
print("=" * 60)

# 创建有多个可配置字段的处理器
class MultiConfigProcessor(RunnableLambda):
    """多配置处理器"""

    def __init__(self, multiplier=1, offset=0, prefix=""):
        self._multiplier = multiplier
        self._offset = offset
        self._prefix = prefix
        # 创建处理函数
        def process_func(x: int) -> str:
            result = x * self._multiplier + self._offset
            return f"{self._prefix}{result}"
        super().__init__(process_func)

    def invoke(self, input_data, config=None, **kwargs):
        """调用方法"""
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'multiplier' in configurable:
                self._multiplier = configurable['multiplier']
            if 'offset' in configurable:
                self._offset = configurable['offset']
            if 'prefix' in configurable:
                self._prefix = configurable['prefix']
        # 使用当前配置计算结果
        if isinstance(input_data, int):
            result = input_data * self._multiplier + self._offset
            return f"{self._prefix}{result}"
        return input_data

multi_processor = MultiConfigProcessor()

# 定义多个可配置字段
multi_config_processor = multi_processor.configurable_fields(
    multiplier=ConfigurableField(
        id="multiplier",
        name="Multiplier",
        description="乘数",
        annotation=int,
    ),
    offset=ConfigurableField(
        id="offset",
        name="Offset",
        description="偏移量",
        annotation=int,
    ),
    prefix=ConfigurableField(
        id="prefix",
        name="Prefix",
        description="前缀字符串",
        annotation=str,
    ),
)

print("\n创建多配置处理器:")
print("multi_config_processor = multi_processor.configurable_fields(")
print("    multiplier=ConfigurableField(...),")
print("    offset=ConfigurableField(...),")
print("    prefix=ConfigurableField(...),")
print(")")

# 使用不同的配置
configs = [
    {"configurable": {"multiplier": 2, "offset": 0, "prefix": ""}},
    {"configurable": {"multiplier": 3, "offset": 1, "prefix": "结果: "}},
    {"configurable": {"multiplier": 4, "offset": -1, "prefix": "输出="}},
]

print("\n使用不同的配置:")
for i, config in enumerate(configs, 1):
    result = multi_config_processor.invoke(5, config=config)
    print(f"  配置 {i}: {result}")

print("\n" + "=" * 60)
print("演示 4: ConfigurableField 与 with_config 结合")
print("=" * 60)

# 使用 with_config 绑定可配置字段的值
configured_processor = multi_config_processor.with_config(
    configurable={"multiplier": 5, "offset": 10, "prefix": "计算: "}
)

print("\n使用 with_config 绑定配置:")
print("configured_processor = multi_config_processor.with_config(")
print("    configurable={'multiplier': 5, 'offset': 10, 'prefix': '计算: '}")
print(")")

# 调用时自动使用绑定的配置
result = configured_processor.invoke(5)
print(f"\n调用(自动使用绑定的配置):")
print(f"  输入: 5")
print(f"  输出: {result}")

print("\n" + "=" * 60)
print("演示 5: ConfigurableField 在链中使用")
print("=" * 60)

# 在链中使用可配置字段
def add_one(x: int) -> int:
    """加1"""
    return x + 1

def multiply(x: int, multiplier: int = 2) -> int:
    """乘以倍数"""
    return x * multiplier

# 创建可配置的乘法函数
class ConfigurableMultiply(RunnableLambda):
    """可配置的乘法"""

    def __init__(self, multiplier=2):
        self._multiplier = multiplier
        def multiply_func(x: int) -> int:
            return x * self._multiplier
        super().__init__(multiply_func)

    def invoke(self, input_data, config=None, **kwargs):
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'multiplier' in configurable:
                self._multiplier = configurable['multiplier']
        if isinstance(input_data, int):
            return input_data * self._multiplier
        return input_data

configurable_multiply = ConfigurableMultiply().configurable_fields(
    multiplier=ConfigurableField(
        id="multiplier",
        name="Multiplier",
        description="乘数",
        annotation=int,
    )
)

# 创建链
chain = RunnableLambda(add_one) | configurable_multiply

print("\n创建链(包含可配置字段):")
print("chain = RunnableLambda(add_one) | configurable_multiply")

# 使用不同的配置
print("\n使用不同的配置:")
for multiplier in [2, 3, 5]:
    result = chain.invoke(5, config={"configurable": {"multiplier": multiplier}})
    print(f"  multiplier={multiplier}: 输入=5 -> 输出={result}")

print("\n" + "=" * 60)
print("演示 6: ConfigurableField 与 RunnableBranch 结合")
print("=" * 60)

# 使用可配置字段控制分支选择
class ConfigurableBranch(RunnableLambda):
    """可配置的分支"""

    def __init__(self, operation="add"):
        self._operation = operation
        def branch_func(x: int) -> int:
            if self._operation == "add":
                return x + 1
            elif self._operation == "multiply":
                return x * 2
            elif self._operation == "square":
                return x ** 2
            else:
                return x
        super().__init__(branch_func)

    def invoke(self, input_data, config=None, **kwargs):
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'operation' in configurable:
                self._operation = configurable['operation']

        if isinstance(input_data, int):
            if self._operation == "add":
                return input_data + 1
            elif self._operation == "multiply":
                return input_data * 2
            elif self._operation == "square":
                return input_data ** 2
        return input_data

configurable_branch = ConfigurableBranch().configurable_fields(
    operation=ConfigurableField(
        id="operation",
        name="Operation",
        description="操作类型:add, multiply, square",
        annotation=str,
    )
)

print("\n创建可配置分支:")
print("configurable_branch = ConfigurableBranch().configurable_fields(")
print("    operation=ConfigurableField(...)")
print(")")

# 使用不同的操作
operations = ["add", "multiply", "square"]

print("\n使用不同的操作:")
for op in operations:
    result = configurable_branch.invoke(5, config={"configurable": {"operation": op}})
    print(f"  operation={op}: 输入=5 -> 输出={result}")

print("\n" + "=" * 60)
print("演示 7: ConfigurableField 批处理")
print("=" * 60)

# 批处理中使用可配置字段
print("\n批处理中使用可配置字段:")

inputs = [1, 2, 3, 4, 5]

# 使用默认配置
results1 = configurable_multiply.batch(inputs)
print(f"默认配置 (multiplier=2):")
print(f"  输入: {inputs}")
print(f"  输出: {results1}")

# 使用自定义配置
results2 = configurable_multiply.batch(inputs, config={"configurable": {"multiplier": 3}})
print(f"\n自定义配置 (multiplier=3):")
print(f"  输入: {inputs}")
print(f"  输出: {results2}")

print("\n" + "=" * 60)
print("演示 8: ConfigurableField 实际应用场景")
print("=" * 60)

# 实际应用:可配置的文本处理
class ConfigurableTextProcessor(RunnableLambda):
    """可配置的文本处理器"""

    def __init__(self, case="lower", max_length=None):
        self._case = case
        self._max_length = max_length
        def process_func(text: str) -> str:
            # 应用大小写转换
            if self._case == "upper":
                text = text.upper()
            elif self._case == "lower":
                text = text.lower()
            elif self._case == "title":
                text = text.title()

            # 应用长度限制
            if self._max_length and len(text) > self._max_length:
                text = text[:self._max_length] + "..."

            return text
        super().__init__(process_func)

    def invoke(self, input_data, config=None, **kwargs):
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'case' in configurable:
                self._case = configurable['case']
            if 'max_length' in configurable:
                self._max_length = configurable['max_length']

        # 使用当前配置处理文本
        if isinstance(input_data, str):
            text = input_data
            # 应用大小写转换
            if self._case == "upper":
                text = text.upper()
            elif self._case == "lower":
                text = text.lower()
            elif self._case == "title":
                text = text.title()

            # 应用长度限制
            if self._max_length and len(text) > self._max_length:
                text = text[:self._max_length] + "..."

            return text
        return input_data

text_processor = ConfigurableTextProcessor().configurable_fields(
    case=ConfigurableField(
        id="case",
        name="Case",
        description="文本大小写:upper, lower, title",
        annotation=str,
    ),
    max_length=ConfigurableField(
        id="max_length",
        name="Max Length",
        description="最大长度(None 表示无限制)",
        annotation=int,
    ),
)

print("\n创建可配置的文本处理器:")
print("text_processor = ConfigurableTextProcessor().configurable_fields(")
print("    case=ConfigurableField(...),")
print("    max_length=ConfigurableField(...),")
print(")")

test_text = "Hello World, This is a Test"

print(f"\n测试文本: {test_text}")

# 使用不同的配置
configs = [
    {"configurable": {"case": "upper", "max_length": None}},
    {"configurable": {"case": "lower", "max_length": 20}},
    {"configurable": {"case": "title", "max_length": 15}},
]

print("\n使用不同的配置:")
for i, config in enumerate(configs, 1):
    result = text_processor.invoke(test_text, config=config)
    print(f"  配置 {i}: {result}")

print("\n" + "=" * 60)
print("演示 9: ConfigurableField 与 PromptTemplate 结合")
print("=" * 60)

# 创建可配置的提示模板
class ConfigurablePrompt(RunnableLambda):
    """可配置的提示模板"""

    def __init__(self, template="请回答:{question}"):
        self._template = template
        def format_func(input_data):
            if isinstance(input_data, dict):
                return self._template.format(**input_data)
            return self._template.format(question=str(input_data))
        super().__init__(format_func)

    def format(self, **kwargs):
        """格式化提示"""
        return self._template.format(**kwargs)

    def invoke(self, input_data, config=None, **kwargs):
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'template' in configurable:
                self._template = configurable['template']

        if isinstance(input_data, dict):
            return self.format(**input_data)
        return self.format(question=str(input_data))

configurable_prompt = ConfigurablePrompt().configurable_fields(
    template=ConfigurableField(
        id="template",
        name="Template",
        description="提示模板",
        annotation=str,
    )
)

print("\n创建可配置的提示模板:")
print("configurable_prompt = ConfigurablePrompt().configurable_fields(")
print("    template=ConfigurableField(...)")
print(")")

# 使用不同的模板
templates = [
    "请简洁回答:{question}",
    "请详细回答:{question}",
    "请用一句话回答:{question}",
]

print("\n使用不同的模板:")
for i, template in enumerate(templates, 1):
    result = configurable_prompt.invoke(
        {"question": "什么是 Python?"},
        config={"configurable": {"template": template}}
    )
    print(f"  模板 {i}: {result}")

print("\n" + "=" * 60)
print("演示 10: ConfigurableField 组合使用")
print("=" * 60)

# 组合多个可配置字段
class ComplexProcessor(RunnableLambda):
    """复杂的可配置处理器"""

    def __init__(self, multiplier=1, offset=0, prefix="", suffix=""):
        self._multiplier = multiplier
        self._offset = offset
        self._prefix = prefix
        self._suffix = suffix
        def process_func(x: int) -> str:
            result = x * self._multiplier + self._offset
            return f"{self._prefix}{result}{self._suffix}"
        super().__init__(process_func)

    def invoke(self, input_data, config=None, **kwargs):
        if config and 'configurable' in config:
            configurable = config['configurable']
            if 'multiplier' in configurable:
                self._multiplier = configurable['multiplier']
            if 'offset' in configurable:
                self._offset = configurable['offset']
            if 'prefix' in configurable:
                self._prefix = configurable['prefix']
            if 'suffix' in configurable:
                self._suffix = configurable['suffix']

        # 使用当前配置计算结果
        if isinstance(input_data, int):
            result = input_data * self._multiplier + self._offset
            return f"{self._prefix}{result}{self._suffix}"
        return input_data

complex_processor = ComplexProcessor().configurable_fields(
    multiplier=ConfigurableField(
        id="multiplier",
        name="Multiplier",
        description="乘数",
        annotation=int,
    ),
    offset=ConfigurableField(
        id="offset",
        name="Offset",
        description="偏移量",
        annotation=int,
    ),
    prefix=ConfigurableField(
        id="prefix",
        name="Prefix",
        description="前缀",
        annotation=str,
    ),
    suffix=ConfigurableField(
        id="suffix",
        name="Suffix",
        description="后缀",
        annotation=str,
    ),
)

print("\n创建复杂的可配置处理器:")
print("complex_processor = ComplexProcessor().configurable_fields(")
print("    multiplier=ConfigurableField(...),")
print("    offset=ConfigurableField(...),")
print("    prefix=ConfigurableField(...),")
print("    suffix=ConfigurableField(...),")
print(")")

# 使用完整配置
config = {
    "configurable": {
        "multiplier": 3,
        "offset": 5,
        "prefix": "结果: ",
        "suffix": " (处理完成)",
    }
}

result = complex_processor.invoke(10, config=config)
print(f"\n使用完整配置:")
print(f"  输入: 10")
print(f"  配置: {config['configurable']}")
print(f"  输出: {result}")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nConfigurableField 的核心概念:")
print("1. 运行时配置:允许在调用时动态配置 Runnable 的属性")
print("2. 字段定义:使用 ConfigurableField 定义可配置的字段")
print("3. 配置传递:通过 config['configurable'] 传递配置值")
print("4. 灵活组合:可以定义多个可配置字段")
print("\n使用场景:")
print("- 动态调整 LLM 参数(如 temperature、max_tokens)")
print("- 运行时选择不同的处理策略")
print("- 根据环境或用户需求调整行为")
print("- 创建可重用的、可配置的组件")
print("\n注意事项:")
print("- ConfigurableField 的 id 用于在 config['configurable'] 中引用")
print("- 可以定义多个可配置字段")
print("- 配置值在调用时通过 config 参数传递")
print("- 可以结合 with_config 方法绑定配置")
print("- 适合需要根据运行时条件调整行为的场景")

30.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_retry(
        self,
        *,
        retry_if_exception_type: tuple = (Exception,),
        wait_exponential_jitter: bool = True,
        stop_after_attempt: int = 3,
        **kwargs
    ):
        """
        创建一个带重试功能的 Runnable

        Args:
            retry_if_exception_type: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            stop_after_attempt: 最大重试次数
            **kwargs: 其他参数

        Returns:
            RunnableRetry: 带重试功能的 Runnable
        """
        return RunnableRetry(
            bound=self,
            retry_exception_types=retry_if_exception_type,
            wait_exponential_jitter=wait_exponential_jitter,
            max_attempt_number=stop_after_attempt,
            **kwargs
        )

    def with_config(
        self,
        config=None,
        **kwargs
    ):
        """
        将配置绑定到 Runnable,返回一个新的 Runnable

        Args:
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数(会合并到 config 中)

        Returns:
            RunnableBinding: 绑定配置后的 Runnable
        """
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        merged_config = {**config, **kwargs}
        return RunnableBinding(bound=self, config=merged_config)
+   
+   def configurable_fields(self, **kwargs):
+       """
+       配置可配置字段
+       
+       Args:
+           **kwargs: ConfigurableField 实例,键是字段名,值是 ConfigurableField
+       
+       Returns:
+           RunnableConfigurableFields: 可配置字段的 Runnable
+       """
+       return RunnableConfigurableFields(bound=self, fields=kwargs)


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


class RunnablePassthrough(Runnable):
    """传递输入数据不变或添加额外键的 Runnable

    RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
    它常用于在链中保留原始输入数据,同时添加处理后的结果。
    """

    def __init__(self, func=None):
        """
        初始化 RunnablePassthrough

        Args:
            func: 可选的函数,会在传递数据时调用(用于副作用)
        """
        self.func = func

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入数据"""
        # 如果提供了函数,调用它(用于副作用)
        if self.func is not None:
            if callable(self.func):
                try:
                    self.func(input_data, config=config, **kwargs)
                except TypeError:
                    self.func(input_data)

        # 直接返回输入数据
        return input_data

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    @classmethod
    def assign(cls, **kwargs):
        """创建一个 RunnableAssign,在传递数据的同时添加额外字段

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象

        Returns:
            RunnableAssign: 可以添加额外字段的 Runnable
        """
        return RunnableAssign(**kwargs)

    def __repr__(self):
        return "RunnablePassthrough()"


class RunnableAssign(Runnable):
    """在传递字典数据的同时添加额外字段的 Runnable

    RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
    在保留原有字段的同时,添加新的字段。
    """

    def __init__(self, **kwargs):
        """
        初始化 RunnableAssign

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
        """
        self.assignments = {}
        for key, value in kwargs.items():
            if isinstance(value, dict):
                # 嵌套字典转换为 RunnableParallel
                self.assignments[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                self.assignments[key] = value
            else:
                raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入并添加额外字段"""
        # 确保输入是字典
        if not isinstance(input_data, dict):
            input_data = {"input": input_data}

        # 复制输入数据
        result = dict(input_data)

        # 添加新字段
        for key, assignment in self.assignments.items():
            if hasattr(assignment, 'invoke'):
                # 检查是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(assignment.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
                    else:
                        result[key] = assignment.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        result[key] = assignment.invoke(input_data, **kwargs)
                    except TypeError:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
            elif callable(assignment):
                result[key] = assignment(input_data)
            else:
                raise ValueError(f"赋值 '{key}' 不可调用")

        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __ror__(self, other):
        """支持从右侧使用 | 操作符(用于字典字面量)"""
        if isinstance(other, dict):
            # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
            return RunnableSequence(RunnableParallel(other), self)
        return NotImplemented

    def __repr__(self):
        return f"RunnableAssign({list(self.assignments.keys())})"


class RunnableBranch(Runnable):
    """根据条件选择执行分支的 Runnable

    RunnableBranch 类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行。
    它接受多个 (condition, Runnable) 对,按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行。
    如果没有条件满足,则执行默认分支。
    """

    def __init__(self, *branches):
        """
        初始化 RunnableBranch

        Args:
            *branches: 多个 (condition, Runnable) 对,最后一个可以是默认分支(Runnable 或可调用对象)
        """
        if len(branches) < 2:
            raise ValueError("RunnableBranch 至少需要 2 个分支(包括默认分支)")

        # 分离条件和默认分支
        self.branches = []
        default = None

        for i, branch in enumerate(branches):
            if i == len(branches) - 1:
                # 最后一个可能是默认分支
                if isinstance(branch, tuple) and len(branch) == 2:
                    # 仍然是 (condition, Runnable) 对
                    condition, runnable = branch
                    self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
                else:
                    # 是默认分支
                    default = self._coerce_to_runnable(branch)
            else:
                # 必须是 (condition, Runnable) 对
                if not isinstance(branch, tuple) or len(branch) != 2:
                    raise ValueError(f"分支 {i} 必须是 (condition, Runnable) 对,但得到 {type(branch)}")
                condition, runnable = branch
                self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))

        if default is None:
            raise ValueError("RunnableBranch 必须提供默认分支(最后一个参数)")

        self.default = default

    def _coerce_to_condition(self, condition):
        """将条件转换为可调用对象"""
        if callable(condition):
            return condition
        elif hasattr(condition, 'invoke'):
            # 如果已经是 Runnable,返回一个包装函数
            def wrapped_condition(input_data):
                result = condition.invoke(input_data)
                return bool(result)
            return wrapped_condition
        else:
            raise ValueError(f"条件必须是可调用对象或 Runnable,但得到 {type(condition)}")

    def _coerce_to_runnable(self, runnable):
        """将值转换为 Runnable"""
        if hasattr(runnable, 'invoke'):
            return runnable
        elif callable(runnable):
            return RunnableLambda(runnable)
        else:
            raise ValueError(f"Runnable 必须是可调用对象或 Runnable,但得到 {type(runnable)}")

    def invoke(self, input_data, config=None, **kwargs):
        """根据条件选择分支执行"""
        # 按顺序检查条件
        for condition, runnable in self.branches:
            try:
                # 评估条件
                condition_result = condition(input_data)
                if condition_result:
                    # 条件满足,执行对应的 Runnable
                    if hasattr(runnable, 'invoke'):
                        import inspect
                        try:
                            sig = inspect.signature(runnable.invoke)
                            params = list(sig.parameters.keys())
                            if 'config' in params:
                                return runnable.invoke(input_data, config=config, **kwargs)
                            else:
                                return runnable.invoke(input_data, **kwargs)
                        except (ValueError, TypeError):
                            try:
                                return runnable.invoke(input_data, **kwargs)
                            except TypeError:
                                return runnable.invoke(input_data, config=config, **kwargs)
                    elif callable(runnable):
                        return runnable(input_data)
            except Exception as e:
                # 如果条件评估出错,继续下一个条件
                continue

        # 没有条件满足,执行默认分支
        if hasattr(self.default, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.default.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.default.invoke(input_data, config=config, **kwargs)
                else:
                    return self.default.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.default.invoke(input_data, **kwargs)
                except TypeError:
                    return self.default.invoke(input_data, config=config, **kwargs)
        elif callable(self.default):
            return self.default(input_data)
        else:
            return self.default

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableBranch({len(self.branches)} branches + default)"


class RunnableRetry(Runnable):
    """带重试功能的 Runnable

    RunnableRetry 包装另一个 Runnable,当执行失败时会自动重试。
    支持指数退避和抖动,以及指定重试的异常类型。
    """

    def __init__(
        self,
        bound,
        retry_exception_types=(Exception,),
        wait_exponential_jitter=True,
        max_attempt_number=3,
        **kwargs
    ):
        """
        初始化 RunnableRetry

        Args:
            bound: 要包装的 Runnable
            retry_exception_types: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            max_attempt_number: 最大重试次数
            **kwargs: 其他参数
        """
        self.bound = bound
        self.retry_exception_types = retry_exception_types
        self.wait_exponential_jitter = wait_exponential_jitter
        self.max_attempt_number = max_attempt_number
        self.kwargs = kwargs

    def _should_retry(self, exception):
        """检查是否应该重试"""
        return isinstance(exception, self.retry_exception_types)

    def _wait_time(self, attempt_number):
        """计算等待时间(指数退避)"""
        if not self.wait_exponential_jitter:
            return 0

        import random
        # 指数退避:2^attempt_number 秒,最大 10 秒
        base_wait = min(2 ** attempt_number, 10)
        # 添加抖动:随机 0-1 秒
        jitter = random.uniform(0, 1)
        return base_wait + jitter

    def invoke(self, input_data, config=None, **kwargs):
        """执行并重试"""
        last_exception = None

        for attempt in range(1, self.max_attempt_number + 1):
            try:
                # 调用原始的 Runnable
                if hasattr(self.bound, 'invoke'):
                    import inspect
                    try:
                        sig = inspect.signature(self.bound.invoke)
                        params = list(sig.parameters.keys())
                        if 'config' in params:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                        else:
                            return self.bound.invoke(input_data, **kwargs)
                    except (ValueError, TypeError):
                        try:
                            return self.bound.invoke(input_data, **kwargs)
                        except TypeError:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                elif callable(self.bound):
                    return self.bound(input_data)
                else:
                    raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")
            except Exception as e:
                last_exception = e
                if not self._should_retry(e):
                    # 不应该重试的异常,直接抛出
                    raise

                if attempt < self.max_attempt_number:
                    # 计算等待时间
                    wait_time = self._wait_time(attempt)
                    if wait_time > 0:
                        import time
                        time.sleep(wait_time)
                    # 继续重试
                    continue
                else:
                    # 达到最大重试次数,抛出最后一个异常
                    raise

        # 如果所有重试都失败,抛出最后一个异常
        if last_exception:
            raise last_exception

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableRetry(bound={self.bound}, max_attempts={self.max_attempt_number})"


class RunnableBinding(Runnable):
    """绑定配置的 Runnable

    RunnableBinding 包装另一个 Runnable,并绑定配置。
    当调用绑定的 Runnable 时,会自动使用绑定的配置。
    """

    def __init__(self, bound, config=None, **kwargs):
        """
        初始化 RunnableBinding

        Args:
            bound: 要包装的 Runnable
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数
        """
        self.bound = bound
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        self.config = {**config, **kwargs}

    def invoke(self, input_data, config=None, **kwargs):
        """执行并合并配置"""
        # 合并绑定的配置和传入的配置(传入的配置优先)
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        # 合并 kwargs 到 config
        merged_config.update(kwargs)

        # 调用原始的 Runnable
        if hasattr(self.bound, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.bound.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
                else:
                    return self.bound.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.bound.invoke(input_data, **kwargs)
                except TypeError:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
        elif callable(self.bound):
            return self.bound(input_data)
        else:
            raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_config(self, config=None, **kwargs):
        """添加更多配置(合并到现有配置)"""
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        merged_config.update(kwargs)
        return RunnableBinding(bound=self.bound, config=merged_config)

    def __repr__(self):
        return f"RunnableBinding(bound={self.bound}, config={self.config})"


+class RunnableConfigurableFields(Runnable):
+   """可配置字段的 Runnable
+   
+   RunnableConfigurableFields 允许在运行时动态配置 Runnable 的属性。
+   它通过 configurable_fields 方法创建,使用 ConfigurableField 定义可配置的字段。
+   """
+   
+   def __init__(self, bound, fields):
+       """
+       初始化 RunnableConfigurableFields
+       
+       Args:
+           bound: 要包装的 Runnable
+           fields: 可配置字段字典,键是字段名,值是 ConfigurableField
+       """
+       self.bound = bound
+       self.fields = fields
+   
+   def invoke(self, input_data, config=None, **kwargs):
+       """执行并应用可配置字段"""
+       # 从 config 中获取可配置字段的值
+       if config and 'configurable' in config:
+           configurable = config['configurable']
+           # 应用可配置字段的值到 bound
+           for field_name, field_def in self.fields.items():
+               if field_def.id in configurable:
+                   # 如果 bound 有该属性,设置它
+                   if hasattr(self.bound, field_name):
+                       setattr(self.bound, field_name, configurable[field_def.id])
+                   # 如果 bound 是字典,更新它
+                   elif isinstance(self.bound, dict):
+                       self.bound[field_name] = configurable[field_def.id]
+       
+       # 调用原始的 Runnable
+       if hasattr(self.bound, 'invoke'):
+           import inspect
+           try:
+               sig = inspect.signature(self.bound.invoke)
+               params = list(sig.parameters.keys())
+               if 'config' in params:
+                   return self.bound.invoke(input_data, config=config, **kwargs)
+               else:
+                   return self.bound.invoke(input_data, **kwargs)
+           except (ValueError, TypeError):
+               try:
+                   return self.bound.invoke(input_data, **kwargs)
+               except TypeError:
+                   return self.bound.invoke(input_data, config=config, **kwargs)
+       elif callable(self.bound):
+           return self.bound(input_data)
+       else:
+           raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")
+   
+   def batch(self, inputs, config=None, **kwargs):
+       """批处理"""
+       return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
+   
+   def stream(self, input_data, config=None, **kwargs):
+       """流式处理"""
+       result = self.invoke(input_data, config=config, **kwargs)
+       yield result
+   
+   def __or__(self, other):
+       """支持 | 操作符"""
+       if isinstance(other, dict):
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   def with_config(self, config=None, **kwargs):
+       """添加更多配置"""
+       if config is None:
+           config = {}
+       merged_config = {**config, **kwargs}
+       return RunnableConfigurableFields(bound=self.bound, fields=self.fields)
+   
+   def __repr__(self):
+       return f"RunnableConfigurableFields(bound={self.bound}, fields={list(self.fields.keys())})"
+
+
def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
+   try:
+       from langchain.prompts import PromptTemplate
+       from langchain.chat_models import ChatOpenAI
+       from langchain.output_parsers import BaseOutputParser
+   except ImportError:
+       # 如果导入失败,跳过 LCEL 支持设置
+       return

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

    # 绑定方法到类
    PromptTemplate.__or__ = _prompt_or
    PromptTemplate.invoke = _prompt_invoke
    ChatOpenAI.__or__ = _llm_or
    BaseOutputParser.__or__ = _parser_or
    BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()

+
+# ConfigurableField 相关类
+from collections import namedtuple
+
+ConfigurableField = namedtuple(
+   'ConfigurableField',
+   ['id', 'name', 'description', 'annotation', 'is_shared'],
+   defaults=(None, None, None, False)
+)
+"""可配置字段
+
+ConfigurableField 用于定义可以在运行时配置的字段。
+它通常与 configurable_fields 方法一起使用,允许在运行时动态配置 Runnable 的属性。
+
+Args:
+   id: 字段的唯一标识符
+   name: 字段的名称(可选)
+   description: 字段的描述(可选)
+   annotation: 字段的类型注解(可选)
+   is_shared: 字段是否共享(可选,默认为 False)
+"""
+
+ConfigurableFieldSingleOption = namedtuple(
+   'ConfigurableFieldSingleOption',
+   ['id', 'name', 'description', 'annotation', 'default', 'is_shared'],
+   defaults=(None, None, None, None, False)
+)
+"""单选项可配置字段
+
+ConfigurableFieldSingleOption 用于定义有默认值的可配置字段。
+
+Args:
+   id: 字段的唯一标识符
+   name: 字段的名称(可选)
+   description: 字段的描述(可选)
+   annotation: 字段的类型注解(可选)
+   default: 默认值
+   is_shared: 字段是否共享(可选,默认为 False)
+"""
+
+ConfigurableFieldMultiOption = namedtuple(
+   'ConfigurableFieldMultiOption',
+   ['id', 'name', 'description', 'annotation', 'options', 'default', 'is_shared'],
+   defaults=(None, None, None, None, None, False)
+)
+"""多选项可配置字段
+
+ConfigurableFieldMultiOption 用于定义有多个选项的可配置字段。
+
+Args:
+   id: 字段的唯一标识符
+   name: 字段的名称(可选)
+   description: 字段的描述(可选)
+   annotation: 字段的类型注解(可选)
+   options: 选项字典
+   default: 默认值(可选)
+   is_shared: 字段是否共享(可选,默认为 False)
+"""
+

31. ChatMessageHistory #

ChatMessageHistory是用于管理和存储对话消息历史的核心类,常用于多轮对话、聊天机器人、上下文管理等场景。

ChatMessageHistory 具有以下主要功能:

  • 消息存储:支持将用户消息(如 HumanMessage)、AI 消息(如 AIMessage)以对象形式依次写入历史,确保消息顺序和类型的准确。
  • 增删查操作:可以通过 add_message/add_messages 方法添加单条或多条消息,通过 messages 属性一次性获取全部历史消息,通过 clear() 方法清空历史。
  • 便捷方法:如 add_user_message 和 add_ai_message,可直接传入字符串自动转换为标准消息对象,简化代码。

设计思想:

  • 满足多类型消息(用户、AI、系统提示等)的统一管理需求,方便扩展。
  • 支持多种后端:默认实现为内存存储,实际生产应用可扩展为数据库、Redis 甚至分布式持久化。
  • 能为每个用户/会话独立维护历史,实现多会话隔离。

典型应用举例:

  • 聊天机器人需要了解上下文,对历史消息逐条回溯;
  • 将历史对话注入到 Prompt 中提升 LLM 理解能力;
  • 支持临时和持久会话(如按 user_id 分组),优化用户体验。

通常使用方法:

from langchain.chat_message_histories import ChatMessageHistory
history = ChatMessageHistory()
history.add_user_message("你好!")
history.add_ai_message("你好,请问有什么可以帮您?")
print([msg.content for msg in history.messages])
history.clear()

31.1. init.py #

langchain/init.py

# 导入 runnables 模块以自动启用 LCEL 支持
import langchain.runnables

# 导出 chat_message_histories
+from langchain.chat_message_histories import ChatMessageHistory
+
+__all__ = ["ChatMessageHistory"]

31.2. chat_message_histories.py #

langchain/chat_message_histories.py

"""Chat message history implementations.

This module provides implementations for storing and managing chat message history.
"""

from typing import List, Sequence
from abc import ABC, abstractmethod

try:
    from langchain_core.messages import (
        AIMessage,
        BaseMessage,
        HumanMessage,
    )
    from langchain_core.chat_history import BaseChatMessageHistory
except ImportError:
    # 如果 langchain_core 不可用,定义基础类
    from langchain.messages import AIMessage, BaseMessage, HumanMessage

    class BaseChatMessageHistory(ABC):
        """抽象基类,用于存储聊天消息历史"""

        @property
        @abstractmethod
        def messages(self) -> List[BaseMessage]:
            """获取所有消息"""
            pass

        @abstractmethod
        def add_message(self, message: BaseMessage) -> None:
            """添加一条消息"""
            pass

        @abstractmethod
        def add_messages(self, messages: Sequence[BaseMessage]) -> None:
            """批量添加消息"""
            pass

        @abstractmethod
        def clear(self) -> None:
            """清空所有消息"""
            pass

        def add_user_message(self, message: HumanMessage | str) -> None:
            """便捷方法:添加用户消息"""
            if isinstance(message, str):
                self.add_message(HumanMessage(content=message))
            else:
                self.add_message(message)

        def add_ai_message(self, message: AIMessage | str) -> None:
            """便捷方法:添加 AI 消息"""
            if isinstance(message, str):
                self.add_message(AIMessage(content=message))
            else:
                self.add_message(message)


class ChatMessageHistory(BaseChatMessageHistory):
    """内存实现的聊天消息历史记录

    ChatMessageHistory 是一个简单的内存实现,将消息存储在内存中的列表中。
    它适合用于开发、测试和简单的应用场景。

    对于生产环境,建议使用持久化存储实现(如数据库、Redis 等)。

    示例:
        ```python
        from langchain.chat_message_histories import ChatMessageHistory
        from langchain.messages import HumanMessage, AIMessage

        # 创建历史记录实例
        history = ChatMessageHistory()

        # 添加用户消息
        history.add_user_message("你好")

        # 添加 AI 消息
        history.add_ai_message("你好!有什么可以帮助你的吗?")

        # 获取所有消息
        messages = history.messages

        # 清空历史记录
        history.clear()
"""

def __init__(self):
    """初始化 ChatMessageHistory

    创建一个空的消息历史记录。
    """
    self._messages: List[BaseMessage] = []

@property
def messages(self) -> List[BaseMessage]:
    """获取所有消息

    Returns:
        消息列表
    """
    return self._messages.copy()  # 返回副本,避免外部修改

def add_message(self, message: BaseMessage) -> None:
    """添加一条消息

    Args:
        message: 要添加的消息对象
    """
    self._messages.append(message)

def add_messages(self, messages: Sequence[BaseMessage]) -> None:
    """批量添加消息

    这是推荐的方法,因为它可以一次性添加多条消息,避免多次调用 add_message。

    Args:
        messages: 要添加的消息序列
    """
    self._messages.extend(messages)

def add_user_message(self, message: HumanMessage | str) -> None:
    """便捷方法:添加用户消息

    Args:
        message: 用户消息字符串或 HumanMessage 对象
    """
    if isinstance(message, str):
        self.add_message(HumanMessage(content=message))
    else:
        self.add_message(message)

def add_ai_message(self, message: AIMessage | str) -> None:
    """便捷方法:添加 AI 消息

    Args:
        message: AI 消息字符串或 AIMessage 对象
    """
    if isinstance(message, str):
        self.add_message(AIMessage(content=message))
    else:
        self.add_message(message)

def clear(self) -> None:
    """清空所有消息"""
    self._messages = []

def __len__(self) -> int:
    """返回消息数量"""
    return len(self._messages)

def __repr__(self) -> str:
    """返回字符串表示"""
    return f"ChatMessageHistory(messages={len(self._messages)})"

all = ["ChatMessageHistory", "BaseChatMessageHistory"]


### 31.3. ChatMessageHistory.py
31.ChatMessageHistory.py
```js

#from langchain.runnables import RunnableLambda
#from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
#from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
#from langchain_openai import ChatOpenAI
#from langchain_core.output_parsers import StrOutputParser
#from langchain_community.chat_message_histories import ChatMessageHistory


from langchain.chat_message_histories import ChatMessageHistory
from langchain.messages import HumanMessage, AIMessage, SystemMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chat_models import ChatOpenAI


print("=" * 60)
print("ChatMessageHistory 演示")
print("=" * 60)
print("\nChatMessageHistory: 用于存储和管理聊天消息历史记录")
print("它提供了添加、获取和清空消息的方法\n")

print("=" * 60)
print("演示 1: ChatMessageHistory 的基本用法")
print("=" * 60)

# 创建 ChatMessageHistory 实例
history = ChatMessageHistory()

print("\n创建 ChatMessageHistory 实例:")
print("history = ChatMessageHistory()")

# 添加用户消息
history.add_user_message("你好,我是张三")
print("\n添加用户消息:")
print("history.add_user_message('你好,我是张三')")

# 添加 AI 消息
history.add_ai_message("你好张三,很高兴认识你!")
print("history.add_ai_message('你好张三,很高兴认识你!')")

# 查看消息
print("\n查看所有消息:")
print(f"history.messages: {len(history.messages)} 条消息")
for i, msg in enumerate(history.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 2: 使用 add_message 方法")
print("=" * 60)

# 使用 add_message 添加消息对象
if HumanMessage and AIMessage:
    history2 = ChatMessageHistory()

    print("\n使用 add_message 添加消息对象:")
    history2.add_message(HumanMessage(content="Python 是什么?"))
    print("history2.add_message(HumanMessage(content='Python 是什么?'))")

    history2.add_message(AIMessage(content="Python 是一种高级编程语言。"))
    print("history2.add_message(AIMessage(content='Python 是一种高级编程语言。'))")

    print("\n消息列表:")
    for i, msg in enumerate(history2.messages, 1):
        print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 3: 使用 add_messages 批量添加")
print("=" * 60)

# 使用 add_messages 批量添加消息
if HumanMessage and AIMessage:
    history3 = ChatMessageHistory()

    print("\n使用 add_messages 批量添加消息:")
    messages = [
        HumanMessage(content="什么是机器学习?"),
        AIMessage(content="机器学习是人工智能的一个分支。"),
        HumanMessage(content="能举个例子吗?"),
        AIMessage(content="比如图像识别、语音识别等。"),
    ]

    history3.add_messages(messages)
    print(f"history3.add_messages({len(messages)} 条消息)")

    print(f"\n消息总数: {len(history3.messages)}")
    for i, msg in enumerate(history3.messages, 1):
        print(f"  {i}. {type(msg).__name__}: {msg.content[:50]}...")

print("\n" + "=" * 60)
print("演示 4: 清空消息历史")
print("=" * 60)

# 清空消息
history4 = ChatMessageHistory()
history4.add_user_message("测试消息 1")
history4.add_ai_message("回复 1")
history4.add_user_message("测试消息 2")

print("\n清空前:")
print(f"  消息数量: {len(history4.messages)}")

history4.clear()
print("\n执行 history4.clear()")
print(f"清空后消息数量: {len(history4.messages)}")

print("\n" + "=" * 60)
print("演示 5: 与 ChatPromptTemplate 结合使用")
print("=" * 60)

if ChatPromptTemplate and MessagesPlaceholder and ChatOpenAI:
    # 创建历史记录
    chat_history = ChatMessageHistory()
    chat_history.add_user_message("你好")
    chat_history.add_ai_message("你好!有什么可以帮助你的吗?")

    print("\n创建聊天历史:")
    print("chat_history = ChatMessageHistory()")
    print("chat_history.add_user_message('你好')")
    print("chat_history.add_ai_message('你好!有什么可以帮助你的吗?')")

    # 创建包含历史记录的提示模板
    template = ChatPromptTemplate.from_messages([
        ("system", "你是一个友好的 AI 助手。"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ])

    print("\n创建提示模板:")
    print("template = ChatPromptTemplate.from_messages([")
    print("    ('system', '你是一个友好的 AI 助手。'),")
    print("    MessagesPlaceholder(variable_name='history'),")
    print("    ('human', '{question}'),")
    print("])")

    # 使用历史记录格式化提示
    prompt_messages = template.format_messages(
        history=chat_history.messages,
        question="请介绍一下你自己"
    )

    print("\n格式化提示(包含历史记录):")
    print(f"  消息数量: {len(prompt_messages)}")
    for i, msg in enumerate(prompt_messages, 1):
        msg_type = type(msg).__name__
        content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
        print(f"  {i}. {msg_type}: {content}")

    # 调用 LLM(如果需要)
    print("\n(注意:实际调用 LLM 需要配置 API 密钥)")
else:
    print("\n(跳过此演示,需要 ChatPromptTemplate 和 ChatOpenAI 支持)")

print("\n" + "=" * 60)
print("演示 6: 多轮对话管理")
print("=" * 60)

# 模拟多轮对话
conversation_history = ChatMessageHistory()

print("\n模拟多轮对话:")

# 第一轮
conversation_history.add_user_message("我想学习 Python")
if AIMessage:
    conversation_history.add_ai_message("很好!Python 是一门很好的编程语言。你想从哪方面开始?")
    print("  用户: 我想学习 Python")
    print("  AI: 很好!Python 是一门很好的编程语言。你想从哪方面开始?")

# 第二轮
conversation_history.add_user_message("基础语法")
if AIMessage:
    conversation_history.add_ai_message("好的,基础语法包括变量、数据类型、控制流等。")
    print("  用户: 基础语法")
    print("  AI: 好的,基础语法包括变量、数据类型、控制流等。")

# 第三轮
conversation_history.add_user_message("能详细说说变量吗?")
if AIMessage:
    conversation_history.add_ai_message("变量是用来存储数据的容器,在 Python 中不需要声明类型。")
    print("  用户: 能详细说说变量吗?")
    print("  AI: 变量是用来存储数据的容器,在 Python 中不需要声明类型。")

print(f"\n对话历史记录(共 {len(conversation_history.messages)} 条消息):")
for i, msg in enumerate(conversation_history.messages, 1):
    role = "用户" if hasattr(msg, 'content') and "用户" in str(type(msg)) or "Human" in type(msg).__name__ else "AI"
    content = msg.content[:60] + "..." if len(msg.content) > 60 else msg.content
    print(f"  {i}. [{role}]: {content}")

print("\n" + "=" * 60)
print("演示 7: 获取消息历史")
print("=" * 60)

# 获取消息历史
history5 = ChatMessageHistory()
history5.add_user_message("问题 1")
history5.add_ai_message("回答 1")
history5.add_user_message("问题 2")
history5.add_ai_message("回答 2")

print("\n获取消息历史:")
print("history5.messages")
print(f"  返回类型: {type(history5.messages)}")
print(f"  消息数量: {len(history5.messages)}")

print("\n遍历消息:")
for i, msg in enumerate(history5.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 8: 消息历史的状态管理")
print("=" * 60)

# 创建多个独立的会话历史
session1 = ChatMessageHistory()
session2 = ChatMessageHistory()

session1.add_user_message("会话 1 的消息")
session1.add_ai_message("会话 1 的回复")

session2.add_user_message("会话 2 的消息")
session2.add_ai_message("会话 2 的回复")

print("\n创建多个独立的会话:")
print("session1 = ChatMessageHistory()")
print("session2 = ChatMessageHistory()")

print(f"\nsession1 消息数: {len(session1.messages)}")
print(f"session2 消息数: {len(session2.messages)}")

print("\n会话 1 的消息:")
for i, msg in enumerate(session1.messages, 1):
    print(f"  {i}. {msg.content}")

print("\n会话 2 的消息:")
for i, msg in enumerate(session2.messages, 1):
    print(f"  {i}. {msg.content}")

print("\n" + "=" * 60)
print("演示 9: 实际应用场景 - 聊天机器人")
print("=" * 60)

# 模拟聊天机器人
bot_history = ChatMessageHistory()

print("\n模拟聊天机器人对话:")

# 初始化系统消息(如果需要)
if SystemMessage:
    # 注意:ChatMessageHistory 通常不直接存储 SystemMessage
    # 但在实际应用中,系统消息会在提示模板中处理
    pass

# 用户输入
user_inputs = [
    "你好",
    "今天天气怎么样?",
    "谢谢你的回答",
]

ai_responses = [
    "你好!很高兴为你服务。",
    "抱歉,我无法获取实时天气信息。建议查看天气预报应用。",
    "不客气!有其他问题随时问我。",
]

for user_input, ai_response in zip(user_inputs, ai_responses):
    bot_history.add_user_message(user_input)
    bot_history.add_ai_message(ai_response)
    print(f"  用户: {user_input}")
    print(f"  机器人: {ai_response}")

print(f"\n完整对话历史({len(bot_history.messages)} 条消息):")
for i in range(0, len(bot_history.messages), 2):
    if i + 1 < len(bot_history.messages):
        user_msg = bot_history.messages[i]
        ai_msg = bot_history.messages[i + 1]
        print(f"  轮次 {i//2 + 1}:")
        print(f"    用户: {user_msg.content}")
        print(f"    机器人: {ai_msg.content}")

print("\n" + "=" * 60)
print("演示 10: 消息历史的持久化(概念演示)")
print("=" * 60)

print("\n在实际应用中,ChatMessageHistory 可以:")
print("1. 存储在内存中(当前演示)")
print("2. 存储在数据库中(如 PostgreSQL、MongoDB)")
print("3. 存储在 Redis 中(适合分布式应用)")
print("4. 存储在文件中(适合单机应用)")

print("\n示例:内存存储(当前实现)")
memory_history = ChatMessageHistory()
memory_history.add_user_message("消息 1")
memory_history.add_user_message("消息 2")
print(f"  消息存储在内存中,程序重启后会丢失")
print(f"  当前消息数: {len(memory_history.messages)}")

print("\n示例:持久化存储(需要实现)")
print("  # 可以继承 BaseChatMessageHistory 实现自定义存储")
print("  # class DatabaseChatMessageHistory(BaseChatMessageHistory):")
print("  #     def add_message(self, message):")
print("  #         # 保存到数据库")
print("  #         pass")
print("  #     def messages(self):")
print("  #         # 从数据库读取")
print("  #         return []")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nChatMessageHistory 的核心功能:")
print("1. 消息存储:使用 add_message、add_messages 存储消息")
print("2. 消息获取:通过 messages 属性获取所有消息")
print("3. 消息清空:使用 clear() 方法清空历史")
print("4. 便捷方法:add_user_message、add_ai_message 快速添加消息")
print("\n使用场景:")
print("- 聊天机器人:维护对话上下文")
print("- 多轮对话:跟踪对话历史")
print("- 会话管理:为不同用户/会话维护独立历史")
print("- 上下文注入:将历史消息注入到提示模板中")
print("\n注意事项:")
print("- 内存实现:当前演示使用内存存储,程序重启后丢失")
print("- 持久化:生产环境应使用持久化存储(数据库、Redis 等)")
print("- 性能:批量添加消息时使用 add_messages 而不是多次 add_message")
print("- 会话隔离:为不同会话创建独立的 ChatMessageHistory 实例")

32. RunnableWithMessageHistory #

RunnableWithMessageHistory 能让你的链(Chain)或可运行对象(Runnable)自动携带和管理每个会话的历史消息,实现多轮对话的上下文记忆。其典型应用场景包括聊天机器人、智能客服等。

基本原理

  • RunnableWithMessageHistory 本质上是一个包装器(wrapper),他可以包装任何 Runnable(包括 LLM 调用、Prompt 链条等)。
  • 你只需要提供一个“会话历史存储”的获取函数(比如根据 session_id 返回对应用户的历史),LangChain 会自动插入/提取历史消息。
  • 每次调用时,你传入一个 config,其 configurable.session_id 字段用于区分不同用户/会话。

用法要点

  1. 定义历史存储:你需要有某种消息历史对象(如 ChatMessageHistory)。
  2. 包装你的链/LLM:用 RunnableWithMessageHistory 包一层,指明历史采集点(如 input_messages_key、history_messages_key)。
  3. 传入 session_id:每次运行时传递唯一 session_id,实现面向多用户或多会话的历史隔离。
  4. 支持自定义存储:不仅可以用内存存储,也可以持久化到数据库、Redis 等,确保会话上下文长期有效。

适用场景

  • 聊天机器人、对话助手等需要了解“上下文”的产品
  • 需要根据历史多轮问答推理或记忆
  • 多用户多会话下,确保各自的历史隔离
  • 在 PromptTemplate/Pipeline/LLM 链条中自动注入对话历史

32.1. RunnableWithMessageHistory.py #

32.RunnableWithMessageHistory..py


#from langchain_core.messages import HumanMessage, AIMessage
#from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
#from langchain_core.output_parsers import StrOutputParser
#from langchain_core.runnables import RunnableLambda
#from langchain_core.runnables.history import RunnableWithMessageHistory
#from langchain_community.chat_message_histories import ChatMessageHistory
#from langchain_openai import ChatOpenAI

from langchain.messages import HumanMessage, AIMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.output_parsers import StrOutputParser
from langchain.runnables import RunnableWithMessageHistory, RunnableLambda
from langchain.chat_message_histories import ChatMessageHistory
from langchain.chat_models import ChatOpenAI


print("=" * 60)
print("演示 1: RunnableWithMessageHistory 的基本用法")
print("=" * 60)

# 创建会话历史存储
store = {}

def get_session_history(session_id: str):
    """获取或创建会话历史"""
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# 创建简单的链(不使用 LLM,只返回固定回复)
def simple_chain_func(input_data):
    """简单的链,返回固定回复"""
    if isinstance(input_data, dict):
        question = input_data.get("question", "")
        return AIMessage(content=f"AI回复: 我收到了你的问题:{question}")
    elif isinstance(input_data, str):
        return AIMessage(content=f"AI回复: 我收到了你的消息:{input_data}")
    elif isinstance(input_data, list):
        # 如果是消息列表,提取最后一条
        if len(input_data) > 0:
            last_msg = input_data[-1]
            if hasattr(last_msg, 'content'):
                return AIMessage(content=f"AI回复: 我收到了你的消息:{last_msg.content}")
    return AIMessage(content="AI回复: 我收到了你的消息")

# 将函数包装成 Runnable(langchain_core 要求 runnable 必须是 Runnable 对象)
if RunnableLambda is not None:
    simple_chain = RunnableLambda(simple_chain_func)
else:
    simple_chain = simple_chain_func

# 创建带历史记录的链
chain_with_history = RunnableWithMessageHistory(
    simple_chain,
    get_session_history,
    input_messages_key="question",
)

print("\n创建带历史记录的链:")
print("chain_with_history = RunnableWithMessageHistory(")
print("    simple_chain,")
print("    get_session_history,")
print("    input_messages_key='question',")
print(")")

# 第一次调用
print("\n第一次调用(session_id='user1'):")
result1 = chain_with_history.invoke(
    {"question": "你好"},
    config={"configurable": {"session_id": "user1"}}
)
print(f"  输入: {{'question': '你好'}}")
print(f"  输出: {result1.content if hasattr(result1, 'content') else result1}")
print(f"  历史记录: {len(store['user1'].messages)} 条消息")

# 第二次调用(应该能看到历史)
print("\n第二次调用(相同 session_id):")
result2 = chain_with_history.invoke(
    {"question": "你还记得我吗?"},
    config={"configurable": {"session_id": "user1"}}
)
print(f"  输入: {{'question': '你还记得我吗?'}}")
print(f"  输出: {result2.content if hasattr(result2, 'content') else result2}")
print(f"  历史记录: {len(store['user1'].messages)} 条消息")
print(f"  历史消息:")
for i, msg in enumerate(store['user1'].messages, 1):
    msg_type = type(msg).__name__
    print(f"    {i}. {msg_type}: {msg.content}")

print("\n" + "=" * 60)
print("演示 2: 与 ChatPromptTemplate 结合使用")
print("=" * 60)

if ChatOpenAI and ChatPromptTemplate and MessagesPlaceholder:
    # 创建提示模板
    prompt = ChatPromptTemplate.from_messages([
        ("system", "你是一个友好的 AI 助手。"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ])

    # 创建 LLM 链
    llm = ChatOpenAI(model="gpt-4o")
    parser = StrOutputParser()
    chain = prompt | llm | parser

    # 创建带历史记录的链
    llm_chain_with_history = RunnableWithMessageHistory(
        chain,
        get_session_history,
        input_messages_key="question",
        history_messages_key="history",
    )

    print("\n创建带历史记录的 LLM 链:")
    print("llm_chain_with_history = RunnableWithMessageHistory(")
    print("    chain,")
    print("    get_session_history,")
    print("    input_messages_key='question',")
    print("    history_messages_key='history',")
    print(")")

    # 第一次调用
    print("\n第一次调用(session_id='user2'):")
    print("  ⚠️  注意:LLM API 调用可能需要 5-30 秒,请耐心等待...")
    print("  (如果长时间无响应,可能是网络问题或 API 配置问题)")
    try:
        result1 = llm_chain_with_history.invoke(
            {"question": "你好,我是张三"},
            config={"configurable": {"session_id": "user2"}}
        )
        print(f"  ✅ 调用成功!")
        print(f"  输入: {{'question': '你好,我是张三'}}")
        print(f"  输出: {result1[:100] if isinstance(result1, str) else str(result1)[:100]}...")
        print(f"  历史记录: {len(store['user2'].messages)} 条消息")

        # 第二次调用
        print("\n第二次调用(相同 session_id):")
        print("  ⚠️  注意:LLM API 调用可能需要 5-30 秒,请耐心等待...")
        result2 = llm_chain_with_history.invoke(
            {"question": "你还记得我的名字吗?"},
            config={"configurable": {"session_id": "user2"}}
        )
        print(f"  ✅ 调用成功!")
        print(f"  输入: {{'question': '你还记得我的名字吗?'}}")
        print(f"  输出: {result2[:100] if isinstance(result2, str) else str(result2)[:100]}...")
        print(f"  历史记录: {len(store['user2'].messages)} 条消息")
    except KeyboardInterrupt:
        print("\n  ⚠️  用户中断了调用")
    except Exception as e:
        print(f"\n  ❌ 调用失败:{type(e).__name__}: {e}")
        print(f"  💡 可能的原因:")
        print(f"     - 缺少 OPENAI_API_KEY 环境变量")
        print(f"     - 网络连接问题")
        print(f"     - API 服务暂时不可用")
        print(f"     - API 密钥无效或已过期")
else:
    print("\n(跳过此演示,需要 ChatOpenAI 和 ChatPromptTemplate 支持)")

print("\n" + "=" * 60)
print("演示 3: 多个会话管理")
print("=" * 60)

# 创建多个会话
print("\n创建多个独立的会话:")

# 会话 1
result1 = chain_with_history.invoke(
    {"question": "会话1的问题1"},
    config={"configurable": {"session_id": "session1"}}
)
print(f"  会话1 - 问题1: {result1.content if hasattr(result1, 'content') else result1}")

result2 = chain_with_history.invoke(
    {"question": "会话1的问题2"},
    config={"configurable": {"session_id": "session1"}}
)
print(f"  会话1 - 问题2: {result2.content if hasattr(result2, 'content') else result2}")

# 会话 2
result3 = chain_with_history.invoke(
    {"question": "会话2的问题1"},
    config={"configurable": {"session_id": "session2"}}
)
print(f"  会话2 - 问题1: {result3.content if hasattr(result3, 'content') else result3}")

# 查看各会话的历史
print("\n各会话的历史记录:")
for session_id in ["session1", "session2"]:
    if session_id in store:
        history = store[session_id]
        print(f"  {session_id}: {len(history.messages)} 条消息")
        for i, msg in enumerate(history.messages, 1):
            msg_type = type(msg).__name__
            print(f"    {i}. {msg_type}: {msg.content[:50]}...")

print("\n" + "=" * 60)
print("演示 4: 字符串输入")
print("=" * 60)

# 使用字符串输入
def string_chain_func(input_data):
    """处理字符串输入的链"""
    if isinstance(input_data, str):
        return AIMessage(content=f"回复: {input_data}")
    elif isinstance(input_data, list):
        if len(input_data) > 0:
            last_msg = input_data[-1]
            if hasattr(last_msg, 'content'):
                return AIMessage(content=f"回复: {last_msg.content}")
    return AIMessage(content="回复: 未知输入")

# 将函数包装成 RunnableLambda
if RunnableLambda is not None:
    string_chain = RunnableLambda(string_chain_func)
else:
    string_chain = string_chain_func

string_chain_with_history = RunnableWithMessageHistory(
    string_chain,
    get_session_history,
)

print("\n创建处理字符串输入的链:")
print("string_chain_with_history = RunnableWithMessageHistory(")
print("    string_chain,")
print("    get_session_history,")
print(")")

# 调用
print("\n使用字符串输入:")
result1 = string_chain_with_history.invoke(
    "你好",
    config={"configurable": {"session_id": "string_session"}}
)
print(f"  输入: '你好'")
print(f"  输出: {result1.content if hasattr(result1, 'content') else result1}")

result2 = string_chain_with_history.invoke(
    "再见",
    config={"configurable": {"session_id": "string_session"}}
)
print(f"  输入: '再见'")
print(f"  输出: {result2.content if hasattr(result2, 'content') else result2}")

print(f"\n历史记录: {len(store['string_session'].messages)} 条消息")

print("\n" + "=" * 60)
print("演示 5: 消息列表输入")
print("=" * 60)

# 使用消息列表输入
def message_list_chain_func(input_data):
    """处理消息列表输入的链"""
    if isinstance(input_data, list):
        if len(input_data) > 0:
            last_msg = input_data[-1]
            if hasattr(last_msg, 'content'):
                return AIMessage(content=f"回复: {last_msg.content}")
    return AIMessage(content="回复: 未知输入")

# 将函数包装成 RunnableLambda
if RunnableLambda is not None:
    message_list_chain = RunnableLambda(message_list_chain_func)
else:
    message_list_chain = message_list_chain_func

message_list_chain_with_history = RunnableWithMessageHistory(
    message_list_chain,
    get_session_history,
)

print("\n创建处理消息列表输入的链:")
print("message_list_chain_with_history = RunnableWithMessageHistory(")
print("    message_list_chain,")
print("    get_session_history,")
print(")")

# 调用
print("\n使用消息列表输入:")
result1 = message_list_chain_with_history.invoke(
    [HumanMessage(content="消息1")],
    config={"configurable": {"session_id": "message_session"}}
)
print(f"  输入: [HumanMessage(content='消息1')]")
print(f"  输出: {result1.content if hasattr(result1, 'content') else result1}")

result2 = message_list_chain_with_history.invoke(
    [HumanMessage(content="消息2")],
    config={"configurable": {"session_id": "message_session"}}
)
print(f"  输入: [HumanMessage(content='消息2')]")
print(f"  输出: {result2.content if hasattr(result2, 'content') else result2}")

print(f"\n历史记录: {len(store['message_session'].messages)} 条消息")

print("\n" + "=" * 60)
print("演示 6: 批处理")
print("=" * 60)

# 批处理
print("\n使用批处理:")
results = chain_with_history.batch(
    [
        {"question": "批处理问题1"},
        {"question": "批处理问题2"},
        {"question": "批处理问题3"},
    ],
    config={"configurable": {"session_id": "batch_session"}}
)

for i, result in enumerate(results, 1):
    content = result.content if hasattr(result, 'content') else result
    print(f"  问题 {i}: {content}")

print(f"\n历史记录: {len(store['batch_session'].messages)} 条消息")

print("\n" + "=" * 60)
print("演示 7: 流式处理")
print("=" * 60)

# 流式处理
print("\n使用流式处理:")
stream = chain_with_history.stream(
    {"question": "流式问题"},
    config={"configurable": {"session_id": "stream_session"}}
)

print("  流式输出:")
for chunk in stream:
    content = chunk.content if hasattr(chunk, 'content') else chunk
    print(f"    {content}")

print(f"\n历史记录: {len(store['stream_session'].messages)} 条消息")

print("\n" + "=" * 60)
print("演示 8: 实际应用场景 - 多轮对话")
print("=" * 60)

# 模拟多轮对话(simple_chain 已经是 RunnableLambda)
conversation_chain = RunnableWithMessageHistory(
    simple_chain,
    get_session_history,
    input_messages_key="question",
)

print("\n模拟多轮对话:")

questions = [
    "你好",
    "我的名字是李四",
    "你还记得我的名字吗?",
    "谢谢",
]

for i, question in enumerate(questions, 1):
    result = conversation_chain.invoke(
        {"question": question},
        config={"configurable": {"session_id": "conversation"}}
    )
    print(f"  轮次 {i}:")
    print(f"    用户: {question}")
    content = result.content if hasattr(result, 'content') else result
    print(f"    AI: {content}")

print(f"\n完整对话历史({len(store['conversation'].messages)} 条消息):")
for i, msg in enumerate(store['conversation'].messages, 1):
    msg_type = type(msg).__name__
    print(f"  {i}. {msg_type}: {msg.content}")

print("\n" + "=" * 60)
print("演示 9: 自定义会话历史工厂")
print("=" * 60)

# 自定义会话历史工厂(支持多个参数)
def get_session_history_with_user(session_id: str, user_id: str = None):
    """支持 user_id 的会话历史工厂"""
    key = f"{user_id}_{session_id}" if user_id else session_id
    if key not in store:
        store[key] = ChatMessageHistory()
    return store[key]

# custom_chain 使用 simple_chain(已经是 RunnableLambda)
custom_chain = RunnableWithMessageHistory(
    simple_chain,
    get_session_history_with_user,
    input_messages_key="question",
)

print("\n使用自定义会话历史工厂:")
print("def get_session_history_with_user(session_id: str, user_id: str = None):")
print("    key = f'{user_id}_{session_id}' if user_id else session_id")
print("    ...")

# 不同用户的会话
result1 = custom_chain.invoke(
    {"question": "用户A的问题"},
    config={"configurable": {"session_id": "session1", "user_id": "userA"}}
)
print(f"\n用户A的会话: {result1.content if hasattr(result1, 'content') else result1}")

result2 = custom_chain.invoke(
    {"question": "用户B的问题"},
    config={"configurable": {"session_id": "session1", "user_id": "userB"}}
)
print(f"用户B的会话: {result2.content if hasattr(result2, 'content') else result2}")

print("\n各用户的历史记录:")
for key in store:
    if "user" in key:
        print(f"  {key}: {len(store[key].messages)} 条消息")

print("\n" + "=" * 60)
print("演示 10: 清空会话历史")
print("=" * 60)

# 创建会话并添加消息(simple_chain 已经是 RunnableLambda)
test_chain = RunnableWithMessageHistory(
    simple_chain,
    get_session_history,
    input_messages_key="question",
)

print("\n创建会话并添加消息:")
test_chain.invoke(
    {"question": "消息1"},
    config={"configurable": {"session_id": "clear_test"}}
)
test_chain.invoke(
    {"question": "消息2"},
    config={"configurable": {"session_id": "clear_test"}}
)

print(f"清空前: {len(store['clear_test'].messages)} 条消息")

# 清空历史
store['clear_test'].clear()
print(f"清空后: {len(store['clear_test'].messages)} 条消息")

# 继续对话
result = test_chain.invoke(
    {"question": "新消息"},
    config={"configurable": {"session_id": "clear_test"}}
)
print(f"新消息后: {len(store['clear_test'].messages)} 条消息")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nRunnableWithMessageHistory 的核心功能:")
print("1. 自动管理历史:自动读取和更新聊天消息历史")
print("2. 会话隔离:通过 session_id 区分不同会话")
print("3. 消息注入:自动将历史消息注入到输入中")
print("4. 历史更新:自动将用户输入和 AI 回复添加到历史")
print("\n使用场景:")
print("- 多轮对话:维护对话上下文")
print("- 会话管理:为不同用户/会话维护独立历史")
print("- 聊天机器人:自动管理对话历史")
print("- 上下文感知:让 LLM 能够记住之前的对话")
print("\n注意事项:")
print("- 必须通过 config 传递 session_id")
print("- get_session_history 函数负责创建或获取会话历史")
print("- 支持字典、字符串、消息列表等多种输入格式")
print("- 生产环境应使用持久化存储(数据库、Redis 等)")
print("- 可以通过自定义 get_session_history 支持多参数(如 user_id)")
print("- 使用 langchain_core 时,runnable 必须是 Runnable 对象(使用 RunnableLambda 包装函数)")

32.2. runnables.py #

langchain/runnables.py

# 实现 LCEL (LangChain Expression Language) 支持
# 提供 Runnable 接口、RunnableSequence 类和为组件添加 | 操作符支持

from typing import Any, List
from abc import ABC, abstractmethod


# 定义 Runnable 接口
class Runnable(ABC):
    """Runnable 接口的抽象定义

    Runnable 是 LangChain 的核心接口,表示可以运行的对象。
    它定义了统一的操作接口:invoke、batch、stream 等。
    所有支持这些方法的对象都可以被视为 Runnable。
    """

    @abstractmethod
    def invoke(self, input_data: Any, **kwargs) -> Any:
        """
        同步调用,处理单个输入

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Returns:
            处理后的输出
        """
        pass

    def batch(self, inputs: List[Any], **kwargs) -> List[Any]:
        """
        批处理,处理多个输入(默认实现)

        Args:
            inputs: 输入数据列表
            **kwargs: 额外的参数

        Returns:
            处理后的输出列表
        """
        return [self.invoke(input_data, **kwargs) for input_data in inputs]

    def stream(self, input_data: Any, **kwargs):
        """
        流式处理,逐步产生输出(默认实现)

        Args:
            input_data: 输入数据
            **kwargs: 额外的参数

        Yields:
            处理后的输出块
        """
        result = self.invoke(input_data, **kwargs)
        yield result

    def __or__(self, other):
        """
        支持 | 操作符,用于组合 Runnable

        Args:
            other: 另一个 Runnable 或可调用对象

        Returns:
            RunnableSequence: 组合后的序列
        """
        if hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_retry(
        self,
        *,
        retry_if_exception_type: tuple = (Exception,),
        wait_exponential_jitter: bool = True,
        stop_after_attempt: int = 3,
        **kwargs
    ):
        """
        创建一个带重试功能的 Runnable

        Args:
            retry_if_exception_type: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            stop_after_attempt: 最大重试次数
            **kwargs: 其他参数

        Returns:
            RunnableRetry: 带重试功能的 Runnable
        """
        return RunnableRetry(
            bound=self,
            retry_exception_types=retry_if_exception_type,
            wait_exponential_jitter=wait_exponential_jitter,
            max_attempt_number=stop_after_attempt,
            **kwargs
        )

    def with_config(
        self,
        config=None,
        **kwargs
    ):
        """
        将配置绑定到 Runnable,返回一个新的 Runnable

        Args:
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数(会合并到 config 中)

        Returns:
            RunnableBinding: 绑定配置后的 Runnable
        """
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        merged_config = {**config, **kwargs}
        return RunnableBinding(bound=self, config=merged_config)

    def configurable_fields(self, **kwargs):
        """
        配置可配置字段

        Args:
            **kwargs: ConfigurableField 实例,键是字段名,值是 ConfigurableField

        Returns:
            RunnableConfigurableFields: 可配置字段的 Runnable
        """
        return RunnableConfigurableFields(bound=self, fields=kwargs)


class RunnableSequence(Runnable):
    """可运行序列,用于 LCEL 链式调用"""

    def __init__(self, *steps):
        """初始化序列"""
        self.steps = list(steps)

    def invoke(self, input_data, config=None, **kwargs):
        """执行序列"""
        result = input_data
        for step in self.steps:
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    # 如果 invoke 方法接受 config 参数,传递它
                    if 'config' in params:
                        result = step.invoke(result, config=config, **kwargs)
                    else:
                        # 如果不接受 config,只传递其他 kwargs
                        result = step.invoke(result, **kwargs)
                except (ValueError, TypeError):
                    # 如果无法检查签名,尝试不传递 config
                    try:
                        result = step.invoke(result, **kwargs)
                    except TypeError:
                        # 如果还是失败,尝试传递 config(某些实现可能需要)
                        result = step.invoke(result, config=config, **kwargs)
            elif callable(step):
                result = step(result)
            else:
                raise ValueError(f"步骤 {step} 不可调用")
        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理(简化版本)"""
        # 对于流式处理,我们返回最终结果
        # 实际实现中,应该逐步产生输出
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(*self.steps, RunnableParallel(other))
        return RunnableSequence(*self.steps, other)

    def __repr__(self):
        return f"RunnableSequence({len(self.steps)} steps)"


class RunnableLambda(Runnable):
    """将 Python 可调用对象转换为 Runnable

    RunnableLambda 将普通的 Python 函数包装成 Runnable,
    使其可以在 LCEL 链中使用,并支持 invoke、batch、stream 等方法。
    """

    def __init__(self, func, name=None):
        """
        初始化 RunnableLambda

        Args:
            func: 要包装的 Python 可调用对象
            name: 可选的名称,用于调试和显示
        """
        if not callable(func):
            raise ValueError(f"func 必须是可调用对象,但得到 {type(func)}")
        self.func = func
        self.name = name or getattr(func, '__name__', 'lambda')

    def invoke(self, input_data, config=None, **kwargs):
        """调用函数处理输入"""
        # 如果函数接受 config 参数,传递它
        import inspect
        sig = inspect.signature(self.func)
        params = list(sig.parameters.keys())

        if len(params) > 1 and 'config' in params:
            return self.func(input_data, config=config, **kwargs)
        elif len(params) > 1 and len(kwargs) > 0:
            # 尝试传递 kwargs
            return self.func(input_data, **kwargs)
        else:
            return self.func(input_data)

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableLambda({self.name})"


class RunnableParallel(Runnable):
    """并行执行多个 Runnable

    RunnableParallel 并行执行多个 Runnable,它们接收相同的输入,
    然后返回一个字典,包含每个分支的结果。
    """

    def __init__(self, steps=None, **kwargs):
        """
        初始化 RunnableParallel

        Args:
            steps: 字典,键是分支名称,值是对应的 Runnable
            **kwargs: 也可以直接传递关键字参数,每个参数名作为键
        """
        if steps is None:
            steps = {}
        elif not isinstance(steps, dict):
            raise ValueError(f"steps 必须是字典,但得到 {type(steps)}")

        # 合并 kwargs 中的步骤
        self.steps = {**steps, **kwargs}

        # 验证所有值都是 Runnable 或可调用对象,如果是字典则转换为 RunnableParallel
        processed_steps = {}
        for key, value in self.steps.items():
            if isinstance(value, dict):
                # 嵌套字典自动转换为 RunnableParallel
                processed_steps[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                processed_steps[key] = value
            else:
                raise ValueError(f"步骤 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")
        self.steps = processed_steps

    def invoke(self, input_data, config=None, **kwargs):
        """并行执行所有步骤"""
        results = {}
        for key, step in self.steps.items():
            if hasattr(step, 'invoke'):
                # 检查 step 的 invoke 方法是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(step.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
                    else:
                        results[key] = step.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        results[key] = step.invoke(input_data, **kwargs)
                    except TypeError:
                        results[key] = step.invoke(input_data, config=config, **kwargs)
            elif callable(step):
                results[key] = step(input_data)
            else:
                raise ValueError(f"步骤 '{key}' 不可调用")
        return results

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        results = []
        for input_data in inputs:
            try:
                result = self.invoke(input_data, config=config, **kwargs)
                results.append(result)
            except Exception as e:
                results.append(f"错误: {e}")
        return results

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __repr__(self):
        return f"RunnableParallel({list(self.steps.keys())})"


class RunnablePassthrough(Runnable):
    """传递输入数据不变或添加额外键的 Runnable

    RunnablePassthrough 类似于恒等函数,但可以配置为在输出中添加额外的键(如果输入是字典)。
    它常用于在链中保留原始输入数据,同时添加处理后的结果。
    """

    def __init__(self, func=None):
        """
        初始化 RunnablePassthrough

        Args:
            func: 可选的函数,会在传递数据时调用(用于副作用)
        """
        self.func = func

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入数据"""
        # 如果提供了函数,调用它(用于副作用)
        if self.func is not None:
            if callable(self.func):
                try:
                    self.func(input_data, config=config, **kwargs)
                except TypeError:
                    self.func(input_data)

        # 直接返回输入数据
        return input_data

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    @classmethod
    def assign(cls, **kwargs):
        """创建一个 RunnableAssign,在传递数据的同时添加额外字段

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象

        Returns:
            RunnableAssign: 可以添加额外字段的 Runnable
        """
        return RunnableAssign(**kwargs)

    def __repr__(self):
        return "RunnablePassthrough()"


class RunnableAssign(Runnable):
    """在传递字典数据的同时添加额外字段的 Runnable

    RunnableAssign 是 RunnablePassthrough 的扩展,它接收一个字典输入,
    在保留原有字段的同时,添加新的字段。
    """

    def __init__(self, **kwargs):
        """
        初始化 RunnableAssign

        Args:
            **kwargs: 键值对,键是字段名,值可以是 Runnable 或可调用对象
        """
        self.assignments = {}
        for key, value in kwargs.items():
            if isinstance(value, dict):
                # 嵌套字典转换为 RunnableParallel
                self.assignments[key] = RunnableParallel(value)
            elif hasattr(value, 'invoke') or callable(value):
                self.assignments[key] = value
            else:
                raise ValueError(f"赋值 '{key}' 必须是 Runnable、可调用对象或字典,但得到 {type(value)}")

    def invoke(self, input_data, config=None, **kwargs):
        """传递输入并添加额外字段"""
        # 确保输入是字典
        if not isinstance(input_data, dict):
            input_data = {"input": input_data}

        # 复制输入数据
        result = dict(input_data)

        # 添加新字段
        for key, assignment in self.assignments.items():
            if hasattr(assignment, 'invoke'):
                # 检查是否接受 config 参数
                import inspect
                try:
                    sig = inspect.signature(assignment.invoke)
                    params = list(sig.parameters.keys())
                    if 'config' in params:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
                    else:
                        result[key] = assignment.invoke(input_data, **kwargs)
                except (ValueError, TypeError):
                    try:
                        result[key] = assignment.invoke(input_data, **kwargs)
                    except TypeError:
                        result[key] = assignment.invoke(input_data, config=config, **kwargs)
            elif callable(assignment):
                result[key] = assignment(input_data)
            else:
                raise ValueError(f"赋值 '{key}' 不可调用")

        return result

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __ror__(self, other):
        """支持从右侧使用 | 操作符(用于字典字面量)"""
        if isinstance(other, dict):
            # 字典字面量先转换为 RunnableParallel,然后与 RunnableAssign 组合
            return RunnableSequence(RunnableParallel(other), self)
        return NotImplemented

    def __repr__(self):
        return f"RunnableAssign({list(self.assignments.keys())})"


class RunnableBranch(Runnable):
    """根据条件选择执行分支的 Runnable

    RunnableBranch 类似于 if-else 或 switch 语句,根据条件选择不同的 Runnable 执行。
    它接受多个 (condition, Runnable) 对,按顺序检查条件,第一个为 True 的条件对应的 Runnable 会被执行。
    如果没有条件满足,则执行默认分支。
    """

    def __init__(self, *branches):
        """
        初始化 RunnableBranch

        Args:
            *branches: 多个 (condition, Runnable) 对,最后一个可以是默认分支(Runnable 或可调用对象)
        """
        if len(branches) < 2:
            raise ValueError("RunnableBranch 至少需要 2 个分支(包括默认分支)")

        # 分离条件和默认分支
        self.branches = []
        default = None

        for i, branch in enumerate(branches):
            if i == len(branches) - 1:
                # 最后一个可能是默认分支
                if isinstance(branch, tuple) and len(branch) == 2:
                    # 仍然是 (condition, Runnable) 对
                    condition, runnable = branch
                    self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))
                else:
                    # 是默认分支
                    default = self._coerce_to_runnable(branch)
            else:
                # 必须是 (condition, Runnable) 对
                if not isinstance(branch, tuple) or len(branch) != 2:
                    raise ValueError(f"分支 {i} 必须是 (condition, Runnable) 对,但得到 {type(branch)}")
                condition, runnable = branch
                self.branches.append((self._coerce_to_condition(condition), self._coerce_to_runnable(runnable)))

        if default is None:
            raise ValueError("RunnableBranch 必须提供默认分支(最后一个参数)")

        self.default = default

    def _coerce_to_condition(self, condition):
        """将条件转换为可调用对象"""
        if callable(condition):
            return condition
        elif hasattr(condition, 'invoke'):
            # 如果已经是 Runnable,返回一个包装函数
            def wrapped_condition(input_data):
                result = condition.invoke(input_data)
                return bool(result)
            return wrapped_condition
        else:
            raise ValueError(f"条件必须是可调用对象或 Runnable,但得到 {type(condition)}")

    def _coerce_to_runnable(self, runnable):
        """将值转换为 Runnable"""
        if hasattr(runnable, 'invoke'):
            return runnable
        elif callable(runnable):
            return RunnableLambda(runnable)
        else:
            raise ValueError(f"Runnable 必须是可调用对象或 Runnable,但得到 {type(runnable)}")

    def invoke(self, input_data, config=None, **kwargs):
        """根据条件选择分支执行"""
        # 按顺序检查条件
        for condition, runnable in self.branches:
            try:
                # 评估条件
                condition_result = condition(input_data)
                if condition_result:
                    # 条件满足,执行对应的 Runnable
                    if hasattr(runnable, 'invoke'):
                        import inspect
                        try:
                            sig = inspect.signature(runnable.invoke)
                            params = list(sig.parameters.keys())
                            if 'config' in params:
                                return runnable.invoke(input_data, config=config, **kwargs)
                            else:
                                return runnable.invoke(input_data, **kwargs)
                        except (ValueError, TypeError):
                            try:
                                return runnable.invoke(input_data, **kwargs)
                            except TypeError:
                                return runnable.invoke(input_data, config=config, **kwargs)
                    elif callable(runnable):
                        return runnable(input_data)
            except Exception as e:
                # 如果条件评估出错,继续下一个条件
                continue

        # 没有条件满足,执行默认分支
        if hasattr(self.default, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.default.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.default.invoke(input_data, config=config, **kwargs)
                else:
                    return self.default.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.default.invoke(input_data, **kwargs)
                except TypeError:
                    return self.default.invoke(input_data, config=config, **kwargs)
        elif callable(self.default):
            return self.default(input_data)
        else:
            return self.default

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableBranch({len(self.branches)} branches + default)"


class RunnableRetry(Runnable):
    """带重试功能的 Runnable

    RunnableRetry 包装另一个 Runnable,当执行失败时会自动重试。
    支持指数退避和抖动,以及指定重试的异常类型。
    """

    def __init__(
        self,
        bound,
        retry_exception_types=(Exception,),
        wait_exponential_jitter=True,
        max_attempt_number=3,
        **kwargs
    ):
        """
        初始化 RunnableRetry

        Args:
            bound: 要包装的 Runnable
            retry_exception_types: 需要重试的异常类型元组
            wait_exponential_jitter: 是否使用指数退避和抖动
            max_attempt_number: 最大重试次数
            **kwargs: 其他参数
        """
        self.bound = bound
        self.retry_exception_types = retry_exception_types
        self.wait_exponential_jitter = wait_exponential_jitter
        self.max_attempt_number = max_attempt_number
        self.kwargs = kwargs

    def _should_retry(self, exception):
        """检查是否应该重试"""
        return isinstance(exception, self.retry_exception_types)

    def _wait_time(self, attempt_number):
        """计算等待时间(指数退避)"""
        if not self.wait_exponential_jitter:
            return 0

        import random
        # 指数退避:2^attempt_number 秒,最大 10 秒
        base_wait = min(2 ** attempt_number, 10)
        # 添加抖动:随机 0-1 秒
        jitter = random.uniform(0, 1)
        return base_wait + jitter

    def invoke(self, input_data, config=None, **kwargs):
        """执行并重试"""
        last_exception = None

        for attempt in range(1, self.max_attempt_number + 1):
            try:
                # 调用原始的 Runnable
                if hasattr(self.bound, 'invoke'):
                    import inspect
                    try:
                        sig = inspect.signature(self.bound.invoke)
                        params = list(sig.parameters.keys())
                        if 'config' in params:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                        else:
                            return self.bound.invoke(input_data, **kwargs)
                    except (ValueError, TypeError):
                        try:
                            return self.bound.invoke(input_data, **kwargs)
                        except TypeError:
                            return self.bound.invoke(input_data, config=config, **kwargs)
                elif callable(self.bound):
                    return self.bound(input_data)
                else:
                    raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")
            except Exception as e:
                last_exception = e
                if not self._should_retry(e):
                    # 不应该重试的异常,直接抛出
                    raise

                if attempt < self.max_attempt_number:
                    # 计算等待时间
                    wait_time = self._wait_time(attempt)
                    if wait_time > 0:
                        import time
                        time.sleep(wait_time)
                    # 继续重试
                    continue
                else:
                    # 达到最大重试次数,抛出最后一个异常
                    raise

        # 如果所有重试都失败,抛出最后一个异常
        if last_exception:
            raise last_exception

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def __repr__(self):
        return f"RunnableRetry(bound={self.bound}, max_attempts={self.max_attempt_number})"


class RunnableBinding(Runnable):
    """绑定配置的 Runnable

    RunnableBinding 包装另一个 Runnable,并绑定配置。
    当调用绑定的 Runnable 时,会自动使用绑定的配置。
    """

    def __init__(self, bound, config=None, **kwargs):
        """
        初始化 RunnableBinding

        Args:
            bound: 要包装的 Runnable
            config: 要绑定的配置字典
            **kwargs: 额外的配置参数
        """
        self.bound = bound
        if config is None:
            config = {}
        # 合并 kwargs 到 config
        self.config = {**config, **kwargs}

    def invoke(self, input_data, config=None, **kwargs):
        """执行并合并配置"""
        # 合并绑定的配置和传入的配置(传入的配置优先)
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        # 合并 kwargs 到 config
        merged_config.update(kwargs)

        # 调用原始的 Runnable
        if hasattr(self.bound, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.bound.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
                else:
                    return self.bound.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.bound.invoke(input_data, **kwargs)
                except TypeError:
                    return self.bound.invoke(input_data, config=merged_config, **kwargs)
        elif callable(self.bound):
            return self.bound(input_data)
        else:
            raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_config(self, config=None, **kwargs):
        """添加更多配置(合并到现有配置)"""
        merged_config = {**self.config}
        if config:
            merged_config.update(config)
        merged_config.update(kwargs)
        return RunnableBinding(bound=self.bound, config=merged_config)

    def __repr__(self):
        return f"RunnableBinding(bound={self.bound}, config={self.config})"


class RunnableConfigurableFields(Runnable):
    """可配置字段的 Runnable

    RunnableConfigurableFields 允许在运行时动态配置 Runnable 的属性。
    它通过 configurable_fields 方法创建,使用 ConfigurableField 定义可配置的字段。
    """

    def __init__(self, bound, fields):
        """
        初始化 RunnableConfigurableFields

        Args:
            bound: 要包装的 Runnable
            fields: 可配置字段字典,键是字段名,值是 ConfigurableField
        """
        self.bound = bound
        self.fields = fields

    def invoke(self, input_data, config=None, **kwargs):
        """执行并应用可配置字段"""
        # 从 config 中获取可配置字段的值
        if config and 'configurable' in config:
            configurable = config['configurable']
            # 应用可配置字段的值到 bound
            for field_name, field_def in self.fields.items():
                if field_def.id in configurable:
                    # 如果 bound 有该属性,设置它
                    if hasattr(self.bound, field_name):
                        setattr(self.bound, field_name, configurable[field_def.id])
                    # 如果 bound 是字典,更新它
                    elif isinstance(self.bound, dict):
                        self.bound[field_name] = configurable[field_def.id]

        # 调用原始的 Runnable
        if hasattr(self.bound, 'invoke'):
            import inspect
            try:
                sig = inspect.signature(self.bound.invoke)
                params = list(sig.parameters.keys())
                if 'config' in params:
                    return self.bound.invoke(input_data, config=config, **kwargs)
                else:
                    return self.bound.invoke(input_data, **kwargs)
            except (ValueError, TypeError):
                try:
                    return self.bound.invoke(input_data, **kwargs)
                except TypeError:
                    return self.bound.invoke(input_data, config=config, **kwargs)
        elif callable(self.bound):
            return self.bound(input_data)
        else:
            raise ValueError(f"bound 必须是 Runnable 或可调用对象,但得到 {type(self.bound)}")

    def batch(self, inputs, config=None, **kwargs):
        """批处理"""
        return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]

    def stream(self, input_data, config=None, **kwargs):
        """流式处理"""
        result = self.invoke(input_data, config=config, **kwargs)
        yield result

    def __or__(self, other):
        """支持 | 操作符"""
        if isinstance(other, dict):
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def with_config(self, config=None, **kwargs):
        """添加更多配置"""
        if config is None:
            config = {}
        merged_config = {**config, **kwargs}
        return RunnableConfigurableFields(bound=self.bound, fields=self.fields)

    def __repr__(self):
        return f"RunnableConfigurableFields(bound={self.bound}, fields={list(self.fields.keys())})"


def setup_lcel_support():
    """设置 LCEL 支持,为组件添加 | 操作符和 invoke 方法"""
+   # 分别导入,避免一个失败导致全部失败
+   PromptTemplate = None
+   ChatPromptTemplate = None
+   ChatOpenAI = None
+   BaseOutputParser = None
+   
    try:
        from langchain.prompts import PromptTemplate
+   except ImportError:
+       pass
+   
+   try:
+       from langchain.prompts import ChatPromptTemplate
+   except ImportError:
+       pass
+   
+   try:
        from langchain.chat_models import ChatOpenAI
+   except ImportError:
+       pass
+   
+   try:
        from langchain.output_parsers import BaseOutputParser
    except ImportError:
+       pass

    # 为 PromptTemplate 添加 LCEL 支持
    def _prompt_or(self, other):
        """PromptTemplate 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _prompt_invoke(self, input_data, **kwargs):
        """PromptTemplate 的 invoke 方法"""
        if isinstance(input_data, dict):
            return self.format(**input_data)
        elif isinstance(input_data, str):
            # 如果输入是字符串,尝试作为单个变量
            if self.input_variables:
                return self.format(**{self.input_variables[0]: input_data})
            return input_data
        else:
            return str(input_data)

+   # 为 ChatPromptTemplate 添加 LCEL 支持
+   def _chat_prompt_or(self, other):
+       """ChatPromptTemplate 的 | 操作符"""
+       if isinstance(other, dict):
+           # 字典字面量自动转换为 RunnableParallel
+           return RunnableSequence(self, RunnableParallel(other))
+       elif hasattr(other, 'invoke') or callable(other):
+           return RunnableSequence(self, other)
+       return NotImplemented
+   
+   def _chat_prompt_invoke(self, input_data, **kwargs):
+       """ChatPromptTemplate 的 invoke 方法"""
+       if isinstance(input_data, dict):
+           # 使用 format_messages 方法
+           if hasattr(self, 'format_messages'):
+               return self.format_messages(**input_data)
+           else:
+               return self.format(**input_data)
+       elif isinstance(input_data, str):
+           # 如果输入是字符串,尝试作为单个变量
+           if hasattr(self, 'input_variables') and self.input_variables:
+               if hasattr(self, 'format_messages'):
+                   return self.format_messages(**{self.input_variables[0]: input_data})
+               else:
+                   return self.format(**{self.input_variables[0]: input_data})
+           return input_data
+       else:
+           return str(input_data)
+   
    # 为 ChatOpenAI 添加 LCEL 支持
    def _llm_or(self, other):
        """ChatOpenAI 的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or hasattr(other, 'parse') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    # 为输出解析器添加 LCEL 支持
    def _parser_or(self, other):
        """输出解析器的 | 操作符"""
        if isinstance(other, dict):
            # 字典字面量自动转换为 RunnableParallel
            return RunnableSequence(self, RunnableParallel(other))
        elif hasattr(other, 'invoke') or callable(other):
            return RunnableSequence(self, other)
        return NotImplemented

    def _parser_invoke(self, input_data, **kwargs):
        """输出解析器的 invoke 方法"""
        if hasattr(input_data, 'content'):
            # 如果是消息对象,提取内容
            return self.parse(input_data.content)
        elif isinstance(input_data, str):
            return self.parse(input_data)
        else:
            return self.parse(str(input_data))

+   # 绑定方法到类(只绑定已成功导入的类)
+   if PromptTemplate is not None:
+       PromptTemplate.__or__ = _prompt_or
+       PromptTemplate.invoke = _prompt_invoke
+   
+   if ChatPromptTemplate is not None:
+       ChatPromptTemplate.__or__ = _chat_prompt_or
+       ChatPromptTemplate.invoke = _chat_prompt_invoke
+   
+   if ChatOpenAI is not None:
+       ChatOpenAI.__or__ = _llm_or
+   
+   if BaseOutputParser is not None:
+       BaseOutputParser.__or__ = _parser_or
+       BaseOutputParser.invoke = _parser_invoke


# 自动设置 LCEL 支持
setup_lcel_support()


+# RunnableWithMessageHistory 实现
+class RunnableWithMessageHistory(RunnableBinding):
+   """带消息历史的 Runnable
+   
+   RunnableWithMessageHistory 包装另一个 Runnable,并自动管理聊天消息历史。
+   它会自动读取历史消息、注入到输入中,并在执行后更新历史记录。
+   
+   使用示例:
+       ```python
+       from langchain.chat_message_histories import ChatMessageHistory
+       from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
+       from langchain.chat_models import ChatOpenAI
+       from langchain.runnables import RunnableWithMessageHistory
+       
+       # 创建会话历史存储
+       store = {}
+       
+       def get_session_history(session_id: str):
+           if session_id not in store:
+               store[session_id] = ChatMessageHistory()
+           return store[session_id]
+       
+       # 创建链
+       prompt = ChatPromptTemplate.from_messages([
+           ("system", "你是一个助手"),
+           MessagesPlaceholder(variable_name="history"),
+           ("human", "{question}"),
+       ])
+       chain = prompt | ChatOpenAI()
+       
+       # 包装链以支持消息历史
+       chain_with_history = RunnableWithMessageHistory(
+           chain,
+           get_session_history,
+           input_messages_key="question",
+           history_messages_key="history",
+       )
+       
+       # 使用 session_id 调用
+       result = chain_with_history.invoke(
+           {"question": "你好"},
+           config={"configurable": {"session_id": "user123"}}
+       )
+
  • """
  • def init(
  • self,
  • runnable,
  • get_session_history,
  • *,
  • input_messages_key: str = None,
  • output_messages_key: str = None,
  • history_messages_key: str = None,
  • **kwargs
  • ):
  • """
  • 初始化 RunnableWithMessageHistory
  • Args:
  • runnable: 要包装的 Runnable
  • get_session_history: 获取会话历史的函数,接受 session_id 参数
  • input_messages_key: 输入字典中消息的键(如果输入是字典)
  • output_messages_key: 输出字典中消息的键(如果输出是字典)
  • history_messages_key: 历史消息的键(如果输入是字典)
  • **kwargs: 其他参数
  • """
  • self.runnable = runnable
  • self.get_session_history = get_session_history
  • self.input_messages_key = input_messages_key
  • self.output_messages_key = output_messages_key
  • self.history_messages_key = history_messages_key
  • 调用父类初始化 #

  • super().init(bound=self, **kwargs)
  • def invoke(self, input_data, config=None, **kwargs):
  • """执行并管理消息历史"""
  • 从 config 中获取 session_id #

  • session_id = None
  • if config and 'configurable' in config:
  • session_id = config['configurable'].get('session_id')
  • if not session_id:
  • 如果没有 session_id,直接调用原始 runnable #

  • return self._invoke_runnable(input_data, config=config, **kwargs)
  • 获取会话历史 #

  • history = self.get_session_history(session_id)
  • 保存原始输入(用于后续更新历史) #

  • original_input = input_data
  • 准备输入 #

  • if isinstance(input_data, dict):
  • 如果输入是字典,注入历史消息 #

  • if self.history_messages_key:
  • input_data = input_data.copy()
  • input_data[self.history_messages_key] = history.messages
  • elif self.input_messages_key:
  • 如果只有 input_messages_key,将当前输入和历史合并 #

  • current_input = input_data.get(self.input_messages_key, "")
  • if isinstance(current_input, str):
  • from langchain.messages import HumanMessage
  • current_input = HumanMessage(content=current_input)
  • all_messages = list(history.messages) + [current_input]
  • input_data = input_data.copy()
  • input_data[self.input_messages_key] = all_messages
  • elif isinstance(input_data, str):
  • 如果输入是字符串,转换为消息列表 #

  • from langchain.messages import HumanMessage
  • all_messages = list(history.messages) + [HumanMessage(content=input_data)]
  • input_data = all_messages
  • elif isinstance(input_data, list):
  • 如果输入是消息列表,合并历史 #

  • all_messages = list(history.messages) + input_data
  • input_data = all_messages
  • 调用原始 runnable #

  • result = self._invoke_runnable(input_data, config=config, **kwargs)
  • 更新历史记录(使用原始输入,避免重复添加) #

  • self._update_history(history, original_input, result)
  • return result
  • def _invoke_runnable(self, input_data, config=None, **kwargs):
  • """调用原始 runnable"""
  • if hasattr(self.runnable, 'invoke'):
  • import inspect
  • try:
  • sig = inspect.signature(self.runnable.invoke)
  • params = list(sig.parameters.keys())
  • if 'config' in params:
  • return self.runnable.invoke(input_data, config=config, **kwargs)
  • else:
  • return self.runnable.invoke(input_data, **kwargs)
  • except (ValueError, TypeError):
  • try:
  • return self.runnable.invoke(input_data, **kwargs)
  • except TypeError:
  • return self.runnable.invoke(input_data, config=config, **kwargs)
  • elif callable(self.runnable):
  • return self.runnable(input_data)
  • else:
  • raise ValueError(f"runnable 必须是 Runnable 或可调用对象,但得到 {type(self.runnable)}")
  • def _update_history(self, history, input_data, result):
  • """更新历史记录"""
  • 获取历史消息数量(用于判断哪些是新消息) #

  • history_count_before = len(history.messages)
  • 添加用户输入到历史(只添加新的消息) #

  • if isinstance(input_data, dict):
  • if self.input_messages_key:
  • user_input = input_data.get(self.input_messages_key)
  • if isinstance(user_input, str):
  • history.add_user_message(user_input)
  • elif isinstance(user_input, list) and len(user_input) > history_count_before:
  • 只添加新的消息(历史消息之后的部分) #

  • new_messages = user_input[history_count_before:]
  • for msg in new_messages:
  • if hasattr(msg, 'content'):
  • history.add_message(msg)
  • elif isinstance(input_data, str):
  • history.add_user_message(input_data)
  • elif isinstance(input_data, list):
  • 只添加新的消息(历史消息之后的部分) #

  • new_messages = input_data[history_count_before:]
  • for msg in new_messages:
  • if hasattr(msg, 'content'):
  • history.add_message(msg)
  • 添加 AI 回复到历史 #

  • if isinstance(result, dict):
  • if self.output_messages_key:
  • ai_output = result.get(self.output_messages_key)
  • if isinstance(ai_output, str):
  • history.add_ai_message(ai_output)
  • elif hasattr(ai_output, 'content'):
  • history.add_ai_message(ai_output)
  • elif isinstance(ai_output, list) and len(ai_output) > 0:
  • for msg in ai_output:
  • if hasattr(msg, 'content'):
  • history.add_message(msg)
  • elif isinstance(result, str):
  • history.add_ai_message(result)
  • elif hasattr(result, 'content'):
  • 如果是消息对象(如 AIMessage) #

  • history.add_ai_message(result)
  • elif isinstance(result, list):
  • for msg in result:
  • if hasattr(msg, 'content'):
  • history.add_message(msg)
  • def batch(self, inputs, config=None, **kwargs):
  • """批处理"""
  • return [self.invoke(input_data, config=config, **kwargs) for input_data in inputs]
  • def stream(self, input_data, config=None, **kwargs):
  • """流式处理"""
  • result = self.invoke(input_data, config=config, **kwargs)
  • yield result
  • def repr(self):
  • return f"RunnableWithMessageHistory(runnable={self.runnable}, input_key={self.input_messages_key}, history_key={self.history_messages_key})" + +

    ConfigurableField 相关类 #

    from collections import namedtuple

ConfigurableField = namedtuple( 'ConfigurableField', ['id', 'name', 'description', 'annotation', 'is_shared'], defaults=(None, None, None, False) ) """可配置字段

ConfigurableField 用于定义可以在运行时配置的字段。 它通常与 configurable_fields 方法一起使用,允许在运行时动态配置 Runnable 的属性。

Args: id: 字段的唯一标识符 name: 字段的名称(可选) description: 字段的描述(可选) annotation: 字段的类型注解(可选) is_shared: 字段是否共享(可选,默认为 False) """

ConfigurableFieldSingleOption = namedtuple( 'ConfigurableFieldSingleOption', ['id', 'name', 'description', 'annotation', 'default', 'is_shared'], defaults=(None, None, None, None, False) ) """单选项可配置字段

ConfigurableFieldSingleOption 用于定义有默认值的可配置字段。

Args: id: 字段的唯一标识符 name: 字段的名称(可选) description: 字段的描述(可选) annotation: 字段的类型注解(可选) default: 默认值 is_shared: 字段是否共享(可选,默认为 False) """

ConfigurableFieldMultiOption = namedtuple( 'ConfigurableFieldMultiOption', ['id', 'name', 'description', 'annotation', 'options', 'default', 'is_shared'], defaults=(None, None, None, None, None, False) ) """多选项可配置字段

ConfigurableFieldMultiOption 用于定义有多个选项的可配置字段。

Args: id: 字段的唯一标识符 name: 字段的名称(可选) description: 字段的描述(可选) annotation: 字段的类型注解(可选) options: 选项字典 default: 默认值(可选) is_shared: 字段是否共享(可选,默认为 False) """


## 33. SQLChatMessageHistory
### 33.1. SQLChatMessageHistory.py
33.SQLChatMessageHistory.py
```js
#from langchain_core.messages import HumanMessage, AIMessage
#from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
#from langchain_core.output_parsers import StrOutputParser
#from langchain_core.runnables import RunnableLambda
#from langchain_core.runnables.history import RunnableWithMessageHistory
#from langchain_community.chat_message_histories import SQLChatMessageHistory, ChatMessageHistory
#from langchain_openai import ChatOpenAI

from langchain.chat_message_histories import SQLChatMessageHistory, ChatMessageHistory
from langchain.messages import HumanMessage, AIMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.output_parsers import StrOutputParser
from langchain.runnables import RunnableWithMessageHistory, RunnableLambda
from langchain.chat_models import ChatOpenAI

print("=" * 60)
print("SQLChatMessageHistory 演示")
print("=" * 60)
print("\nSQLChatMessageHistory: 将聊天消息历史存储在 SQL 数据库中")
print("它支持持久化存储,适合生产环境使用\n")

print("=" * 60)
print("演示 1: SQLChatMessageHistory 的基本用法(内存数据库)")
print("=" * 60)

# 使用 SQLite 内存数据库
history = SQLChatMessageHistory(
    session_id="user1",
    connection_string="sqlite:///:memory:"
)

print("\n创建 SQLChatMessageHistory 实例(内存数据库):")
print("history = SQLChatMessageHistory(")
print("    session_id='user1',")
print("    connection_string='sqlite:///:memory:'")
print(")")

# 添加消息
history.add_user_message("你好,我是张三")
history.add_ai_message("你好张三,很高兴认识你!")

print("\n添加消息:")
print("history.add_user_message('你好,我是张三')")
print("history.add_ai_message('你好张三,很高兴认识你!')")

# 查看消息
print("\n查看所有消息:")
print(f"history.messages: {len(history.messages)} 条消息")
for i, msg in enumerate(history.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 2: 使用文件数据库(持久化存储)")
print("=" * 60)

# 使用 SQLite 文件数据库
db_file = "chat_history.db"
if os.path.exists(db_file):
    os.remove(db_file)  # 删除旧数据库

file_history = SQLChatMessageHistory(
    session_id="user2",
    connection_string=f"sqlite:///{db_file}"
)

print("\n创建 SQLChatMessageHistory 实例(文件数据库):")
print(f"file_history = SQLChatMessageHistory(")
print(f"    session_id='user2',")
print(f"    connection_string='sqlite:///{db_file}'")
print(")")

# 添加消息
file_history.add_user_message("这是持久化的消息 1")
file_history.add_ai_message("这是持久化的回复 1")
file_history.add_user_message("这是持久化的消息 2")

print("\n添加消息:")
print("file_history.add_user_message('这是持久化的消息 1')")
print("file_history.add_ai_message('这是持久化的回复 1')")
print("file_history.add_user_message('这是持久化的消息 2')")

print(f"\n消息数量: {len(file_history.messages)}")
print("\n消息列表:")
for i, msg in enumerate(file_history.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

# 关闭并重新打开,验证持久化
print("\n关闭数据库连接...")
file_history.close()

print("重新打开数据库连接...")
file_history2 = SQLChatMessageHistory(
    session_id="user2",
    connection_string=f"sqlite:///{db_file}"
)

print(f"\n重新打开后消息数量: {len(file_history2.messages)}")
print("消息仍然存在(持久化成功):")
for i, msg in enumerate(file_history2.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 3: 多个会话管理")
print("=" * 60)

# 创建多个会话
session1 = SQLChatMessageHistory(
    session_id="session1",
    connection_string="sqlite:///:memory:"
)

session2 = SQLChatMessageHistory(
    session_id="session2",
    connection_string="sqlite:///:memory:"
)

print("\n创建多个独立的会话:")
print("session1 = SQLChatMessageHistory(session_id='session1', ...)")
print("session2 = SQLChatMessageHistory(session_id='session2', ...)")

# 为不同会话添加消息
session1.add_user_message("会话1的消息1")
session1.add_ai_message("会话1的回复1")

session2.add_user_message("会话2的消息1")
session2.add_ai_message("会话2的回复1")

print("\n为不同会话添加消息:")
print(f"session1 消息数量: {len(session1.messages)}")
print(f"session2 消息数量: {len(session2.messages)}")

print("\n会话1的消息:")
for i, msg in enumerate(session1.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n会话2的消息:")
for i, msg in enumerate(session2.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 4: 使用 add_messages 批量添加")
print("=" * 60)

batch_history = SQLChatMessageHistory(
    session_id="batch_session",
    connection_string="sqlite:///:memory:"
)

messages_to_add = [
    HumanMessage(content="批量消息1"),
    AIMessage(content="批量回复1"),
    HumanMessage(content="批量消息2"),
    AIMessage(content="批量回复2"),
]

print("\n使用 add_messages 批量添加消息:")
batch_history.add_messages(messages_to_add)
print(f"batch_history.add_messages({len(messages_to_add)} 条消息)")

print(f"\n消息总数: {len(batch_history.messages)}")
for i, msg in enumerate(batch_history.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 5: 清空消息历史")
print("=" * 60)

clear_history = SQLChatMessageHistory(
    session_id="clear_session",
    connection_string="sqlite:///:memory:"
)

clear_history.add_user_message("消息1")
clear_history.add_ai_message("回复1")
clear_history.add_user_message("消息2")

print("\n清空前:")
print(f"  消息数量: {len(clear_history.messages)}")

clear_history.clear()
print("\n执行 clear_history.clear()")
print(f"清空后消息数量: {len(clear_history.messages)}")

print("\n" + "=" * 60)
print("演示 6: 与 RunnableWithMessageHistory 结合使用")
print("=" * 60)

# 创建简单的链
def simple_chain_func(input_data):
    """简单的链,返回固定回复"""
    if isinstance(input_data, dict):
        question = input_data.get("question", "")
        return AIMessage(content=f"AI回复: 我收到了你的问题:{question}")
    elif isinstance(input_data, str):
        return AIMessage(content=f"AI回复: 我收到了你的消息:{input_data}")
    elif isinstance(input_data, list):
        if len(input_data) > 0:
            last_msg = input_data[-1]
            if hasattr(last_msg, 'content'):
                return AIMessage(content=f"AI回复: 我收到了你的消息:{last_msg.content}")
    return AIMessage(content="AI回复: 我收到了你的消息")

# 将函数包装成 RunnableLambda
if RunnableLambda is not None:
    simple_chain = RunnableLambda(simple_chain_func)
else:
    simple_chain = simple_chain_func

# 创建会话历史工厂函数
def get_sql_session_history(session_id: str):
    """获取或创建 SQL 会话历史"""
    return SQLChatMessageHistory(
        session_id=session_id,
        connection_string="sqlite:///:memory:"
    )

if RunnableWithMessageHistory is not None:
    # 创建带历史记录的链
    chain_with_history = RunnableWithMessageHistory(
        simple_chain,
        get_sql_session_history,
        input_messages_key="question",
    )

    print("\n创建带历史记录的链(使用 SQLChatMessageHistory):")
    print("chain_with_history = RunnableWithMessageHistory(")
    print("    simple_chain,")
    print("    get_sql_session_history,")
    print("    input_messages_key='question',")
    print(")")

    # 第一次调用
    print("\n第一次调用(session_id='sql_user1'):")
    result1 = chain_with_history.invoke(
        {"question": "你好"},
        config={"configurable": {"session_id": "sql_user1"}}
    )
    print(f"  输入: {{'question': '你好'}}")
    print(f"  输出: {result1.content if hasattr(result1, 'content') else result1}")

    # 获取历史记录
    sql_history = get_sql_session_history("sql_user1")
    print(f"  历史记录: {len(sql_history.messages)} 条消息")

    # 第二次调用
    print("\n第二次调用(相同 session_id):")
    result2 = chain_with_history.invoke(
        {"question": "你还记得我吗?"},
        config={"configurable": {"session_id": "sql_user1"}}
    )
    print(f"  输入: {{'question': '你还记得我吗?'}}")
    print(f"  输出: {result2.content if hasattr(result2, 'content') else result2}")

    sql_history2 = get_sql_session_history("sql_user1")
    print(f"  历史记录: {len(sql_history2.messages)} 条消息")
    print("\n历史消息:")
    for i, msg in enumerate(sql_history2.messages, 1):
        print(f"  {i}. {type(msg).__name__}: {msg.content}")
else:
    print("\n(跳过此演示,需要 RunnableWithMessageHistory 支持)")

print("\n" + "=" * 60)
print("演示 7: 自定义表名")
print("=" * 60)

# 使用自定义表名
custom_table_history = SQLChatMessageHistory(
    session_id="custom_table_user",
    connection_string="sqlite:///:memory:",
    table_name="custom_messages"
)

print("\n创建 SQLChatMessageHistory 实例(自定义表名):")
print("custom_table_history = SQLChatMessageHistory(")
print("    session_id='custom_table_user',")
print("    connection_string='sqlite:///:memory:',")
print("    table_name='custom_messages'")
print(")")

custom_table_history.add_user_message("自定义表的消息")
custom_table_history.add_ai_message("自定义表的回复")

print(f"\n消息数量: {len(custom_table_history.messages)}")
for i, msg in enumerate(custom_table_history.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

print("\n" + "=" * 60)
print("演示 8: 持久化存储的实际应用")
print("=" * 60)

# 使用文件数据库进行持久化
persistent_db = "persistent_chat.db"
if os.path.exists(persistent_db):
    os.remove(persistent_db)

persistent_history = SQLChatMessageHistory(
    session_id="persistent_user",
    connection_string=f"sqlite:///{persistent_db}"
)

print("\n创建持久化存储的 SQLChatMessageHistory:")
print(f"persistent_history = SQLChatMessageHistory(")
print(f"    session_id='persistent_user',")
print(f"    connection_string='sqlite:///{persistent_db}'")
print(")")

# 添加消息
persistent_history.add_user_message("持久化消息1")
persistent_history.add_ai_message("持久化回复1")
persistent_history.add_user_message("持久化消息2")

print("\n添加消息:")
print(f"  消息数量: {len(persistent_history.messages)}")

# 关闭连接
persistent_history.close()
print("\n关闭数据库连接...")

# 重新打开,验证持久化
print("重新打开数据库连接...")
persistent_history2 = SQLChatMessageHistory(
    session_id="persistent_user",
    connection_string=f"sqlite:///{persistent_db}"
)

print(f"重新打开后消息数量: {len(persistent_history2.messages)}")
print("\n持久化的消息:")
for i, msg in enumerate(persistent_history2.messages, 1):
    print(f"  {i}. {type(msg).__name__}: {msg.content}")

# 清理
persistent_history2.close()
if os.path.exists(persistent_db):
    os.remove(persistent_db)
    print(f"\n清理:删除数据库文件 {persistent_db}")

print("\n" + "=" * 60)
print("演示 9: 多轮对话管理")
print("=" * 60)

# 模拟多轮对话
conversation_history = SQLChatMessageHistory(
    session_id="conversation",
    connection_string="sqlite:///:memory:"
)

print("\n模拟多轮对话:")

# 第一轮
conversation_history.add_user_message("我想学习 Python")
conversation_history.add_ai_message("很好!Python 是一门很好的编程语言。你想从哪方面开始?")
print("  用户: 我想学习 Python")
print("  AI: 很好!Python 是一门很好的编程语言。你想从哪方面开始?")

# 第二轮
conversation_history.add_user_message("基础语法")
conversation_history.add_ai_message("好的,基础语法包括变量、数据类型、控制流等。")
print("  用户: 基础语法")
print("  AI: 好的,基础语法包括变量、数据类型、控制流等。")

# 第三轮
conversation_history.add_user_message("能详细说说变量吗?")
conversation_history.add_ai_message("变量是用来存储数据的容器,在 Python 中不需要声明类型。")
print("  用户: 能详细说说变量吗?")
print("  AI: 变量是用来存储数据的容器,在 Python 中不需要声明类型。")

print(f"\n对话历史记录(共 {len(conversation_history.messages)} 条消息):")
for i, msg in enumerate(conversation_history.messages, 1):
    role = "用户" if "Human" in type(msg).__name__ else "AI"
    print(f"  {i}. [{role}]: {msg.content}")

print("\n" + "=" * 60)
print("演示 10: 与 ChatPromptTemplate 结合使用(如果可用)")
print("=" * 60)

if ChatOpenAI and ChatPromptTemplate and MessagesPlaceholder and RunnableWithMessageHistory:
    # 创建会话历史工厂函数(使用文件数据库)
    llm_db_file = "llm_chat_history.db"
    if os.path.exists(llm_db_file):
        os.remove(llm_db_file)

    def get_llm_session_history(session_id: str):
        """获取或创建 LLM 会话历史(使用文件数据库)"""
        return SQLChatMessageHistory(
            session_id=session_id,
            connection_string=f"sqlite:///{llm_db_file}"
        )

    # 创建提示模板
    prompt = ChatPromptTemplate.from_messages([
        ("system", "你是一个友好的 AI 助手。"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ])

    # 创建 LLM 链
    llm = ChatOpenAI(model="gpt-4o")
    parser = StrOutputParser()
    chain = prompt | llm | parser

    # 创建带历史记录的链
    llm_chain_with_history = RunnableWithMessageHistory(
        chain,
        get_llm_session_history,
        input_messages_key="question",
        history_messages_key="history",
    )

    print("\n创建带历史记录的 LLM 链(使用 SQLChatMessageHistory):")
    print("llm_chain_with_history = RunnableWithMessageHistory(")
    print("    chain,")
    print("    get_llm_session_history,")
    print("    input_messages_key='question',")
    print("    history_messages_key='history',")
    print(")")

    # 第一次调用
    print("\n第一次调用(session_id='llm_user'):")
    print("  ⚠️  注意:LLM API 调用可能需要 5-30 秒,请耐心等待...")
    try:
        result1 = llm_chain_with_history.invoke(
            {"question": "你好,我是李四"},
            config={"configurable": {"session_id": "llm_user"}}
        )
        print(f"  ✅ 调用成功!")
        print(f"  输入: {{'question': '你好,我是李四'}}")
        print(f"  输出: {result1[:100] if isinstance(result1, str) else str(result1)[:100]}...")

        # 验证持久化
        sql_history = get_llm_session_history("llm_user")
        print(f"  历史记录: {len(sql_history.messages)} 条消息")

        # 第二次调用
        print("\n第二次调用(相同 session_id):")
        print("  ⚠️  注意:LLM API 调用可能需要 5-30 秒,请耐心等待...")
        result2 = llm_chain_with_history.invoke(
            {"question": "你还记得我的名字吗?"},
            config={"configurable": {"session_id": "llm_user"}}
        )
        print(f"  ✅ 调用成功!")
        print(f"  输入: {{'question': '你还记得我的名字吗?'}}")
        print(f"  输出: {result2[:100] if isinstance(result2, str) else str(result2)[:100]}...")

        sql_history2 = get_llm_session_history("llm_user")
        print(f"  历史记录: {len(sql_history2.messages)} 条消息")

        # 关闭并重新打开,验证持久化
        sql_history2.close()
        print("\n关闭数据库连接...")
        print("重新打开数据库连接...")
        sql_history3 = get_llm_session_history("llm_user")
        print(f"重新打开后历史记录: {len(sql_history3.messages)} 条消息")
        print("(消息已持久化到数据库)")
        sql_history3.close()

        # 清理
        if os.path.exists(llm_db_file):
            os.remove(llm_db_file)
            print(f"\n清理:删除数据库文件 {llm_db_file}")
    except KeyboardInterrupt:
        print("\n  ⚠️  用户中断了调用(这是正常的,可以继续其他演示)")
    except Exception as e:
        print(f"\n  ❌ 调用失败:{type(e).__name__}: {e}")
        print(f"  💡 可能的原因:")
        print(f"     - 缺少 OPENAI_API_KEY 环境变量")
        print(f"     - 网络连接问题")
        print(f"     - API 服务暂时不可用")
else:
    print("\n(跳过此演示,需要 ChatOpenAI 和 ChatPromptTemplate 支持)")

print("\n" + "=" * 60)
print("总结")
print("=" * 60)
print("\nSQLChatMessageHistory 的核心功能:")
print("1. 持久化存储:消息存储在 SQL 数据库中,程序重启后不会丢失")
print("2. 会话隔离:通过 session_id 区分不同会话")
print("3. 数据库支持:支持 SQLite、PostgreSQL、MySQL 等 SQL 数据库")
print("4. 批量操作:支持 add_messages 批量添加消息")
print("5. 消息管理:支持添加、获取、清空消息")
print("\n使用场景:")
print("- 生产环境:需要持久化存储的聊天应用")
print("- 多用户系统:为不同用户维护独立的对话历史")
print("- 数据分析:可以查询和分析历史对话数据")
print("- 备份恢复:数据库可以备份和恢复")
print("\n注意事项:")
print("- 需要安装 sqlalchemy(如果使用 langchain_community 的实现)")
print("- 使用文件数据库时,需要确保有写入权限")
print("- 生产环境建议使用 PostgreSQL 或 MySQL 等专业数据库")
print("- 定期清理旧数据,避免数据库过大")
print("- 使用连接池可以提高性能")
print("- 考虑使用异步模式(AsyncEngine)以提高并发性能")

访问验证

请输入访问令牌

Token不正确,请重新输入