导航菜单

  • 1.vector
  • 2.milvus
  • 3.pymilvus
  • 4.rag
  • 5.rag_measure
  • ragflow
  • heapq
  • HNSW
  • cosine_similarity
  • math
  • typing
  • etcd
  • minio
  • collections
  • jieba
  • random
  • beautifulsoup4
  • chromadb
  • sentence_transformers
  • numpy
  • lxml
  • openpyxl
  • PyMuPDF
  • python-docx
  • requests
  • python-pptx
  • text_splitter
  • all-MiniLM-L6-v2
  • openai
  • llm
  • BPETokenizer
  • Flask
  • RAGAS
  • BagofWords
  • langchain
  • Pydantic
  • abc
  • faiss
  • MMR
  • scikit-learn
  • Runnable
  • PromptEngineering
  • dataclasses
  • LaTeX
  • rank_bm25
  • TF-IDF
  • asyncio
  • sqlalchemy
  • fastapi
  • Starlette
  • uvicorn
  • argparse
  • Generic
  • ssl
  • urllib
  • python-dotenv
  • RRF
  • CrossEncoder
  • Lost-in-the-middle
  • Jinja2
  • logger
  • io
  • venv
  • concurrent
  • parameter
  • SSE
  • 1. 什么是CrossEncoder
    • 1.1 为什么需要CrossEncoder
    • 1.2 直观理解
  • 2. 前置知识
    • 2.1 什么是文本编码(Embedding)
    • 2.2 什么是语义相似度
    • 2.3 什么是Bi-Encoder(双编码器)
      • 2.3.1 工作原理
      • 2.3.2 优缺点
      • 2.3.3 思路演示
      • 2.3.4 代码实现
      • 2.3.5 实现原理
  • 3. CrossEncoder vs Bi-Encoder:核心区别
    • 3.1 架构对比
      • 3.1.1 Bi-Encoder架构
      • 3.1.2 CrossEncoder架构
    • 3.2 关键差异总结
    • 3.3 为什么CrossEncoder更准确
  • 4. CrossEncoder的工作原理
    • 4.1 输入处理
    • 4.2 模型结构
    • 4.3 工作流程
  • 5. 实战:使用CrossEncoder进行文档重排序
    • 5.1 环境准备
    • 5.2 示例1:基础使用 - 计算文本相关性
    • 5.3 示例2:批量处理多个文档
    • 5.4 示例3:完整的RAG重排序流程
    • 5.5 示例4:问答系统中的答案选择
  • 6. 何时使用CrossEncoder
    • 7.1 适合使用CrossEncoder的场景
    • 7.2 不适合使用CrossEncoder的场景
  • 7. 常用预训练模型推荐
    • 7.1 模型选择指南
    • 7.2 模型使用示例
  • 8. 最佳实践和技巧
    • 8.1 两阶段检索策略
    • 8.2 批量处理优化
    • 8.3 分数阈值过滤
  • 9. 常见问题解答
    • 9.1 Q1: CrossEncoder和Bi-Encoder可以一起使用吗?
    • 9.2 Q2: CrossEncoder的分数范围是多少?
    • 9.3 Q3: 可以自己训练CrossEncoder吗?
    • 9.4 Q4: CrossEncoder在CPU上运行会很慢吗?
    • 9.5 Q5: 如何选择合适的模型?

1. 什么是CrossEncoder #

CrossEncoder(交叉编码器)是一种用于计算两个文本之间相关性分数的深度学习模型。它的核心特点是:将查询和文档一起输入模型,让它们"深度交互",从而得到更准确的相关性判断。

1.1 为什么需要CrossEncoder #

在信息检索和RAG系统中,我们经常需要判断:

  • 用户查询和文档是否相关?
  • 哪个文档最符合用户的需求?
  • 多个候选答案中哪个最好?

传统的简单方法(如关键词匹配)往往不够准确。CrossEncoder通过深度学习,能够理解语义层面的相关性,给出更精准的判断。

1.2 直观理解 #

想象你在面试候选人:

  • Bi-Encoder方式:分别看简历和岗位要求,然后对比(可能遗漏细节)
  • CrossEncoder方式:把简历和岗位要求放在一起,让它们"对话",全面评估匹配度(更准确)

CrossEncoder就是让查询和文档"深度对话",从而得到更准确的相关性分数。

2. 前置知识 #

