1. 什么是CrossEncoder #
CrossEncoder(交叉编码器)是一种用于计算两个文本之间相关性分数的深度学习模型。它的核心特点是:将查询和文档一起输入模型,让它们"深度交互",从而得到更准确的相关性判断。
1.1 为什么需要CrossEncoder #
在信息检索和RAG系统中,我们经常需要判断:
- 用户查询和文档是否相关?
- 哪个文档最符合用户的需求?
- 多个候选答案中哪个最好?
传统的简单方法(如关键词匹配)往往不够准确。CrossEncoder通过深度学习,能够理解语义层面的相关性,给出更精准的判断。
1.2 直观理解 #
想象你在面试候选人:
- Bi-Encoder方式:分别看简历和岗位要求,然后对比(可能遗漏细节)
- CrossEncoder方式:把简历和岗位要求放在一起,让它们"对话",全面评估匹配度(更准确)
CrossEncoder就是让查询和文档"深度对话",从而得到更准确的相关性分数。
2. 前置知识 #
在学习CrossEncoder之前,我们需要了解一些基础概念。
2.1 什么是文本编码(Embedding) #
文本编码是将文本转换为数值向量的过程。这些向量能够:
- 表示文本的语义信息
- 用于计算文本之间的相似度
- 作为机器学习模型的输入
简单例子:
# 文本编码的简单示例
文本1: "我喜欢编程" → 向量1: [0.2, 0.5, -0.1, 0.8, ...]
文本2: "我爱写代码" → 向量2: [0.3, 0.4, -0.2, 0.7, ...]
# 这两个向量应该很相似,因为它们语义相近2.2 什么是语义相似度 #
语义相似度是指两个文本在意思上的相似程度,而不仅仅是字面上的相似。
例子:
- "我喜欢编程" 和 "我爱写代码" → 语义相似度高(意思相同)
- "我喜欢编程" 和 "编程喜欢我" → 语义相似度低(意思不同,虽然字相同)
CrossEncoder能够很好地理解语义相似度,这是它的核心优势。
2.3 什么是Bi-Encoder(双编码器) #
Bi-Encoder是CrossEncoder的"兄弟",理解它有助于理解CrossEncoder。
2.3.1 工作原理 #
- 独立编码:查询和文档分别通过编码器,得到各自的向量
- 计算相似度:使用余弦相似度或点积计算两个向量的相似度
- 输出分数:相似度分数就是相关性分数
2.3.2 优缺点 #
优点:
- 速度快:文档可以预先编码并存储,查询时只需编码查询
- 适合大规模检索:可以快速从百万级文档中找出候选
缺点:
- 精度相对较低:查询和文档没有直接交互,可能遗漏细节
- 无法捕捉复杂的语义关系
2.3.3 思路演示 #
# 步骤1:分别编码查询和文档
查询向量 = 编码器("什么是机器学习")
文档1向量 = 编码器("机器学习是人工智能的一个分支")
文档2向量 = 编码器("今天天气很好")
# 步骤2:计算相似度
相似度1 = 计算相似度(查询向量, 文档1向量) # 应该很高
相似度2 = 计算相似度(查询向量, 文档2向量) # 应该很低
# 步骤3:根据相似度排序
# 文档1应该排在文档2前面2.3.4 代码实现 #
# 导入 HuggingFaceEmbeddings,用于文本编码
from langchain_huggingface import HuggingFaceEmbeddings
# 导入 numpy,用于向量运算
import numpy as np
# 初始化文本编码器,运行在 CPU 上
embeddings = HuggingFaceEmbeddings(model_kwargs={"device": "cpu"})
# 定义查询文本
query = "什么是机器学习"
# 定义第一个待匹配文档
doc1 = "机器学习是人工智能的一个分支"
# 定义第二个待匹配文档
doc2 = "今天天气很好"
# 对查询文本进行编码,得到向量表示
query_vector = embeddings.embed_query(query)
# 对文档1进行编码,得到向量表示
doc1_vector = embeddings.embed_query(doc1)
# 对文档2进行编码,得到向量表示
doc2_vector = embeddings.embed_query(doc2)
# 定义余弦相似度计算函数
def cosine_similarity(vec1, vec2):
# 将输入转为 numpy 数组
vec1, vec2 = np.array(vec1), np.array(vec2)
# 计算余弦相似度并返回
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# 计算查询和文档1的相似度(预计很高)
similarity1 = cosine_similarity(query_vector, doc1_vector)
# 计算查询和文档2的相似度(预计很低)
similarity2 = cosine_similarity(query_vector, doc2_vector)
# 创建包含文档和对应相似度的元组列表
docs = [(doc1, similarity1), (doc2, similarity2)]
# 按相似度从高到低排序
sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True)
# 打印查询内容
print(f"查询: {query}\n")
# 打印文档1内容
print(f"文档1: {doc1}")
# 打印文档1的相似度分数
print(f"相似度1: {similarity1:.4f}\n")
# 打印文档2内容
print(f"文档2: {doc2}")
# 打印文档2的相似度分数
print(f"相似度2: {similarity2:.4f}\n")
# 打印排序结果
print("排序结果:")
# 逐个打印排序后的文档及相似度
for i, (doc, score) in enumerate(sorted_docs, 1):
print(f"{i}. {doc} (相似度: {score:.4f})")2.3.5 实现原理 #
# CrossEncoder:同时输入查询和文档,直接得到相关性分数
# 定义CrossEncoder的评分函数(仅基于共同字符数量,不涉及具体词语匹配)
def cross_encoder_score(query, document):
# CrossEncoder:只考虑查询和文档的共同字符,计算相关性分数
"""CrossEncoder:只考虑查询和文档的共同字符,计算相关性分数"""
# 将查询转换为字符集合
query_chars = set(query)
# 将文档转换为字符集合
doc_chars = set(document)
# 获取查询和文档的共同字符集合
common_chars = query_chars & doc_chars
# 只计算共同字符数量,并乘以系数0.5作为分数
char_score = len(common_chars) * 0.5
# 得到最终分数
score = char_score
# 返回分数
return score
# 定义查询字符串
query = "什么是机器学习"
# 定义第一个文档
doc1 = "机器学习是人工智能的一个分支"
# 定义第二个文档
doc2 = "我喜欢学习机器学习"
# 步骤1:构造查询-文档对,打包成元组列表
pairs = [(query, doc1), (query, doc2)]
# 步骤2:直接计算每个查询-文档对的相关性分数(不需要分别编码)
scores = [cross_encoder_score(q, d) for q, d in pairs]
# 步骤3:将文档和分数打包,并按分数从高到低排序
results = sorted(zip([doc1, doc2], scores), key=lambda x: x[1], reverse=True)
# 打印查询内容
print(f"查询: {query}\n")
# 打印第一个文档内容
print(f"文档1: {doc1}")
# 打印第一个文档分数
print(f"分数1: {scores[0]:.4f}\n")
# 打印第二个文档内容
print(f"文档2: {doc2}")
# 打印第二个文档分数
print(f"分数2: {scores[1]:.4f}\n")
# 打印排序结果标题
print("排序结果:")
# 依次枚举排序结果,输出每个文档及其分数
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. {doc} (分数: {score:.4f})")
3. CrossEncoder vs Bi-Encoder:核心区别 #
理解两者的区别是掌握CrossEncoder的关键。
3.1 架构对比 #
3.1.1 Bi-Encoder架构 #
查询文本 → [编码器] → 查询向量 ──┐
├─→ 相似度计算 → 分数
文档文本 → [编码器] → 文档向量 ──┘特点:查询和文档独立编码,然后计算相似度
3.1.2 CrossEncoder架构 #
查询文本 ──┐
├─→ [拼接] → [联合编码器] → 分类/回归头 → 分数
文档文本 ──┘特点:查询和文档一起编码,直接输出分数
3.2 关键差异总结 #
| 维度 | Bi-Encoder | CrossEncoder |
|---|---|---|
| 编码方式 | 独立编码 | 联合编码 |
| 交互程度 | 无交互(仅通过相似度计算) | 深度交互(在编码过程中) |
| 精度 | 中等 | 极高 |
| 速度 | 极快(可预编码) | 较慢(需实时计算) |
| 适用场景 | 大规模召回(百万级) | 小规模精排(<1000) |
3.3 为什么CrossEncoder更准确 #
CrossEncoder在编码过程中,查询和文档可以"看到"彼此,模型能够:
- 捕捉细粒度匹配:发现"机器学习"和"ML"的对应关系
- 理解上下文:理解"苹果"在不同语境下的含义
- 识别复杂关系:理解因果关系、对比关系等
这就是为什么CrossEncoder精度更高的原因。
4. CrossEncoder的工作原理 #
4.1 输入处理 #
CrossEncoder将查询和文档拼接成一个序列:
[CLS] 查询文本 [SEP] 文档文本 [SEP]说明:
[CLS]:特殊标记,用于分类任务[SEP]:分隔符,用于分隔查询和文档- 整个序列作为一个整体输入模型
4.2 模型结构 #
CrossEncoder通常包含三个部分:
基础编码器:预训练的Transformer模型(如BERT、RoBERTa)
- 负责理解文本的语义
- 输出文本的向量表示
分类/回归头:
- 分类头:输出0-1之间的概率分数
- 回归头:输出连续的相关性分数
输出:直接得到相关性分数,无需额外的相似度计算
4.3 工作流程 #
输入:[CLS] 查询 [SEP] 文档 [SEP]
↓
Tokenizer(分词)
↓
Transformer编码器(理解语义)
↓
分类/回归头(输出分数)
↓
相关性分数(0-1之间或连续值)5. 实战:使用CrossEncoder进行文档重排序 #
现在让我们通过完整的代码示例,学习如何使用CrossEncoder。
5.1 环境准备 #
首先,我们需要安装必要的库:
# 安装sentence-transformers库
# 在命令行运行: pip install sentence-transformers
# 注意:首次运行会自动下载预训练模型,可能需要一些时间5.2 示例1:基础使用 - 计算文本相关性 #
这是一个最简单的例子,展示如何使用CrossEncoder计算两个文本的相关性:
# 导入CrossEncoder类
from sentence_transformers import CrossEncoder
# 加载预训练的CrossEncoder模型
# 这个模型专门用于计算文本相关性
# 首次运行会自动从HuggingFace下载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 定义查询和文档
# 查询:用户的问题
query = "什么是机器学习"
# 文档:候选答案或文档内容
document1 = "机器学习是人工智能的一个分支,通过算法让计算机从数据中学习"
document2 = "今天天气很好,适合出去散步"
# 将查询和文档组成文本对
# 格式:[查询, 文档]
pair1 = [query, document1]
pair2 = [query, document2]
# 使用模型预测相关性分数
# 分数越高表示越相关,通常在0-1之间
score1 = model.predict(pair1)
score2 = model.predict(pair2)
# 打印结果
print(f"查询: {query}")
print(f"\n文档1: {document1}")
print(f"相关性分数: {score1:.4f}")
print(f"\n文档2: {document2}")
print(f"相关性分数: {score2:.4f}")
# 判断哪个文档更相关
if score1 > score2:
print(f"\n文档1更相关(分数高 {score1:.4f} > {score2:.4f})")
else:
print(f"\n文档2更相关(分数高 {score2:.4f} > {score1:.4f})")5.3 示例2:批量处理多个文档 #
在实际应用中,我们通常需要同时评估多个文档。CrossEncoder支持批量处理,这样更高效:
# 导入CrossEncoder类
from sentence_transformers import CrossEncoder
# 加载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 定义查询
query = "Python如何读取文件"
# 定义多个候选文档
documents = [
"Python使用open()函数可以打开和读取文件,例如:f = open('file.txt', 'r')",
"Python是一种高级编程语言,具有简洁的语法",
"读取文件时可以使用with语句自动关闭文件,这是最佳实践",
"今天学习了Python的基础语法,感觉很有趣",
"文件操作包括读取、写入、追加等模式,'r'表示只读模式"
]
# 创建查询-文档对列表
# 每个元素都是[查询, 文档]的格式
pairs = [[query, doc] for doc in documents]
# 批量预测所有文档的相关性分数
# 返回一个numpy数组,包含每个文档的分数
scores = model.predict(pairs)
# 将文档和分数组合在一起
# zip函数将两个列表组合成元组列表
doc_scores = list(zip(documents, scores))
# 按分数从高到低排序
# reverse=True表示降序排列
doc_scores_sorted = sorted(doc_scores, key=lambda x: x[1], reverse=True)
# 打印排序结果
print(f"查询: {query}\n")
print("=" * 60)
print("文档相关性排序结果(从高到低):")
print("=" * 60)
# 遍历排序后的结果,打印每个文档及其分数
for i, (doc, score) in enumerate(doc_scores_sorted, 1):
print(f"\n排名 {i} (分数: {score:.4f}):")
print(f" {doc}")5.4 示例3:完整的RAG重排序流程 #
这是一个更完整的例子,模拟RAG系统中的两阶段检索流程:
# 导入必要的库
from sentence_transformers import CrossEncoder
import numpy as np
# 模拟一个简单的Bi-Encoder(实际应用中会使用真实的嵌入模型)
# 这里我们用一个简化的函数来模拟Bi-Encoder的检索结果
def simple_bi_encoder_search(query, all_documents, top_k=10):
"""
模拟Bi-Encoder的快速检索过程
参数:
query: 用户查询
all_documents: 所有文档列表
top_k: 返回前k个候选文档
返回:
前k个候选文档列表
"""
# 在实际应用中,这里会使用嵌入模型计算相似度
# 这里我们简化处理,假设已经按相关性排序
return all_documents[:top_k]
# 加载CrossEncoder模型
# 这个模型专门用于重排序任务
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 模拟知识库中的文档
# 在实际应用中,这些文档来自你的知识库或数据库
knowledge_base = [
"Python是一种高级编程语言,由Guido van Rossum创建",
"Python的语法简洁清晰,适合初学者学习",
"Python支持多种编程范式,包括面向对象和函数式编程",
"使用Python的open()函数可以读取文件,例如:with open('file.txt', 'r') as f: content = f.read()",
"Python的列表、字典、元组是常用的数据结构",
"机器学习是人工智能的一个分支,通过算法从数据中学习",
"深度学习使用神经网络来模拟人脑的学习过程",
"Python的pandas库是数据分析的重要工具",
"今天天气很好,适合出去散步",
"Python的装饰器是高级特性,可以增强函数功能",
"文件操作包括读取、写入、追加等模式",
"Python的异常处理使用try-except语句"
]
def rag_reranking_pipeline(query, knowledge_base, k_initial=10, k_final=5):
"""
完整的RAG重排序流程
参数:
query: 用户查询
knowledge_base: 知识库中的所有文档
k_initial: 第一阶段检索的候选文档数量
k_final: 最终返回的文档数量
返回:
重排序后的文档列表和分数
"""
print("=" * 60)
print("RAG两阶段检索流程")
print("=" * 60)
# 阶段1:使用Bi-Encoder快速召回候选文档
# 在实际应用中,这里会使用向量数据库进行快速检索
print(f"\n【阶段1】Bi-Encoder快速召回(Top {k_initial})")
candidates = simple_bi_encoder_search(query, knowledge_base, top_k=k_initial)
print(f"检索到 {len(candidates)} 个候选文档")
# 阶段2:使用CrossEncoder对候选文档进行精排
print(f"\n【阶段2】CrossEncoder精排(选出Top {k_final})")
# 创建查询-文档对
pairs = [[query, doc] for doc in candidates]
# 批量计算相关性分数
scores = cross_encoder.predict(pairs)
# 将文档和分数组合
doc_scores = list(zip(candidates, scores))
# 按分数从高到低排序
reranked = sorted(doc_scores, key=lambda x: x[1], reverse=True)
# 返回前k_final个文档
return reranked[:k_final]
# 测试完整的流程
query = "Python如何读取文件"
results = rag_reranking_pipeline(query, knowledge_base, k_initial=10, k_final=5)
# 打印最终结果
print("\n" + "=" * 60)
print("最终重排序结果")
print("=" * 60)
for i, (doc, score) in enumerate(results, 1):
print(f"\n排名 {i} (分数: {score:.4f}):")
print(f" {doc}")5.5 示例4:问答系统中的答案选择 #
CrossEncoder也可以用于从多个候选答案中选择最佳答案:
# 导入CrossEncoder类
from sentence_transformers import CrossEncoder
# 导入numpy用于数组操作
import numpy as np
# 加载模型
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 定义问题
question = "Python中如何读取文件?"
# 定义多个候选答案
candidate_answers = [
"使用open()函数,例如:f = open('file.txt', 'r'); content = f.read(); f.close()",
"Python是一种编程语言",
"可以使用with语句:with open('file.txt', 'r') as f: content = f.read()",
"文件读取是编程中的基本操作",
"read()方法可以读取文件的全部内容"
]
def select_best_answer(question, candidate_answers, model):
"""
从多个候选答案中选择最佳答案
参数:
question: 问题
candidate_answers: 候选答案列表
model: CrossEncoder模型
返回:
最佳答案和分数
"""
# 创建问题-答案对
pairs = [[question, answer] for answer in candidate_answers]
# 计算每个答案的相关性分数
scores = model.predict(pairs)
# 找到分数最高的答案的索引
best_idx = np.argmax(scores)
# 返回最佳答案和分数
return candidate_answers[best_idx], scores[best_idx]
# 选择最佳答案
best_answer, best_score = select_best_answer(question, candidate_answers, model)
# 打印结果
print(f"问题: {question}\n")
print("=" * 60)
print("所有候选答案及其分数:")
print("=" * 60)
# 创建问题-答案对并计算分数
pairs = [[question, answer] for answer in candidate_answers]
scores = model.predict(pairs)
# 打印所有答案和分数
for i, (answer, score) in enumerate(zip(candidate_answers, scores), 1):
marker = " 最佳答案" if score == best_score else ""
print(f"\n答案 {i} (分数: {score:.4f}){marker}:")
print(f" {answer}")
print(f"\n" + "=" * 60)
print(f"选择的最佳答案(分数: {best_score:.4f}):")
print(f" {best_answer}")6. 何时使用CrossEncoder #
7.1 适合使用CrossEncoder的场景 #
候选文档数量较少(通常 < 1000)
- CrossEncoder需要为每个查询-文档对计算,数量太多会很慢
- 适合在Bi-Encoder召回后的精排阶段使用
对精度要求很高
- 需要准确判断文档相关性
- 答案选择、重排序等关键任务
可以接受稍高的延迟
- CrossEncoder比Bi-Encoder慢,但通常仍在可接受范围内
- 对于非实时场景,这个延迟是可以接受的
有GPU资源
- CrossEncoder在GPU上运行会快很多
- 如果没有GPU,CPU也可以运行,只是会慢一些
7.2 不适合使用CrossEncoder的场景 #
需要实时响应(< 100ms)
- CrossEncoder计算需要时间,可能无法满足极低延迟要求
候选文档数量巨大(> 10000)
- 计算量会非常大,耗时过长
- 应该先用Bi-Encoder缩小候选集
计算资源非常有限
- CrossEncoder需要加载模型,占用内存
- 如果资源紧张,可能无法运行
只需要粗粒度排序
- 如果精度要求不高,Bi-Encoder可能就足够了
7. 常用预训练模型推荐 #
7.1 模型选择指南 #
不同的模型有不同的特点,选择合适的模型很重要:
| 模型名称 | 适用场景 | 参数量 | 特点 |
|---|---|---|---|
cross-encoder/ms-marco-MiniLM-L-6-v2 |
通用重排序 | 22M | ⭐ 推荐新手使用,速度快,精度好 |
BAAI/bge-reranker-base |
中文/英文重排序 | 110M | 支持中文,适合中文场景 |
cross-encoder/stsb-roberta-large |
语义相似度 | 355M | 精度高,但速度慢 |
cross-encoder/nli-deberta-v3-base |
自然语言推理 | 184M | 逻辑推理能力强 |
7.2 模型使用示例 #
# 导入CrossEncoder
from sentence_transformers import CrossEncoder
# 示例1:使用通用重排序模型(推荐新手)
model1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
print("模型1加载成功:通用重排序模型")
# 示例2:使用中文重排序模型(如果处理中文)
# model2 = CrossEncoder('BAAI/bge-reranker-base')
# print("模型2加载成功:中文重排序模型")
# 测试模型
query = "什么是Python"
document = "Python是一种高级编程语言"
score = model1.predict([[query, document]])
print(f"相关性分数: {score[0]:.4f}")建议:
- 新手推荐使用
cross-encoder/ms-marco-MiniLM-L-6-v2 - 处理中文内容时,考虑使用
BAAI/bge-reranker-base - 根据实际需求选择合适的模型
8. 最佳实践和技巧 #
8.1 两阶段检索策略 #
这是最常用的策略,结合Bi-Encoder和CrossEncoder的优势:
# 两阶段检索的完整示例
from sentence_transformers import CrossEncoder, SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# 加载模型
bi_encoder = SentenceTransformer('all-MiniLM-L6-v2')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# 知识库文档
documents = [
"Python是一种高级编程语言",
"机器学习是人工智能的分支",
"使用Python可以读取文件",
"今天天气很好",
"深度学习使用神经网络"
]
def two_stage_retrieval(query, documents, k_initial=10, k_final=5):
"""
两阶段检索:Bi-Encoder召回 + CrossEncoder精排
参数:
query: 查询
documents: 文档列表
k_initial: 第一阶段召回数量
k_final: 最终返回数量
"""
# 阶段1:Bi-Encoder快速召回
print("阶段1:Bi-Encoder快速召回...")
query_emb = bi_encoder.encode(query)
doc_embs = bi_encoder.encode(documents)
similarities = cosine_similarity([query_emb], doc_embs)[0]
# 获取Top k_initial个候选
top_indices = np.argsort(similarities)[::-1][:k_initial]
candidates = [documents[i] for i in top_indices]
print(f"召回 {len(candidates)} 个候选文档")
# 阶段2:CrossEncoder精排
print("阶段2:CrossEncoder精排...")
pairs = [[query, doc] for doc in candidates]
scores = cross_encoder.predict(pairs)
# 排序并返回Top k_final
doc_scores = list(zip(candidates, scores))
reranked = sorted(doc_scores, key=lambda x: x[1], reverse=True)
return reranked[:k_final]
# 测试
query = "Python如何读取文件"
results = two_stage_retrieval(query, documents, k_initial=5, k_final=3)
print("\n最终结果:")
for i, (doc, score) in enumerate(results, 1):
print(f"{i}. (分数: {score:.4f}) {doc}")8.2 批量处理优化 #
批量处理可以显著提高效率:
# 批量处理的优势示例
from sentence_transformers import CrossEncoder
import time
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
query = "什么是机器学习"
documents = [f"文档{i}:关于机器学习的相关内容" for i in range(20)]
# 方式1:逐个处理(慢)
print("方式1:逐个处理")
start_time = time.time()
scores1 = []
for doc in documents:
score = model.predict([[query, doc]])[0]
scores1.append(score)
time1 = time.time() - start_time
print(f"耗时: {time1:.4f} 秒")
# 方式2:批量处理(快)
print("\n方式2:批量处理")
start_time = time.time()
pairs = [[query, doc] for doc in documents]
scores2 = model.predict(pairs)
time2 = time.time() - start_time
print(f"耗时: {time2:.4f} 秒")
print(f"\n性能提升: {time1/time2:.2f}x 倍")
print("建议:尽量使用批量处理")8.3 分数阈值过滤 #
可以根据分数阈值过滤掉明显不相关的文档:
# 使用分数阈值过滤不相关文档
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
query = "Python如何读取文件"
documents = [
"使用open()函数可以读取文件",
"Python是一种编程语言",
"今天天气很好",
"文件读取使用with语句更安全",
"机器学习很有趣"
]
# 设置分数阈值(低于此分数的文档被认为不相关)
threshold = 0.3
# 计算分数
pairs = [[query, doc] for doc in documents]
scores = model.predict(pairs)
# 过滤并排序
filtered_results = []
for doc, score in zip(documents, scores):
if score >= threshold:
filtered_results.append((doc, score))
# 按分数排序
filtered_results.sort(key=lambda x: x[1], reverse=True)
print(f"查询: {query}")
print(f"阈值: {threshold}")
print(f"\n过滤前: {len(documents)} 个文档")
print(f"过滤后: {len(filtered_results)} 个文档")
print("\n相关文档(分数 >= 阈值):")
for i, (doc, score) in enumerate(filtered_results, 1):
print(f" {i}. (分数: {score:.4f}) {doc}")9. 常见问题解答 #
9.1 Q1: CrossEncoder和Bi-Encoder可以一起使用吗? #
A: 可以!这是最佳实践。先用Bi-Encoder快速召回大量候选(如1000个),再用CrossEncoder对候选进行精排(选出Top 10),这样既保证了速度,又保证了精度。
9.2 Q2: CrossEncoder的分数范围是多少? #
A: 这取决于模型。大多数模型输出0-1之间的分数,分数越高表示越相关。但有些模型可能输出其他范围的分数,使用时需要注意。
9.3 Q3: 可以自己训练CrossEncoder吗? #
A: 可以,但需要大量的标注数据。对于大多数应用场景,使用预训练模型就足够了。只有在有特定领域需求且数据充足时,才考虑自己训练。
9.4 Q4: CrossEncoder在CPU上运行会很慢吗? #
A: 对于少量文档(< 100),CPU运行速度还可以接受。但如果文档数量较多,建议使用GPU,速度会快很多。
9.5 Q5: 如何选择合适的模型? #
A:
- 新手推荐:
cross-encoder/ms-marco-MiniLM-L-6-v2(通用、快速) - 中文场景:
BAAI/bge-reranker-base - 高精度需求:
cross-encoder/stsb-roberta-large(但速度较慢)