在学习CrossEncoder之前,我们需要了解一些基础概念。

2.1 什么是文本编码(Embedding) #

文本编码是将文本转换为数值向量的过程。这些向量能够:

  • 表示文本的语义信息
  • 用于计算文本之间的相似度
  • 作为机器学习模型的输入

简单例子:

# 文本编码的简单示例
文本1: "我喜欢编程"  →  向量1: [0.2, 0.5, -0.1, 0.8, ...]
文本2: "我爱写代码"  →  向量2: [0.3, 0.4, -0.2, 0.7, ...]
# 这两个向量应该很相似,因为它们语义相近

2.2 什么是语义相似度 #

语义相似度是指两个文本在意思上的相似程度,而不仅仅是字面上的相似。

例子:

  • "我喜欢编程" 和 "我爱写代码" → 语义相似度高(意思相同)
  • "我喜欢编程" 和 "编程喜欢我" → 语义相似度低(意思不同,虽然字相同)

CrossEncoder能够很好地理解语义相似度,这是它的核心优势。

2.3 什么是Bi-Encoder(双编码器) #

Bi-Encoder是CrossEncoder的"兄弟",理解它有助于理解CrossEncoder。

2.3.1 工作原理 #

  1. 独立编码:查询和文档分别通过编码器,得到各自的向量
  2. 计算相似度:使用余弦相似度或点积计算两个向量的相似度
  3. 输出分数:相似度分数就是相关性分数

2.3.2 优缺点 #

优点:

  • 速度快:文档可以预先编码并存储,查询时只需编码查询
  • 适合大规模检索:可以快速从百万级文档中找出候选

缺点:

  • 精度相对较低:查询和文档没有直接交互,可能遗漏细节
  • 无法捕捉复杂的语义关系

2.3.3 思路演示 #

# 步骤1:分别编码查询和文档
查询向量 = 编码器("什么是机器学习")
文档1向量 = 编码器("机器学习是人工智能的一个分支")
文档2向量 = 编码器("今天天气很好")

# 步骤2:计算相似度
相似度1 = 计算相似度(查询向量, 文档1向量)  # 应该很高
相似度2 = 计算相似度(查询向量, 文档2向量)  # 应该很低

# 步骤3:根据相似度排序
# 文档1应该排在文档2前面

2.3.4 代码实现 #

# 导入 HuggingFaceEmbeddings,用于文本编码
from langchain_huggingface import HuggingFaceEmbeddings
# 导入 numpy,用于向量运算
import numpy as np

# 初始化文本编码器,运行在 CPU 上
embeddings = HuggingFaceEmbeddings(model_kwargs={"device": "cpu"})

# 定义查询文本
query = "什么是机器学习"
# 定义第一个待匹配文档
doc1 = "机器学习是人工智能的一个分支"
# 定义第二个待匹配文档
doc2 = "今天天气很好"

# 对查询文本进行编码,得到向量表示
query_vector = embeddings.embed_query(query)
# 对文档1进行编码,得到向量表示
doc1_vector = embeddings.embed_query(doc1)
# 对文档2进行编码,得到向量表示
doc2_vector = embeddings.embed_query(doc2)

# 定义余弦相似度计算函数
def cosine_similarity(vec1, vec2):
    # 将输入转为 numpy 数组
    vec1, vec2 = np.array(vec1), np.array(vec2)
    # 计算余弦相似度并返回
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

# 计算查询和文档1的相似度(预计很高)
similarity1 = cosine_similarity(query_vector, doc1_vector)
# 计算查询和文档2的相似度(预计很低)
similarity2 = cosine_similarity(query_vector, doc2_vector)

# 创建包含文档和对应相似度的元组列表
docs = [(doc1, similarity1), (doc2, similarity2)]
# 按相似度从高到低排序
sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True)

# 打印查询内容
print(f"查询: {query}\n")
# 打印文档1内容
print(f"文档1: {doc1}")
# 打印文档1的相似度分数
print(f"相似度1: {similarity1:.4f}\n")
# 打印文档2内容
print(f"文档2: {doc2}")
# 打印文档2的相似度分数
print(f"相似度2: {similarity2:.4f}\n")
# 打印排序结果
print("排序结果:")
# 逐个打印排序后的文档及相似度
for i, (doc, score) in enumerate(sorted_docs, 1):
    print(f"{i}. {doc} (相似度: {score:.4f})")

2.3.5 实现原理 #

# CrossEncoder:同时输入查询和文档,直接得到相关性分数
# 定义CrossEncoder的评分函数(仅基于共同字符数量,不涉及具体词语匹配)
def cross_encoder_score(query, document):
    # CrossEncoder:只考虑查询和文档的共同字符,计算相关性分数
    """CrossEncoder:只考虑查询和文档的共同字符,计算相关性分数"""
    # 将查询转换为字符集合
    query_chars = set(query)
    # 将文档转换为字符集合
    doc_chars = set(document)
    # 获取查询和文档的共同字符集合
    common_chars = query_chars & doc_chars

    # 只计算共同字符数量,并乘以系数0.5作为分数
    char_score = len(common_chars) * 0.5

    # 得到最终分数
    score = char_score
    # 返回分数
    return score

# 定义查询字符串
query = "什么是机器学习"
# 定义第一个文档
doc1 = "机器学习是人工智能的一个分支"
# 定义第二个文档
doc2 = "我喜欢学习机器学习"

# 步骤1:构造查询-文档对,打包成元组列表
pairs = [(query, doc1), (query, doc2)]

# 步骤2:直接计算每个查询-文档对的相关性分数(不需要分别编码)
scores = [cross_encoder_score(q, d) for q, d in pairs]

# 步骤3:将文档和分数打包,并按分数从高到低排序
results = sorted(zip([doc1, doc2], scores), key=lambda x: x[1], reverse=True)

# 打印查询内容
print(f"查询: {query}\n")
# 打印第一个文档内容
print(f"文档1: {doc1}")
# 打印第一个文档分数
print(f"分数1: {scores[0]:.4f}\n")
# 打印第二个文档内容
print(f"文档2: {doc2}")
# 打印第二个文档分数
print(f"分数2: {scores[1]:.4f}\n")
# 打印排序结果标题
print("排序结果:")
# 依次枚举排序结果,输出每个文档及其分数
for i, (doc, score) in enumerate(results, 1):
    print(f"{i}. {doc} (分数: {score:.4f})")

3. CrossEncoder vs Bi-Encoder:核心区别 #

理解两者的区别是掌握CrossEncoder的关键。

3.1 架构对比 #

3.1.1 Bi-Encoder架构 #

查询文本 → [编码器] → 查询向量 ──┐
                                  ├─→ 相似度计算 → 分数
文档文本 → [编码器] → 文档向量 ──┘

特点:查询和文档独立编码,然后计算相似度

3.1.2 CrossEncoder架构 #

查询文本 ──┐
           ├─→ [拼接] → [联合编码器] → 分类/回归头 → 分数
文档文本 ──┘

特点:查询和文档一起编码,直接输出分数

3.2 关键差异总结 #

维度 Bi-Encoder CrossEncoder
编码方式 独立编码 联合编码
交互程度 无交互(仅通过相似度计算) 深度交互(在编码过程中)
精度 中等 极高
速度 极快(可预编码) 较慢(需实时计算)
适用场景 大规模召回(百万级) 小规模精排(<1000)

3.3 为什么CrossEncoder更准确 #

CrossEncoder在编码过程中,查询和文档可以"看到"彼此,模型能够:

  1. 捕捉细粒度匹配:发现"机器学习"和"ML"的对应关系
  2. 理解上下文:理解"苹果"在不同语境下的含义
  3. 识别复杂关系:理解因果关系、对比关系等

这就是为什么CrossEncoder精度更高的原因。

4. CrossEncoder的工作原理 #

4.1 输入处理 #

CrossEncoder将查询和文档拼接成一个序列:

[CLS] 查询文本 [SEP] 文档文本 [SEP]

说明:

  • [CLS]:特殊标记,用于分类任务
  • [SEP]:分隔符,用于分隔查询和文档
  • 整个序列作为一个整体输入模型

4.2 模型结构 #

CrossEncoder通常包含三个部分:

  1. 基础编码器:预训练的Transformer模型(如BERT、RoBERTa)

    • 负责理解文本的语义
    • 输出文本的向量表示
  2. 分类/回归头:

    • 分类头:输出0-1之间的概率分数
    • 回归头:输出连续的相关性分数
  3. 输出:直接得到相关性分数,无需额外的相似度计算

4.3 工作流程 #

输入:[CLS] 查询 [SEP] 文档 [SEP]
  ↓
Tokenizer(分词)
  ↓
Transformer编码器(理解语义)
  ↓
分类/回归头(输出分数)
  ↓
相关性分数(0-1之间或连续值)

5. 实战:使用CrossEncoder进行文档重排序 #

现在让我们通过完整的代码示例,学习如何使用CrossEncoder。

5.1 环境准备 #

首先,我们需要安装必要的库:

# 安装sentence-transformers库
# 在命令行运行: pip install sentence-transformers

# 注意:首次运行会自动下载预训练模型,可能需要一些时间

5.2 示例1:基础使用 - 计算文本相关性 #

这是一个最简单的例子,展示如何使用CrossEncoder计算两个文本的相关性:

# 导入CrossEncoder类
from sentence_transformers import CrossEncoder

# 加载预训练的CrossEncoder模型
# 这个模型专门用于计算文本相关性
# 首次运行会自动从HuggingFace下载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 定义查询和文档
# 查询:用户的问题
query = "什么是机器学习"
# 文档:候选答案或文档内容
document1 = "机器学习是人工智能的一个分支,通过算法让计算机从数据中学习"
document2 = "今天天气很好,适合出去散步"

# 将查询和文档组成文本对
# 格式:[查询, 文档]
pair1 = [query, document1]
pair2 = [query, document2]

# 使用模型预测相关性分数
# 分数越高表示越相关,通常在0-1之间
score1 = model.predict(pair1)
score2 = model.predict(pair2)

# 打印结果
print(f"查询: {query}")
print(f"\n文档1: {document1}")
print(f"相关性分数: {score1:.4f}")
print(f"\n文档2: {document2}")
print(f"相关性分数: {score2:.4f}")

# 判断哪个文档更相关
if score1 > score2:
    print(f"\n文档1更相关(分数高 {score1:.4f} > {score2:.4f})")
else:
    print(f"\n文档2更相关(分数高 {score2:.4f} > {score1:.4f})")

5.3 示例2:批量处理多个文档 #

在实际应用中,我们通常需要同时评估多个文档。CrossEncoder支持批量处理,这样更高效:

# 导入CrossEncoder类
from sentence_transformers import CrossEncoder

# 加载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 定义查询
query = "Python如何读取文件"

# 定义多个候选文档
documents = [
    "Python使用open()函数可以打开和读取文件,例如:f = open('file.txt', 'r')",
    "Python是一种高级编程语言,具有简洁的语法",
    "读取文件时可以使用with语句自动关闭文件,这是最佳实践",
    "今天学习了Python的基础语法,感觉很有趣",
    "文件操作包括读取、写入、追加等模式,'r'表示只读模式"
]

# 创建查询-文档对列表
# 每个元素都是[查询, 文档]的格式
pairs = [[query, doc] for doc in documents]

# 批量预测所有文档的相关性分数
# 返回一个numpy数组,包含每个文档的分数
scores = model.predict(pairs)

# 将文档和分数组合在一起
# zip函数将两个列表组合成元组列表
doc_scores = list(zip(documents, scores))

# 按分数从高到低排序
# reverse=True表示降序排列
doc_scores_sorted = sorted(doc_scores, key=lambda x: x[1], reverse=True)

# 打印排序结果
print(f"查询: {query}\n")
print("=" * 60)
print("文档相关性排序结果(从高到低):")
print("=" * 60)

# 遍历排序后的结果,打印每个文档及其分数
for i, (doc, score) in enumerate(doc_scores_sorted, 1):
    print(f"\n排名 {i} (分数: {score:.4f}):")
    print(f"  {doc}")

5.4 示例3:完整的RAG重排序流程 #

这是一个更完整的例子,模拟RAG系统中的两阶段检索流程:

# 导入必要的库
from sentence_transformers import CrossEncoder
import numpy as np

# 模拟一个简单的Bi-Encoder(实际应用中会使用真实的嵌入模型)
# 这里我们用一个简化的函数来模拟Bi-Encoder的检索结果
def simple_bi_encoder_search(query, all_documents, top_k=10):
    """
    模拟Bi-Encoder的快速检索过程

    参数:
        query: 用户查询
        all_documents: 所有文档列表
        top_k: 返回前k个候选文档

    返回:
        前k个候选文档列表
    """
    # 在实际应用中,这里会使用嵌入模型计算相似度
    # 这里我们简化处理,假设已经按相关性排序
    return all_documents[:top_k]

# 加载CrossEncoder模型
# 这个模型专门用于重排序任务
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 模拟知识库中的文档
# 在实际应用中,这些文档来自你的知识库或数据库
knowledge_base = [
    "Python是一种高级编程语言,由Guido van Rossum创建",
    "Python的语法简洁清晰,适合初学者学习",
    "Python支持多种编程范式,包括面向对象和函数式编程",
    "使用Python的open()函数可以读取文件,例如:with open('file.txt', 'r') as f: content = f.read()",
    "Python的列表、字典、元组是常用的数据结构",
    "机器学习是人工智能的一个分支,通过算法从数据中学习",
    "深度学习使用神经网络来模拟人脑的学习过程",
    "Python的pandas库是数据分析的重要工具",
    "今天天气很好,适合出去散步",
    "Python的装饰器是高级特性,可以增强函数功能",
    "文件操作包括读取、写入、追加等模式",
    "Python的异常处理使用try-except语句"
]

def rag_reranking_pipeline(query, knowledge_base, k_initial=10, k_final=5):
    """
    完整的RAG重排序流程

    参数:
        query: 用户查询
        knowledge_base: 知识库中的所有文档
        k_initial: 第一阶段检索的候选文档数量
        k_final: 最终返回的文档数量

    返回:
        重排序后的文档列表和分数
    """
    print("=" * 60)
    print("RAG两阶段检索流程")
    print("=" * 60)

    # 阶段1:使用Bi-Encoder快速召回候选文档
    # 在实际应用中,这里会使用向量数据库进行快速检索
    print(f"\n【阶段1】Bi-Encoder快速召回(Top {k_initial})")
    candidates = simple_bi_encoder_search(query, knowledge_base, top_k=k_initial)
    print(f"检索到 {len(candidates)} 个候选文档")

    # 阶段2:使用CrossEncoder对候选文档进行精排
    print(f"\n【阶段2】CrossEncoder精排(选出Top {k_final})")

    # 创建查询-文档对
    pairs = [[query, doc] for doc in candidates]

    # 批量计算相关性分数
    scores = cross_encoder.predict(pairs)

    # 将文档和分数组合
    doc_scores = list(zip(candidates, scores))

    # 按分数从高到低排序
    reranked = sorted(doc_scores, key=lambda x: x[1], reverse=True)

    # 返回前k_final个文档
    return reranked[:k_final]

# 测试完整的流程
query = "Python如何读取文件"
results = rag_reranking_pipeline(query, knowledge_base, k_initial=10, k_final=5)

# 打印最终结果
print("\n" + "=" * 60)
print("最终重排序结果")
print("=" * 60)
for i, (doc, score) in enumerate(results, 1):
    print(f"\n排名 {i} (分数: {score:.4f}):")
    print(f"  {doc}")

5.5 示例4:问答系统中的答案选择 #

CrossEncoder也可以用于从多个候选答案中选择最佳答案:

# 导入CrossEncoder类
from sentence_transformers import CrossEncoder
# 导入numpy用于数组操作
import numpy as np

# 加载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 定义问题
question = "Python中如何读取文件?"

# 定义多个候选答案
candidate_answers = [
    "使用open()函数,例如:f = open('file.txt', 'r'); content = f.read(); f.close()",
    "Python是一种编程语言",
    "可以使用with语句:with open('file.txt', 'r') as f: content = f.read()",
    "文件读取是编程中的基本操作",
    "read()方法可以读取文件的全部内容"
]

def select_best_answer(question, candidate_answers, model):
    """
    从多个候选答案中选择最佳答案

    参数:
        question: 问题
        candidate_answers: 候选答案列表
        model: CrossEncoder模型

    返回:
        最佳答案和分数
    """
    # 创建问题-答案对
    pairs = [[question, answer] for answer in candidate_answers]

    # 计算每个答案的相关性分数
    scores = model.predict(pairs)

    # 找到分数最高的答案的索引
    best_idx = np.argmax(scores)

    # 返回最佳答案和分数
    return candidate_answers[best_idx], scores[best_idx]

# 选择最佳答案
best_answer, best_score = select_best_answer(question, candidate_answers, model)

# 打印结果
print(f"问题: {question}\n")
print("=" * 60)
print("所有候选答案及其分数:")
print("=" * 60)

# 创建问题-答案对并计算分数
pairs = [[question, answer] for answer in candidate_answers]
scores = model.predict(pairs)

# 打印所有答案和分数
for i, (answer, score) in enumerate(zip(candidate_answers, scores), 1):
    marker = " 最佳答案" if score == best_score else ""
    print(f"\n答案 {i} (分数: {score:.4f}){marker}:")
    print(f"  {answer}")

print(f"\n" + "=" * 60)
print(f"选择的最佳答案(分数: {best_score:.4f}):")
print(f"  {best_answer}")

6. 何时使用CrossEncoder #

7.1 适合使用CrossEncoder的场景 #

  1. 候选文档数量较少(通常 < 1000)

    • CrossEncoder需要为每个查询-文档对计算,数量太多会很慢
    • 适合在Bi-Encoder召回后的精排阶段使用
  2. 对精度要求很高

    • 需要准确判断文档相关性
    • 答案选择、重排序等关键任务
  3. 可以接受稍高的延迟

    • CrossEncoder比Bi-Encoder慢,但通常仍在可接受范围内
    • 对于非实时场景,这个延迟是可以接受的
  4. 有GPU资源

    • CrossEncoder在GPU上运行会快很多
    • 如果没有GPU,CPU也可以运行,只是会慢一些

7.2 不适合使用CrossEncoder的场景 #

  1. 需要实时响应(< 100ms)

    • CrossEncoder计算需要时间,可能无法满足极低延迟要求
  2. 候选文档数量巨大(> 10000)

    • 计算量会非常大,耗时过长
    • 应该先用Bi-Encoder缩小候选集
  3. 计算资源非常有限

    • CrossEncoder需要加载模型,占用内存
    • 如果资源紧张,可能无法运行
  4. 只需要粗粒度排序

    • 如果精度要求不高,Bi-Encoder可能就足够了

7. 常用预训练模型推荐 #

7.1 模型选择指南 #

不同的模型有不同的特点,选择合适的模型很重要:

模型名称 适用场景 参数量 特点
cross-encoder/ms-marco-MiniLM-L-6-v2 通用重排序 22M ⭐ 推荐新手使用,速度快,精度好
BAAI/bge-reranker-base 中文/英文重排序 110M 支持中文,适合中文场景
cross-encoder/stsb-roberta-large 语义相似度 355M 精度高,但速度慢
cross-encoder/nli-deberta-v3-base 自然语言推理 184M 逻辑推理能力强

7.2 模型使用示例 #

# 导入CrossEncoder
from sentence_transformers import CrossEncoder

# 示例1:使用通用重排序模型(推荐新手)
model1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
print("模型1加载成功:通用重排序模型")

# 示例2:使用中文重排序模型(如果处理中文)
# model2 = CrossEncoder('BAAI/bge-reranker-base')
# print("模型2加载成功:中文重排序模型")

# 测试模型
query = "什么是Python"
document = "Python是一种高级编程语言"
score = model1.predict([[query, document]])
print(f"相关性分数: {score[0]:.4f}")

建议:

  • 新手推荐使用 cross-encoder/ms-marco-MiniLM-L-6-v2
  • 处理中文内容时,考虑使用 BAAI/bge-reranker-base
  • 根据实际需求选择合适的模型

8. 最佳实践和技巧 #

8.1 两阶段检索策略 #

这是最常用的策略,结合Bi-Encoder和CrossEncoder的优势:

# 两阶段检索的完整示例
from sentence_transformers import CrossEncoder, SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# 加载模型
bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

# 知识库文档
documents = [
    "Python是一种高级编程语言",
    "机器学习是人工智能的分支",
    "使用Python可以读取文件",
    "今天天气很好",
    "深度学习使用神经网络"
]

def two_stage_retrieval(query, documents, k_initial=10, k_final=5):
    """
    两阶段检索:Bi-Encoder召回 + CrossEncoder精排

    参数:
        query: 查询
        documents: 文档列表
        k_initial: 第一阶段召回数量
        k_final: 最终返回数量
    """
    # 阶段1:Bi-Encoder快速召回
    print("阶段1:Bi-Encoder快速召回...")
    query_emb = bi_encoder.encode(query)
    doc_embs = bi_encoder.encode(documents)
    similarities = cosine_similarity([query_emb], doc_embs)[0]

    # 获取Top k_initial个候选
    top_indices = np.argsort(similarities)[::-1][:k_initial]
    candidates = [documents[i] for i in top_indices]
    print(f"召回 {len(candidates)} 个候选文档")

    # 阶段2:CrossEncoder精排
    print("阶段2:CrossEncoder精排...")
    pairs = [[query, doc] for doc in candidates]
    scores = cross_encoder.predict(pairs)

    # 排序并返回Top k_final
    doc_scores = list(zip(candidates, scores))
    reranked = sorted(doc_scores, key=lambda x: x[1], reverse=True)

    return reranked[:k_final]

# 测试
query = "Python如何读取文件"
results = two_stage_retrieval(query, documents, k_initial=5, k_final=3)

print("\n最终结果:")
for i, (doc, score) in enumerate(results, 1):
    print(f"{i}. (分数: {score:.4f}) {doc}")

8.2 批量处理优化 #

批量处理可以显著提高效率:

# 批量处理的优势示例
from sentence_transformers import CrossEncoder
import time

model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
query = "什么是机器学习"
documents = [f"文档{i}:关于机器学习的相关内容" for i in range(20)]

# 方式1:逐个处理(慢)
print("方式1:逐个处理")
start_time = time.time()
scores1 = []
for doc in documents:
    score = model.predict([[query, doc]])[0]
    scores1.append(score)
time1 = time.time() - start_time
print(f"耗时: {time1:.4f} 秒")

# 方式2:批量处理(快)
print("\n方式2:批量处理")
start_time = time.time()
pairs = [[query, doc] for doc in documents]
scores2 = model.predict(pairs)
time2 = time.time() - start_time
print(f"耗时: {time2:.4f} 秒")

print(f"\n性能提升: {time1/time2:.2f}x 倍")
print("建议:尽量使用批量处理")

8.3 分数阈值过滤 #

可以根据分数阈值过滤掉明显不相关的文档:

# 使用分数阈值过滤不相关文档
from sentence_transformers import CrossEncoder

model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
query = "Python如何读取文件"
documents = [
    "使用open()函数可以读取文件",
    "Python是一种编程语言",
    "今天天气很好",
    "文件读取使用with语句更安全",
    "机器学习很有趣"
]

# 设置分数阈值(低于此分数的文档被认为不相关)
threshold = 0.3

# 计算分数
pairs = [[query, doc] for doc in documents]
scores = model.predict(pairs)

# 过滤并排序
filtered_results = []
for doc, score in zip(documents, scores):
    if score >= threshold:
        filtered_results.append((doc, score))

# 按分数排序
filtered_results.sort(key=lambda x: x[1], reverse=True)

print(f"查询: {query}")
print(f"阈值: {threshold}")
print(f"\n过滤前: {len(documents)} 个文档")
print(f"过滤后: {len(filtered_results)} 个文档")
print("\n相关文档(分数 >= 阈值):")
for i, (doc, score) in enumerate(filtered_results, 1):
    print(f"  {i}. (分数: {score:.4f}) {doc}")

9. 常见问题解答 #

9.1 Q1: CrossEncoder和Bi-Encoder可以一起使用吗? #

A: 可以!这是最佳实践。先用Bi-Encoder快速召回大量候选(如1000个),再用CrossEncoder对候选进行精排(选出Top 10),这样既保证了速度,又保证了精度。

9.2 Q2: CrossEncoder的分数范围是多少? #

A: 这取决于模型。大多数模型输出0-1之间的分数,分数越高表示越相关。但有些模型可能输出其他范围的分数,使用时需要注意。

9.3 Q3: 可以自己训练CrossEncoder吗? #

A: 可以,但需要大量的标注数据。对于大多数应用场景,使用预训练模型就足够了。只有在有特定领域需求且数据充足时,才考虑自己训练。

9.4 Q4: CrossEncoder在CPU上运行会很慢吗? #

A: 对于少量文档(< 100),CPU运行速度还可以接受。但如果文档数量较多,建议使用GPU,速度会快很多。

9.5 Q5: 如何选择合适的模型? #

A:

  • 新手推荐:cross-encoder/ms-marco-MiniLM-L-6-v2(通用、快速)
  • 中文场景:BAAI/bge-reranker-base
  • 高精度需求:cross-encoder/stsb-roberta-large(但速度较慢)
← 上一节 cosine_similarity 下一节 dataclasses →

访问验证

请输入访问令牌

Token不正确,请重新输入