1.RAG简介 #
RAG(Retrieval-Augmented Generation,检索式增强生成)是一种结合了“检索”与“生成”两种能力的自然语言处理技术,常用于问答系统、智能助手等场景。
1.1 安装依赖 #
uv add PyMuPDF beautifulsoup4 python-docx openpyxl python-pptx lxml sentence_transformers numpy chromadb requests rank_bm251.2 大模型的局限性 #
- 知识更新滞后:大模型的训练数据具有时效性,难以及时反映最新的事实和动态信息。
- 缺乏专有领域知识:对于企业内部、行业专属等私有知识,大模型通常无法覆盖,难以满足专业化需求。
- 存在内容幻觉:大模型有时会生成看似合理但实际错误的信息,容易造成误导和错误决策。
1.3 为什么选择RAG? #
提升答案准确性
RAG通过实时检索外部知识库中的相关信息,有效增强生成内容的准确性和权威性,避免模型“胡编乱造”。显著降低训练与维护成本
相较于完全依赖大规模数据训练的生成模型,RAG只需较少的训练数据即可实现高质量输出,大幅减少算力和数据投入。灵活应对新知识与变化
RAG具备极强的适应性,能够动态检索和利用最新的知识库内容,面对新领域、新事件时无需重新训练模型即可快速响应和生成相关答案。
1.4 RAG工作流程 #
1.4.1 步骤一:知识库构建 #
- 管理员将本地文档转化为可检索的向量,供后续问答使用。
1.4.2 步骤二:检索(Retrieval) #
- 用户输入一个问题(Query)。
- 系统将问题转化为向量(embedding),在知识库中检索与问题最相关的若干条文档或片段。
1.4.3 步骤三:生成(Generation) #
- 将检索到的内容与原始问题一起输入到生成式模型中。
- 生成式模型根据这些信息,生成更准确、更有依据的答案。
1.5 优势 #
- 知识可扩展:模型不需要记住所有知识,只需检索外部知识库,便于更新和扩展。
- 答案更有依据:生成的答案可以引用检索到的具体内容,提升可信度。
- 减少幻觉:降低生成模型“胡编乱造”的概率。
1.6 应用场景 #
- 智能问答(如企业知识库问答、法律咨询等)
- 智能客服
- 文档摘要与分析
- 代码检索与生成
1.7 长上下文和RAG #
1.7.1 Long Context(长上下文) #
直接将大量上下文(如长文档、历史对话等)整体输入到大模型,让模型在“记住”全部内容的基础上直接生成答案。
| 维度 | RAG(检索增强生成) | Long Context(长上下文) |
|---|---|---|
| 知识容量 | 理论上无限(依赖外部知识库) | 受限于模型最大上下文窗口(如32K、128K) |
| 实时性 | 检索+生成,检索速度快,生成速度取决于模型 | 直接生成,长文本输入会显著拖慢推理速度 |
| 可扩展性 | 知识库易于扩展和更新,无需重新训练模型 | 上下文窗口有限,超长内容需截断或摘要 |
| 准确性 | 检索相关性决定答案质量,依赖检索效果 | 只要内容在窗口内,模型可直接引用 |
| 幻觉风险 | 检索内容可控,幻觉概率较低 | 长上下文内仍可能出现幻觉 |
| 实现难度 | 需搭建检索系统和知识库,流程较复杂 | 只需支持大窗口模型,流程简单 |
| 成本 | 检索和生成分开,资源消耗可控 | 长上下文推理消耗显著增加 |
| 模型名称 | 上下文长度(token) | 备注说明 |
|---|---|---|
| GPT-3 (davinci) | 4,096 | OpenAI,早期主力模型 |
| GPT-3.5 (turbo) | 4,096 | OpenAI,API常用 |
| gpt-4o-16k | 16,384 | OpenAI,16k版本 |
| GPT-4 (标准) | 8,192 | OpenAI,API标准版 |
| GPT-4-32k | 32,768 | OpenAI,API高配版 |
| GPT-4o | 128,000 | OpenAI,2024年发布,极大提升上下文窗口 |
| Claude 2 | 100,000 | Anthropic,长文本处理能力强 |
| Claude 3 (Opus) | 200,000 | Anthropic,2024年旗舰版 |
| Gemini 1.5 Pro | 1,000,000 | Google,百万级上下文,适合超长文档 |
| Llama 2 | 4,096 | Meta,开源大模型 |
| Llama 3 | 8,192 / 128,000 | Meta,8k为基础版,128k为高配版 |
| ChatGLM2-6B | 8,192 | 智谱AI,国产开源模型 |
| ChatGLM3-6B | 8,192 | 智谱AI,国产开源模型 |
| Qwen-72B | 32,768 | 通义千问,阿里巴巴 |
| Yi-34B | 32,000 | 零一万物,国产开源 |
| ERNIE Bot 4.0 | 8,192 | 百度文心一言 |
1.7.2 适用场景对比 #
RAG适合:
- 企业知识库问答、法律/医疗/金融等专业领域
- 知识库大、内容经常更新的场景
- 需要可追溯、可解释答案的场景
Long Context适合:
- 需要处理长文档、长对话、论文、小说等场景
- 上下文内容有限且全部相关时
- 不方便搭建知识库或检索系统时
2.RAG工作流 #

- 管理员负责知识的整理、切分、向量化和入库,保证知识库的丰富和可检索性。
- 用户只需输入问题,系统会自动完成向量化、检索、生成答案等一系列操作,最终返回高质量的答案。
2.1 管理员部分:知识入库流程 #
这部分主要是将本地文档转化为可检索的向量,供后续问答使用。
Local Documents(本地文档)
管理员准备好需要导入的知识文档,格式可以多样(如PDF、Word、txt等)。Unstructured Loader(非结构化加载器)
使用加载器将各种格式的文档解析为纯文本,便于后续处理。Text(文本)
得到的纯文本内容。Text Splitter(文本切分器)
将长文本按照一定规则(如段落、句子、字数等)切分成较小的文本块。Text Chunks(文本块)
切分后得到的多个小文本片段。Embedding(向量化)
利用嵌入模型(如BERT、SentenceTransformer等)将每个文本块转化为高维向量。VectorStore(向量存储)
所有文本块的向量被存入向量数据库(如Milvus等),用于后续的高效检索。
2.2 用户部分:问答检索流程 #
这部分是用户实际提问并获得答案的过程。
Query(查询)
用户输入一个问题或查询。Embedding(向量化)
将用户的查询同样转化为向量表示。Query Vector(查询向量)
得到用户问题的向量。Vector Similarity(向量相似度)
计算查询向量与知识库中所有文本块向量的相似度,找出最相关的若干文本块。Related Text Chunks(相关文本块)
检索到与用户问题最相关的文本片段。Prompt Template(提示词模板)
将用户问题和相关文本块填充到预设的提示词模板中,构建最终的Prompt。Prompt(提示词)
生成用于大语言模型的完整输入。LLM(大语言模型)
将Prompt输入到大语言模型,生成最终的答案。Answer(答案)
返回给用户的最终回答。
3.文档解析 #
3.1 数据格式的多样性 #
在实际开展文档处理工作前,我们首先要对企业内部的数据类型和特点有一个全面的了解。只有充分认识到数据的多样性和复杂性,才能为后续的数据解析和知识库建设打下坚实基础。
企业在日常运营中会积累大量数据,这些数据因行业、业务流程和管理方式的不同而呈现出极大的多样性。部分数据甚至是企业独有的,外部很难见到。因此,数据处理往往需要根据实际情况定制脚本和工具,灵活应对各种数据格式和内容。
企业数据的复杂性主要体现在两个方面:数据格式的多样性和数据内容的复杂性。
3.2 数据内容的复杂性 #
企业数据大致可以分为两大类:结构化数据和非结构化/半结构化数据。
结构化数据
这类数据通常存储在关系型数据库(如 MySQL、Oracle)中,数据以表格形式组织,每个字段都有明确的定义和含义。除了传统的关系型数据库,还有文档型数据库(如 MongoDB)、全文检索数据库(如 Elasticsearch)、列式数据库(如 ClickHouse)、图数据库(如 Neo4j)以及分布式 NoSQL 数据库(如 Cassandra)等。结构化数据的优点是格式统一、易于查询和分析,通常通过 SQL 或类似的查询语言进行操作。非结构化与半结构化数据
这部分数据类型丰富,包括但不限于 PDF、Word、PPT、Excel 等常见办公文档,纯文本文件(如日志、txt),网页数据(HTML)、数据交换格式(JSON、XML)、Markdown 文档,以及图片、音频、视频等多媒体文件。非结构化数据没有固定的字段和格式,内容表现形式灵活多变,解析和处理难度较大。半结构化数据则介于两者之间,既有一定的结构性,又保留了内容的灵活性。
除了格式多样,企业数据在内容层面也极具挑战性,尤其是非结构化文档。以 PDF 文件为例,单个文档中可能同时包含标题、段落、表格、图片、公式等多种元素。不同文档的排版和布局也各不相同,有的采用单栏,有的为双栏,内容组织方式千变万化。文本、标题、列表、表格、图片等元素可能交错出现,给自动化解析带来很大难度。
正因为企业数据在格式和内容上的复杂性,构建高质量的知识库时,往往需要针对不同类型的数据进行专门的预处理和解析策略。只有这样,才能确保后续知识抽取和检索的准确性和有效性。
总之,企业数据的多样性和复杂性是知识管理和智能文档处理中的一大挑战。理解这些特点,有助于我们选择合适的技术方案和工具,提升数据处理的效率和质量。
3.3 GIGO #
在数据处理和智能系统开发领域,有一个非常重要的理念——“垃圾输入,必然导致垃圾输出”(Garbage In, Garbage Out,简称GIGO)。这个原则广泛应用于计算机科学和信息系统建设中,强调了数据质量对最终结果的决定性作用。无论你的算法多么先进、流程多么完善,如果最初输入的数据存在问题,最终的输出也难以令人满意。
在构建RAG(检索增强生成)知识库的过程中,数据质量同样是成败的关键。整个流程大致包括:业务数据的选取、数据解析、内容分块、向量化处理,以及将向量存入数据库。每一个环节都可能成为“垃圾数据”产生的源头。
数据源选择
首先,数据源的甄别至关重要。只有与业务需求高度相关的数据,才能为系统提供有价值的支撑。如果引入了无关或低质量的信息,不仅会影响检索效果,还可能导致系统输出错误或无意义的答案。此外,数据本身的准确性和一致性也必须得到保证。如果原始数据中存在矛盾、错误或遗漏,后续所有处理都无法弥补这些缺陷。最后,数据的覆盖面也要足够广泛,避免因信息不全而影响系统的整体表现。
数据解析环节
企业数据格式多样,内容复杂,解析过程中极易出现问题。例如,文本内容可能被错误识别,表格结构解析混乱,或者某些关键信息被遗漏。这些解析失误都会直接影响后续的知识抽取和检索效果。因此,数据解析工具和脚本的选择与调优同样重要。
内容分块策略
文档分块是将大文本拆分为更小的片段,以便后续向量化和检索。但如果分块策略不合理,比如将本应连贯的语义拆散,或者把无关内容混在一起,都会导致上下文丢失或检索噪声增加。这不仅影响检索的相关性,还可能让生成模型输出不准确的答案。
全流程质量把控
综上所述,RAG系统的每一步都可能成为“垃圾数据”的入口。只有在数据源筛选、解析、分块等各个环节都严格把控质量,才能确保最终系统的有效性和可靠性。数据质量的保障,是构建高效智能知识系统的基石。始终牢记:只有高质量的输入,才能带来高价值的输出。
3.4 文档解析 #
uv add PyMuPDF python-docx openpyxl python-pptx beautifulsoup4 lxml
3.4.1 读取 PDF 文件 #
# 安装依赖:uv add PyMuPDF
import fitz # PyMuPDF
# 定义一个函数用于提取PDF文件中的所有文本内容
def extract_pdf_text(pdf_path):
"""
提取PDF文件中的所有文本内容
参数:
pdf_path (str): PDF文件路径
返回:
str: 合并后的所有页文本
"""
# 打开PDF文件
pdf = fitz.open(pdf_path)
# 创建一个空列表用于存放每一页的文本内容
text_list = []
# 遍历PDF中的每一页
for page in pdf:
# 提取当前页的文本内容,并添加到列表中
text_list.append(page.get_text("text")) # type: ignore
# 将所有页的文本内容合并为一个字符串
all_text = "\n".join(text_list)
# 返回合并后的文本
return all_text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的PDF文件名
pdf_file = "example.pdf"
# 调用函数提取PDF文本
result_text = extract_pdf_text(pdf_file)
# 打印提取到的文本内容
print(result_text)
3.4.2 读取 Word 文件 #
# 导入python-docx库中的Document类
from docx import Document
# 定义函数:从Word文档中提取所有段落文本
def extract_text_from_word(file_path):
"""
从Word文档中提取所有段落的文本,并以字符串返回。
:param file_path: Word文档的路径
:return: 文本内容字符串
"""
# 加载Word文档
doc = Document(file_path)
# 遍历所有段落,将段落文本拼接为一个字符串(以换行符分隔)
text = "\n".join([para.text for para in doc.paragraphs])
# 返回拼接后的文本
return text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的Word文件名
file_path = "example.docx"
# 调用函数提取Word文本
result = extract_text_from_word(file_path)
# 打印提取到的文本内容
print(result)
3.4.3 读取 Excel 文件 #
# 安装依赖:uv add openpyxl
import openpyxl
# 定义函数:从Excel文件中提取所有单元格内容为文本
def extract_text_from_excel(file_path):
"""
从Excel文件中提取所有单元格内容为文本,并以字符串返回。
:param file_path: Excel文件路径
:return: 文本内容字符串
"""
# 加载Excel工作簿
wb = openpyxl.load_workbook(file_path)
# 获取活动工作表
ws = wb.active
# 初始化用于存储每一行文本的列表
rows = []
# 遍历工作表中的每一行,values_only=True表示只获取单元格的值
for row in ws.iter_rows(values_only=True):
# 将每一行的所有单元格内容转换为字符串,并用制表符分隔,空单元格用空字符串代替
rows.append("\t".join([str(cell) if cell is not None else "" for cell in row]))
# 将所有行用换行符拼接为一个大字符串
all_text = "\n".join(rows)
# 返回拼接后的文本
return all_text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的Excel文件名
file_path = "example.xlsx"
# 调用函数提取Excel文本
result = extract_text_from_excel(file_path)
# 打印提取到的文本内容
print(result)
3.4.4 读取 PPT 文件 #
# 导入Presentation类,用于读取PPT文件
from pptx import Presentation
# 定义函数:提取PPT文件中的所有文本内容
def extract_ppt_text(file_path):
"""
提取PPT文件中的所有文本内容,并以字符串返回。
:param file_path: PPT文件路径
:return: 所有文本内容(以换行符分隔)
"""
# 加载PPT文件
ppt = Presentation(file_path)
# 初始化用于存储所有文本的列表
text_list = []
# 遍历PPT中的每一页幻灯片
for slide in ppt.slides:
# 遍历幻灯片中的每一个形状
for shape in slide.shapes:
# 判断该形状是否有text属性(即是否包含文本)
if hasattr(shape, "text"):
# 如果有,则将文本内容添加到列表中
text_list.append(shape.text)
# 将所有文本用换行符拼接成一个字符串
all_text = "\n".join(text_list)
# 返回拼接后的所有文本
return all_text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的PPT文件名
ppt_file = "example.pptx"
# 调用函数提取PPT文本内容
result = extract_ppt_text(ppt_file)
# 打印提取到的文本内容
print(result)3.4.5 读取 HTML 文件 #
# 安装依赖:uv add beautifulsoup4
from bs4 import BeautifulSoup # 导入BeautifulSoup库用于解析HTML
# 定义函数:从HTML文件中提取所有文本内容
def extract_text_from_html(file_path):
"""
从指定HTML文件中提取所有文本内容
参数:
file_path (str): HTML文件路径
返回:
str: 提取的文本内容
"""
# 以utf-8编码方式打开HTML文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取整个HTML文件内容为字符串
html = f.read()
# 使用BeautifulSoup解析HTML内容
soup = BeautifulSoup(html, "html.parser")
# 提取所有文本内容,使用换行符分隔
text = soup.get_text(separator="\n")
# 返回提取到的文本内容
return text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的HTML文件名
file_path = "example.html"
# 调用函数提取HTML文本内容
result = extract_text_from_html(file_path)
# 打印提取到的文本内容
print(result)
3.4.6 读取 JSON 文件 #
# 导入内置的json库
import json
# 定义函数:读取并格式化打印JSON文件内容
def read_and_print_json(filename):
"""
读取指定JSON文件并以格式化字符串打印内容
:param filename: JSON文件名
"""
# 以utf-8编码方式打开指定的JSON文件
with open(filename, "r", encoding="utf-8") as f:
# 使用json.load读取文件内容为Python对象
data = json.load(f)
# 使用json.dumps将Python对象格式化为带缩进的字符串,确保中文正常显示
text = json.dumps(data, ensure_ascii=False, indent=2)
# 打印格式化后的JSON字符串
print(text)
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 调用函数,读取并打印example.json文件内容
read_and_print_json("example.json")
3.4.7 读取 XML 文件 #
# 导入lxml库中的etree模块,用于解析XML
from lxml import etree
# 定义函数:从XML文件中提取所有文本内容
def extract_xml_text(file_path):
"""
读取XML文件并提取所有文本内容
参数:
file_path (str): XML文件路径
返回:
str: 提取的所有文本内容
"""
# 以utf-8编码方式打开XML文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取XML文件的全部内容为字符串
xml = f.read()
# 将字符串形式的XML内容解析为XML树结构
root = etree.fromstring(xml.encode("utf-8"))
# 遍历XML树,提取所有文本内容,并用空格连接
text = " ".join(root.itertext())
# 返回提取到的文本内容
return text
# 主程序入口,进行测试调用
if __name__ == "__main__":
# 指定要读取的XML文件名
xml_file = "example.xml"
# 调用函数提取XML文本内容
result = extract_xml_text(xml_file)
# 打印提取到的文本内容
print(result)
3.4.8 读取 CSV 文件 #
# 导入内置的csv库
import csv
# 定义函数:读取CSV文件内容,并将每行用逗号连接,所有行用换行符拼接成一个字符串返回
def read_csv_to_text(filename):
"""
读取CSV文件内容,并将每行用逗号连接,所有行用换行符拼接成一个字符串返回。
"""
# 以utf-8编码方式打开CSV文件
with open(filename, "r", encoding="utf-8") as f:
# 创建csv.reader对象,按行读取CSV内容
reader = csv.reader(f)
# 对每一行,用逗号连接各列,生成字符串列表
rows = [", ".join(row) for row in reader]
# 用换行符拼接所有行,得到完整文本
all_text = "\n".join(rows)
# 返回拼接后的文本
return all_text
# 测试调用
if __name__ == "__main__":
# 调用函数读取example.csv文件内容
result = read_csv_to_text("example.csv")
# 打印读取到的内容
print(result)
3.4.9 读取纯文本文件 #
# 定义一个函数,用于读取指定文本文件的内容并返回
def read_text_file(filename):
"""
读取指定文本文件内容并返回
:param filename: 文件名
:return: 文件内容字符串
"""
# 以只读模式并指定utf-8编码打开文件
with open(filename, "r", encoding="utf-8") as f:
# 读取文件全部内容
text = f.read()
# 返回读取到的文本内容
return text
# 测试代码块,只有当本文件作为主程序运行时才会执行
if __name__ == "__main__":
# 调用函数读取example.txt文件内容
result = read_text_file("example.txt")
# 打印读取到的内容
print(result)
3.4.10 读取 Markdown #
# 定义一个函数,用于读取Markdown文件内容
def read_markdown_file(file_path):
# 以只读模式并指定utf-8编码打开Markdown文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取并返回文件的全部内容
return f.read()
# 测试代码块,只有当本文件作为主程序运行时才会执行
if __name__ == "__main__":
# 调用函数读取example.md文件内容
content = read_markdown_file("example.md")
# 打印读取到的内容
print(content)
4 文档分割 #
文本分块(Text Chunking)是一种将长文档分解为更小、更易处理的文本片段的技术。这个过程类似于将一本厚重的书籍拆分成多个章节,每个章节都包含相对独立且完整的信息单元。
4.1 文本分块的核心价值 #
在构建检索增强生成(RAG)系统时,文本分块技术具有以下重要意义:
语义纯度保证:长文档往往包含多个不同的主题和语义信息,直接处理会导致语义混淆,影响检索精度。
检索精度提升:通过分块,系统能够更精确地匹配用户查询,避免返回无关信息,从而提高回答质量。
模型限制适配:现代语言模型都有输入长度限制,分块技术确保每个文本片段都能在模型的处理范围内。
4.2. 文本分块的基本原则 #
4.2.1 语义完整性原则 #
每个文本块应该包含一个完整且独立的语义单元。这就像拼图游戏中的每一块,都应该能够独立表达一个完整的概念。
4.2.2 大小平衡原则 #
- 过小的块:会破坏语义的完整性,导致上下文信息缺失
- 过大的块:可能包含多个不相关的主题,降低检索精度
4.3 递归分割策略 #
递归分割是一种基于规则的分块方法,其核心思想是使用预定义的分隔符对文档进行逐级分解。
4.3.1 工作原理 #
文档 → 段落 → 句子 → 单词 → 字符系统首先使用最粗粒度的分隔符(如段落分隔符),如果生成的块仍然过大,则使用更细粒度的分隔符继续分割,直到所有块都符合预设大小。
4.3.2 text_splitter #
4.4 语义感知分块策略 #
基于规则的分块方法虽然简单有效,但有时无法准确识别语义边界。语义感知分块技术通过分析文本的语义特征来进行更智能的分割。
基于Embedding的语义分块利用预训练的embedding模型来计算文本片段之间的语义相似度。
核心思想
- 语义分块:不是简单的按长度或句子数分割,而是基于语义相似度进行智能分割。
- 滑动窗口:聚合上下文信息,提高语义稳定性。先用固定大小的窗口创建初始块,再根据语义相似度调整。
- 相似度阈值:通过阈值控制分割的粒度,阈值越高分割越细,阈值越低分割越粗。
4.4.1 代码实现 #
# 导入句子嵌入模型
from sentence_transformers import SentenceTransformer
# 导入numpy用于数值计算
import numpy as np
# 导入正则表达式模块
import re
# 加载预训练的句子嵌入模型
print("正在加载句子嵌入模型...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("模型加载完成。")
# 定义基于语义的分块器类
class SemanticChunker:
# 初始化方法,设置窗口大小和相似度阈值
def __init__(self, window_size=2, threshold=0.85):
# 设置每个窗口包含的句子数
self.window_size = window_size
# 设置相邻窗口的相似度阈值
self.threshold = threshold
# 日志:输出初始化参数
print(
f"SemanticChunker初始化,窗口大小: {window_size},相似度阈值: {threshold}"
)
# 创建分块文档的方法
def create_documents(self, text):
# 使用正则表达式按中英文标点和换行分割句子
# 当正则表达式中使用捕获分组时,分隔符会包含在结果中
print("正在分割原始文本为句子...")
sentences = re.split(r"(。|!|?|\!|\?|\.|\n)", text)
# 初始化句子列表
sents = []
# 遍历分割后的句子和标点,合并为完整句子
for i in range(0, len(sentences) - 1, 2):
# 拼接句子内容和分隔符
s = sentences[i].strip() + sentences[i + 1].strip()
# 如果拼接后不为空,则加入句子列表
if s.strip():
sents.append(s)
# 日志:输出分割后的句子数量
print(f"分割得到 {len(sents)} 个句子。")
# 初始化分块列表
docs = []
# 设置起始索引
start = 0
# 使用滑动窗口将句子分组
# 窗口用于聚合上下文,窗口合并多个句子,嵌入能表达更完整的语义,让语义比较更稳定、更可靠
print("正在使用滑动窗口进行初步分块...")
while start < len(sents):
# 计算窗口结束位置
end = min(start + self.window_size, len(sents))
# 获取当前窗口的句子
window = sents[start:end]
# 合并窗口内句子为一个块
docs.append("".join(window))
# 移动到下一个窗口
start = end
# 日志:输出初步分块数量
print(f"初步分块完成,共 {len(docs)} 个块。")
# 计算每个窗口的嵌入向量
print("正在计算每个块的嵌入向量...")
embeddings = model.encode(docs)
# 初始化分割点列表,起始点为0
split_points = [0]
# 遍历相邻窗口,计算相似度
print("正在计算相邻块之间的相似度...")
for i in range(1, len(docs)):
# 计算余弦相似度
sim = np.dot(embeddings[i - 1], embeddings[i]) / (
np.linalg.norm(embeddings[i - 1]) * np.linalg.norm(embeddings[i])
)
# 日志:输出每对块的相似度
print(f"块 {i} 与块 {i-1} 的相似度为: {sim:.4f}")
# 如果相似度低于阈值,则作为新的分割点
if sim < self.threshold:
print(f"相似度低于阈值({self.threshold}),在位置 {i} 添加分割点。")
split_points.append(i)
# 初始化最终分块结果列表
result = []
# 遍历所有分割点,生成最终文本块
print("正在根据分割点生成最终分块结果...")
for i in range(len(split_points)):
# 当前块的起始索引
start = split_points[i]
# 当前块的结束索引
end = split_points[i + 1] if i + 1 < len(split_points) else len(docs)
# 合并该范围内的窗口为一个块
chunk = "".join(docs[start:end])
# 如果块内容不为空,则加入结果列表
if chunk.strip():
print(f"生成第 {len(result)+1} 个块,内容长度: {len(chunk)}")
result.append(chunk)
# 返回所有分块
print(f"最终分块完成,共 {len(result)} 个块。")
return result
# 创建语义分块器对象,设置窗口大小和相似度阈值
print("正在创建语义分块器对象...")
semantic_splitter = SemanticChunker(window_size=2, threshold=0.85)
# 准备需要分割的长文本
print("准备待分割的长文本...")
long_text = """今天天气晴朗,适合去公园散步。
量子力学中的叠加态是描述粒子同时处于多个状态的数学工具。
Windows命令行中复制文件可以使用copy命令。
大熊猫主要以竹子为食,是中国的国宝。
欧拉公式被誉为“最美的数学公式”。"""
# 执行文本分割,得到分块结果
print("开始执行文本分割...")
documents = semantic_splitter.create_documents(long_text)
# 打印分割结果,显示每个块的内容
print(f"总共分割为 {len(documents)} 个块:\n")
for i, doc in enumerate(documents, 1):
# 打印当前块的编号
print(f"=== 第 {i} 个块 ===")
# 打印当前块的内容
print(doc)4.4.2 执行流程 #
按标点分割] E1 --> E2[遍历合并句子和标点] G --> G1[窗口大小=2] G1 --> G2[滑动窗口分组] H --> H1[model.encode
生成嵌入] I --> I1[余弦相似度计算] N --> N1[合并连续相似块] N1 --> N2[生成最终文档] %% 样式定义 classDef process fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef decision fill:#fff3e0,stroke:#ef6c00,stroke-width:2px classDef startend fill:#c8e6c9,stroke:#2e7d32,stroke-width:2px classDef subprocess fill:#f3e5f5,stroke:#4a148c,stroke-width:2px class A,P startend class B,C,D,E,F,G,H,I,N,O process class J,M decision class E1,E2,G1,G2,H1,I1,N1,N2 subprocess
1. 句子分割阶段
sentences = re.split(r"(。|!|?|\!|\?|\.|\n)", text)- 作用:使用正则表达式按中英文标点符号和换行符分割文本。
- 正则表达式说明:
(。|!|?|\!|\?|\.|\n):匹配中文句号、感叹号、问号,英文感叹号、问号、句号,以及换行符。- 由于使用了捕获分组
(),分隔符会保留在结果中。
- 结果:得到一个包含句子内容和分隔符的列表。
2. 句子重构阶段
sents = []
for i in range(0, len(sentences) - 1, 2):
s = sentences[i].strip() + sentences[i + 1].strip()
if s.strip():
sents.append(s)- 作用:将分割后的句子内容和分隔符合并为完整的句子。
- 逻辑说明:
- 由于正则分割的结果是
[句子1, 分隔符1, 句子2, 分隔符2, ...],所以每两个元素为一组。 sentences[i]是句子内容,sentences[i + 1]是对应的分隔符。- 合并后得到完整的句子,并过滤掉空字符串。
- 由于正则分割的结果是
3. 初始分块阶段(滑动窗口)
docs = []
start = 0
while start < len(sents):
end = min(start + self.window_size, len(sents))
window = sents[start:end]
docs.append("".join(window))
start = end- 作用:使用滑动窗口将句子分组,创建初始的文本块。
- 参数说明:
window_size=2:每个窗口包含2个句子。
- 执行过程:
- 从第0个句子开始,每次取2个句子组成一个窗口。
- 将窗口内的句子合并为一个文本块。
- 移动到下一个窗口,直到处理完所有句子。
4. 嵌入向量计算阶段
embeddings = model.encode(docs)- 作用:使用预训练的句子嵌入模型将每个文本块转换为向量表示。
- 模型说明:
all-MiniLM-L6-v2是一个轻量级的句子嵌入模型,能够将文本转换为384维的向量。
5. 相似度计算和分割点确定阶段
split_points = [0]
for i in range(1, len(docs)):
sim = np.dot(embeddings[i - 1], embeddings[i]) / (
np.linalg.norm(embeddings[i - 1]) * np.linalg.norm(embeddings[i])
)
if sim < self.threshold:
split_points.append(i)- 作用:计算相邻文本块的余弦相似度,根据相似度阈值确定分割点。
- 计算过程:
- 使用余弦相似度公式
- 如果相似度低于阈值(0.85),则认为这两个块语义差异较大,需要分割。
- 结果:得到一个分割点列表,表示哪些位置需要分割。
6. 最终分块生成阶段
result = []
for i in range(len(split_points)):
start = split_points[i]
end = split_points[i + 1] if i + 1 < len(split_points) else len(docs)
chunk = "".join(docs[start:end])
if chunk.strip():
result.append(chunk)- 作用:根据分割点生成最终的文本块。
- 逻辑说明:
- 遍历所有分割点,每个分割点之间的范围作为一个文本块。
- 将范围内的初始文本块合并为一个最终块。
- 过滤掉空字符串,返回最终的分块结果。
4.4.2 滑动窗口 #
# 导入句子向量模型库
from sentence_transformers import SentenceTransformer
# 导入操作系统库
import os
# 导入正则表达式库
import re
# 导入numpy库,用于数值运算
import numpy as np
# 设置HuggingFace镜像地址为国内源
os.environ["HF-ENDPOINT"] = "https://hf-mirror.com"
# 加载SentenceTransformer模型
model = SentenceTransformer("all-MiniLM-L6-v2")
# 定义语义分块器类
class SemanticChunker:
# 初始化方法,设置窗口大小和阈值
def __init__(self, window_size=2, threshold=0.85):
# 设置窗口大小
self.window_size = window_size
# 设置阈值
self.threshold = threshold
# 创建分块文档的方法
def create_documents(self, text):
# 用正则表达式按标点符号或换行符分割文本为句子
sentences = re.split(r"(。|!|?|\!|\?|\.|\n)", text)
# 初始化句子列表
sents = []
# 按照标点符号把文本拼成完整的句子段
for i in range(0, len(sentences) - 1, 2):
# 合并完整的句子
s = sentences[i].strip() + sentences[i + 1].strip()
# 如果合并后的句子非空,则加入句子列表
if s.strip():
sents.append(s)
# 打印每个句子,便于调试
for sent in sents:
print(sent)
# 计算每个句子的向量嵌入
sentence_embeddings = model.encode(sents)
# 初始化分割点列表,起始位置为0
split_points = [0]
# 遍历每一个可能的分割点(从窗口大小处开始)
for i in range(self.window_size, len(sents)):
# 计算前一个窗口的起始和结束索引
prev_window_start = max(0, i - self.window_size)
prev_window_end = i
# 计算后一个窗口的起始和结束索引
next_window_start = i
next_window_end = min(len(sents), i + self.window_size)
# 获取前一个窗口的向量
prev_window_embeddings = sentence_embeddings[
prev_window_start:prev_window_end
]
# 计算前一个窗口的平均向量
prev_avg_embedding = np.mean(prev_window_embeddings, axis=0)
# 获取后一个窗口的向量
next_window_embeddings = sentence_embeddings[
next_window_start:next_window_end
]
# 计算后一个窗口的平均向量
next_avg_embedding = np.mean(next_window_embeddings, axis=0)
# 计算两个窗口向量的余弦相似度
sim = np.dot(prev_avg_embedding, next_avg_embedding) / (
np.linalg.norm(prev_avg_embedding) * np.linalg.norm(next_avg_embedding)
)
# 如果相似度低于阈值,就把当前位置作为新的分割点
if sim < self.threshold:
split_points.append(i)
# 初始化最终的分割结果
result = []
# 遍历分割点,把句子合并成块
for i in range(len(split_points) - 1):
# 当前块的起始索引
start = split_points[i]
# 当前块的结束索引
end = split_points[i + 1]
# 合并当前范围内的句子作为分块
chunk = "".join(sents[start:end])
# 如果分块非空,则加入分割结果列表
if chunk.strip():
result.append(chunk)
# 合并最后一个分块(尾部)
chunk = "".join(sents[end:])
# 如果最后一个分块非空,则加入结果
if chunk.strip():
result.append(chunk)
# 返回所有分块
return result
# 再次设置HuggingFace镜像地址为国内源(保险起见)
os.environ["HF-ENDPOINT"] = "https://hf-mirror.com"
# 定义要分块的长文本
long_text = """今天天气晴朗。今天天气晴朗。今天天气晴朗。今天天气晴朗。今天天气晴朗。
Windows。Windows。Windows。Windows。Windows。"""
# 创建语义分块器对象,窗口大小为2,相似度阈值0.85
semanticChunker = SemanticChunker(window_size=2, threshold=0.85)
# 使用分块器对长文本进行分块
documents = semanticChunker.create_documents(long_text)
# 打印每个分块的编号和内容
for i, doc in enumerate(documents, 1):
print(i, doc)
5. 向量化 #
5.1 使用云服务 #
# 导入os模块,用于读取环境变量
import os
# 导入requests库,用于发送HTTP请求
import requests
# 设置文本向量API的URL
VOLC_EMBEDDINGS_API_URL = "https://ark.cn-beijing.volces.com/api/v3/embeddings"
# 设置API密钥
VOLC_API_KEY = "d52e49a1-36ea-44bb-bc6e-65ce789a72f6"
# 定义获取文档向量的函数,参数为文档内容
def get_doubao_embedding(doc_content):
# 构造请求头,包含内容类型和认证信息
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {VOLC_API_KEY}",
}
# 构造请求体,指定模型和输入内容
payload = {"model": "doubao-embedding-text-240715", "input": doc_content}
# 发送POST请求到向量API,获取响应
response = requests.post(VOLC_EMBEDDINGS_API_URL, json=payload, headers=headers)
# 判断响应状态码是否为200,表示请求成功
if response.status_code == 200:
# 解析响应的JSON数据
data = response.json()
# 提取嵌入向量
embedding = data["data"][0]["embedding"]
# 返回嵌入向量
return embedding
else:
# 如果请求失败,抛出异常并输出错误信息
raise Exception(f"Embedding API error: {response.text}")
# 定义待处理的文档内容
doc_content = "这是一个示例文档"
# 调用函数获取嵌入向量
embedding = get_doubao_embedding(doc_content)
# 打印嵌入向量
print(embedding)5.2 本地向量化 #
# 从sentence_transformers库中导入SentenceTransformer类
from sentence_transformers import SentenceTransformer
# 加载预训练的句子嵌入模型"all-MiniLM-L6-v2"
model = SentenceTransformer("all-MiniLM-L6-v2")
# 定义一个函数,用于获取输入文档内容的向量表示
def get_sentence_embedding(doc_content):
# 使用模型对文档内容进行编码,得到嵌入向量
embedding = model.encode(doc_content)
# 返回嵌入向量
return embedding
# 定义一个示例文档内容
doc_content = "这是一个示例文档"
# 获取示例文档的嵌入向量
embedding = get_sentence_embedding(doc_content)
# 打印嵌入向量
print(embedding)
6.RAG工作流 #
Retrieval(检索) → Augmented(增强) → Generation(生成)
↓ ↓ ↓
向量数据库检索 拼接上下文到Prompt 大模型生成答案6.1 extract.py #
extract.py
# 导入PyMuPDF库(fitz),用于处理PDF文件
import fitz # PyMuPDF
# 导入Optional类型提示
from typing import Optional
# 导入日志logging功能
import logging
# 获取当前模块日志记录器
logger = logging.getLogger(__name__)
# 定义用于提取PDF所有文本内容的函数
def extract_pdf_text(pdf_path: str) -> str:
"""
提取PDF文件中的所有文本内容
参数:
pdf_path (str): PDF文件路径
返回:
str: 合并后的所有页文本
异常:
FileNotFoundError: 文件不存在
Exception: PDF文件读取失败
"""
try:
# 打开PDF文件
pdf = fitz.open(pdf_path)
try:
# 新建一个空列表,用来存储每页文本
text_list = []
# 遍历每一页
for page in pdf:
# 获取当前页文本,并加入列表
text_list.append(page.get_text("text")) # type: ignore
# 将每页文本用换行拼接成一个大字符串
all_text = "\n".join(text_list)
# 返回拼接后的文本
return all_text
finally:
# 确保关闭PDF文件
pdf.close()
except FileNotFoundError:
# 如果文件未找到,记录错误日志
logger.error(f"PDF文件不存在: {pdf_path}")
# 向上抛出异常
raise
except Exception as e:
# 其他异常情况,记录错误信息
logger.error(f"提取PDF文本失败: {pdf_path}, 错误: {str(e)}")
# 抛出异常
raise
# 导入python-docx的Document类
from docx import Document
# 定义提取Word文档所有段落文本的函数
def extract_text_from_word(file_path: str) -> str:
"""
从Word文档中提取所有段落的文本,并以字符串返回。
参数:
file_path (str): Word文档的路径
返回:
str: 文本内容字符串
异常:
FileNotFoundError: 文件不存在
Exception: Word文件读取失败
"""
try:
# 加载Word文档
doc = Document(file_path)
# 取所有段落的文本,并用换行符拼接
text = "\n".join([para.text for para in doc.paragraphs])
# 返回拼接好的文本
return text
except FileNotFoundError:
# 文件未找到时记录日志
logger.error(f"Word文件不存在: {file_path}")
# 抛出异常
raise
except Exception as e:
# 其它异常记录错误信息
logger.error(f"提取Word文本失败: {file_path}, 错误: {str(e)}")
# 抛出异常
raise
# 导入openpyxl库,用于操作Excel文件
import openpyxl
# 定义函数提取Excel文件中的所有文本
def extract_text_from_excel(file_path: str) -> str:
"""
从Excel文件中提取所有单元格内容为文本,并以字符串返回。
参数:
file_path (str): Excel文件路径
返回:
str: 文本内容字符串
异常:
FileNotFoundError: 文件不存在
Exception: Excel文件读取失败
"""
try:
# 加载Excel工作簿
wb = openpyxl.load_workbook(file_path, data_only=True)
try:
# 取得活动工作表
ws = wb.active
# 新建空列表保存每一行字符串
rows = []
# 遍历所有行,只取单元格的值
for row in ws.iter_rows(values_only=True):
# 将每行单元格内容用Tab连接,空值转换为空字符串
rows.append("\t".join([str(cell) if cell is not None else "" for cell in row]))
# 用换行符拼接所有行
all_text = "\n".join(rows)
# 返回最终文本
return all_text
finally:
# 关闭Excel工作簿
wb.close()
except FileNotFoundError:
# 文件未找到时日志记录
logger.error(f"Excel文件不存在: {file_path}")
raise
except Exception as e:
# 其它异常日志并抛出
logger.error(f"提取Excel文本失败: {file_path}, 错误: {str(e)}")
raise
# 导入python-pptx库的Presentation类
from pptx import Presentation
# 定义函数提取PPT文件所有文本内容
def extract_ppt_text(file_path: str) -> str:
"""
提取PPT文件中的所有文本内容,并以字符串返回。
参数:
file_path (str): PPT文件路径
返回:
str: 所有文本内容(以换行符分隔)
异常:
FileNotFoundError: 文件不存在
Exception: PPT文件读取失败
"""
try:
# 加载PPT文件
ppt = Presentation(file_path)
# 新建列表存储所有文本内容
text_list = []
# 遍历PPT中的每张幻灯片
for slide in ppt.slides:
# 遍历当前幻灯片的每个形状
for shape in slide.shapes:
# 判断是否含有文本,且文本不为空
if hasattr(shape, "text") and shape.text.strip():
# 有文本时加入结果列表
text_list.append(shape.text)
# 用换行符拼接所有文本
all_text = "\n".join(text_list)
# 返回所有文本内容
return all_text
except FileNotFoundError:
# 文件未找到时日志打印
logger.error(f"PPT文件不存在: {file_path}")
raise
except Exception as e:
# 处理其它异常
logger.error(f"提取PPT文本失败: {file_path}, 错误: {str(e)}")
raise
# 导入BeautifulSoup用于解析HTML
from bs4 import BeautifulSoup # BeautifulSoup用于解析HTML
# 定义函数,从HTML文件提取所有文本内容
def extract_text_from_html(file_path: str) -> str:
"""
从指定HTML文件中提取所有文本内容
参数:
file_path (str): HTML文件路径
返回:
str: 提取的文本内容
异常:
FileNotFoundError: 文件不存在
Exception: HTML文件读取失败
"""
try:
# 以utf-8编码方式打开HTML文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取HTML文件所有内容
html = f.read()
# 创建BeautifulSoup对象
soup = BeautifulSoup(html, "html.parser")
# 用换行分隔符获取全部文本
text = soup.get_text(separator="\n", strip=True)
# 返回文本
return text
except FileNotFoundError:
# 文件不存在,记录日志
logger.error(f"HTML文件不存在: {file_path}")
raise
except Exception as e:
# 其它异常记录并抛出
logger.error(f"提取HTML文本失败: {file_path}, 错误: {str(e)}")
raise
# 导入内置json库
import json
# 定义提取JSON文件文本内容的函数
def extract_text_from_json(filename: str) -> str:
"""
从JSON文件中提取文本内容并格式化为字符串
参数:
filename (str): JSON文件路径
返回:
str: 格式化后的JSON文本内容
异常:
FileNotFoundError: 文件不存在
json.JSONDecodeError: JSON解析失败
"""
try:
# 以utf-8编码打开JSON文件
with open(filename, "r", encoding="utf-8") as f:
# 加载JSON内容到Python对象
data = json.load(f)
# 格式化JSON为缩进文本,显示中文
text = json.dumps(data, ensure_ascii=False, indent=2)
# 返回字符串格式JSON内容
return text
except FileNotFoundError:
# 文件不存在时记录日志
logger.error(f"JSON文件不存在: {filename}")
raise
except json.JSONDecodeError as e:
# JSON解析异常日志
logger.error(f"JSON解析失败: {filename}, 错误: {str(e)}")
raise
# 导入lxml库的etree模块用于XML处理
from lxml import etree
# 定义函数,从XML文件提取所有文本内容
def extract_xml_text(file_path: str) -> str:
"""
读取XML文件并提取所有文本内容
参数:
file_path (str): XML文件路径
返回:
str: 提取的所有文本内容
异常:
FileNotFoundError: 文件不存在
etree.XMLSyntaxError: XML解析失败
"""
try:
# 用utf-8编码打开XML文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取XML字符串内容
xml = f.read()
# 解析为XML树结构对象
root = etree.fromstring(xml.encode("utf-8"))
# 遍历所有文本节点并用空格拼接
text = " ".join(root.itertext())
# 返回拼接后的文本
return text
except FileNotFoundError:
# 文件不存在日志
logger.error(f"XML文件不存在: {file_path}")
raise
except etree.XMLSyntaxError as e:
# XML语法异常日志
logger.error(f"XML解析失败: {file_path}, 错误: {str(e)}")
raise
except Exception as e:
# 其它异常日志
logger.error(f"提取XML文本失败: {file_path}, 错误: {str(e)}")
raise
# 导入csv模块
import csv
# 定义读取CSV内容并串成字符串的函数
def read_csv_to_text(filename: str) -> str:
"""
读取CSV文件内容,并将每行用逗号连接,所有行用换行符拼接成一个字符串返回。
参数:
filename (str): CSV文件路径
返回:
str: 拼接后的字符串
异常:
FileNotFoundError: 文件不存在
"""
try:
# 以utf-8编码方式打开CSV文件
with open(filename, "r", encoding="utf-8") as f:
# 创建csv.reader对象逐行读取
reader = csv.reader(f)
# 每行用逗号拼接并放到列表
rows = [", ".join(row) for row in reader]
# 用换行拼接所有行
all_text = "\n".join(rows)
# 返回结果
return all_text
except FileNotFoundError:
# 文件不存在日志
logger.error(f"CSV文件不存在: {filename}")
raise
except Exception as e:
# 其它异常日志
logger.error(f"读取CSV文件失败: {filename}, 错误: {str(e)}")
raise
# 定义读取文本文件内容的函数
def read_text_file(filename: str) -> str:
"""
读取指定文本文件内容并返回
参数:
filename (str): 文件路径
返回:
str: 文件内容字符串
异常:
FileNotFoundError: 文件不存在
"""
try:
# 以utf-8只读方式打开文本文件
with open(filename, "r", encoding="utf-8") as f:
# 读取文件的所有内容
text = f.read()
# 返回字符串
return text
except FileNotFoundError:
# 文件未找到记录日志
logger.error(f"文本文件不存在: {filename}")
raise
except Exception as e:
# 其它异常情况日志记录
logger.error(f"读取文本文件失败: {filename}, 错误: {str(e)}")
raise
# 定义读取Markdown文件内容的函数
def read_markdown_file(file_path: str) -> str:
"""
读取Markdown文件内容并返回
参数:
file_path (str): Markdown文件路径
返回:
str: 文件内容字符串
异常:
FileNotFoundError: 文件不存在
"""
try:
# 以utf-8编码只读打开Markdown文件
with open(file_path, "r", encoding="utf-8") as f:
# 读取并返回全部内容
return f.read()
except FileNotFoundError:
# 文件不存在日志
logger.error(f"Markdown文件不存在: {file_path}")
raise
except Exception as e:
# 其它异常日志
logger.error(f"读取Markdown文件失败: {file_path}, 错误: {str(e)}")
raise
┌─────────────────────────────────┐
│ 步骤1: 打开/加载文件 │
│ 使用对应的库打开文件 │
└─────────────────────────────────┘
↓
┌─────────────────────────────────┐
│ 步骤2: 解析/遍历内容 │
│ 根据不同格式特点解析 │
└─────────────────────────────────┘
↓
┌─────────────────────────────────┐
│ 步骤3: 提取文本内容 │
│ 收集所有文本片段 │
└─────────────────────────────────┘
↓
┌─────────────────────────────────┐
│ 步骤4: 拼接成字符串 │
│ 用分隔符连接所有文本 │
└─────────────────────────────────┘
↓
┌─────────────────────────────────┐
│ 步骤5: 返回结果 │
│ 返回最终的文本字符串 │
└─────────────────────────────────┘6.2 db.py #
db.py
# 导入 chromadb 库
import chromadb
# 导入 Optional 类型,用于类型标注
from typing import Optional
# 导入 logging 模块,用于日志记录
import logging
# 导入 sentence_transformers 库中的 SentenceTransformer 类
from sentence_transformers import SentenceTransformer
# 获取当前模块的 logger 实例
logger = logging.getLogger(__name__)
# 设置默认的集合名称
DEFAULT_COLLECTION_NAME = "rag"
# 设置默认的模型名称
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
# 设置默认的数据库文件路径
DEFAULT_DB_PATH = "./chroma_db"
# 定义全局变量 _model,用于存放 SentenceTransformer 实例,初始为 None
_model: Optional[SentenceTransformer] = None
# 定义全局变量 _client,用于存放 chromadb 的 PersistentClient 实例,初始为 None
_client: Optional[chromadb.PersistentClient] = None
# 定义获取嵌入模型的内部方法,如果未初始化则进行加载
def _get_model() -> SentenceTransformer:
"""
获取嵌入模型实例(单例模式)
返回:
SentenceTransformer: 嵌入模型实例
"""
# 声明使用全局变量 _model
global _model
# 如果 _model 尚未实例化,则进行初始化
if _model is None:
# 记录开始加载模型的日志
logger.info(f"正在加载嵌入模型: {DEFAULT_MODEL_NAME}")
# 加载 SentenceTransformer 模型
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
# 记录模型加载完成的日志
logger.info("嵌入模型加载完成")
# 返回模型实例
return _model
# 定义获取 ChromaDB 客户端的内部方法,如果未初始化则进行加载
def _get_client() -> chromadb.PersistentClient:
"""
获取ChromaDB客户端实例(单例模式)
返回:
chromadb.PersistentClient: 客户端实例
"""
# 声明使用全局变量 _client
global _client
# 如果 _client 尚未实例化,则进行初始化
if _client is None:
# 记录开始初始化客户端的日志,并输出路径信息
logger.info(f"正在初始化ChromaDB客户端,路径: {DEFAULT_DB_PATH}")
# 初始化 PersistentClient 实例
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
# 记录客户端初始化完成的日志
logger.info("ChromaDB客户端初始化完成")
# 返回客户端实例
return _client
# 定义将文本保存到 ChromaDB 的函数
def save_text_to_db(text: str, collection_name: str = DEFAULT_COLLECTION_NAME, source: Optional[str] = None) -> str:
"""
将文本保存到ChromaDB指定集合中,使用sentence_transformers生成embedding。
参数:
text (str): 要保存的文本
collection_name (str): 集合名称,默认为 "rag"
source (str, optional): 数据来源标识,默认为 "document"
返回:
str: 保存的文本ID
异常:
Exception: 保存失败
"""
try:
# 如果文本为空或者全是空白字符,直接记录警告并返回空字符串
if not text or not text.strip():
logger.warning("尝试保存空文本,已跳过")
return ""
# 获取全局模型实例
model = _get_model()
# 获取全局客户端实例
client = _get_client()
# 获取指定名称的集合,如果集合不存在则自动创建
collection = client.get_or_create_collection(collection_name)
# 使用文本内容的哈希值生成唯一的文本ID(转为正整数再转为字符串)
text_id = str(abs(hash(text)))
# 检查数据库中是否已经存在相同的ID
existing = collection.get(ids=[text_id])
# 如果存在该ID,说明相同内容已保存,无需重复保存
if existing and existing.get("ids"):
logger.debug(f"文本已存在,跳过保存,id={text_id}")
return text_id
# 生成文本的 embedding,模型处理结果为 ndarray,通过tolist 转换为列表
embedding = model.encode([text])[0].tolist()
# 向集合中添加文本、元数据、ID 以及 embedding(均为单元素列表)
collection.add(
documents=[text],
metadatas=[{"source": source or "document"}],
ids=[text_id],
embeddings=[embedding],
)
# 记录成功保存的调试日志,包含文本id和集合名称
logger.debug(f"文本已保存到ChromaDB,id={text_id}, collection={collection_name}")
# 返回本次保存的文本ID
return text_id
# 捕捉整个保存过程中的异常
except Exception as e:
# 记录错误日志并输出异常信息
logger.error(f"保存文本到数据库失败: {str(e)}")
# 抛出异常
raise文本输入 → 生成向量 → 检查重复 → 保存到数据库 → 返回ID
ChromaDB 存储结构:
┌─────────────┬──────────────┬─────┬─────────────┐
│ documents │ metadatas │ ids │ embeddings │
├─────────────┼──────────────┼─────┼─────────────┤
│ [文本内容] │[{source:..}] │ [ID]│ [384维向量] │
└─────────────┴──────────────┴─────┴─────────────┘
用户调用 save_text_to_db(text, collection_name, source)
↓
[输入验证]
文本是否为空?
├─ 是 → 记录警告 → 返回 ""
└─ 否 → 继续
↓
[初始化资源]
获取模型(单例,首次调用时加载)
获取客户端(单例,首次调用时连接)
↓
[获取集合]
client.get_or_create_collection()
↓
[生成ID]
text_id = str(abs(hash(text)))
↓
[检查重复]
数据库中是否已存在该ID?
├─ 是 → 记录调试日志 → 返回现有ID
└─ 否 → 继续
↓
[生成向量]
embedding = model.encode([text])[0].tolist()
↓
[保存到数据库]
collection.add(
documents=[text],
metadatas=[{"source": ...}],
ids=[text_id],
embeddings=[embedding]
)
↓
[返回结果]
记录成功日志 → 返回 text_id6.3 save.py #
save.py
# 导入os模块,用于路径和文件操作
import os
# 导入Optional类型用于类型注解(本文件其实未用到)
from typing import Optional
# 从db模块导入保存文本到数据库的函数
from db import save_text_to_db
# 导入extract模块,用于处理各种格式的文本提取
import extract
# 导入递归字符分割器,用于文本分块
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 导入logging模块,用于日志记录
import logging
# 配置日志的输出格式和日志级别
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 默认分块大小
DEFAULT_CHUNK_SIZE = 200
# 默认分块重叠长度
DEFAULT_CHUNK_OVERLAP = 30
# 默认集合名称
DEFAULT_COLLECTION_NAME = "rag"
# 定义自动根据文件类型提取文本内容的函数
def extract_text_auto(file_path: str) -> str:
"""
根据文件类型自动提取文本内容
参数:
file_path (str): 文件路径
返回:
str: 提取的文本内容
异常:
FileNotFoundError: 文件不存在
ValueError: 不支持的文件类型
"""
# 检查文件是否存在
if not os.path.exists(file_path):
# 文件不存在时记录错误日志
logger.error(f"文件不存在: {file_path}")
# 抛出文件不存在异常
raise FileNotFoundError(f"文件不存在: {file_path}")
# 获取文件扩展名,并转换为小写
ext = os.path.splitext(file_path)[-1].lower()
try:
# 如果是pdf文件
if ext == ".pdf":
logger.info(f"检测到PDF文件,开始提取文本: {file_path}")
return extract.extract_pdf_text(file_path)
# 如果是Word文档
elif ext in [".docx", ".doc"]:
logger.info(f"检测到Word文件,开始提取文本: {file_path}")
return extract.extract_text_from_word(file_path)
# 如果是Excel文件
elif ext in [".xlsx", ".xls"]:
logger.info(f"检测到Excel文件,开始提取文本: {file_path}")
return extract.extract_text_from_excel(file_path)
# 如果是PPT文件
elif ext in [".pptx", ".ppt"]:
logger.info(f"检测到PPT文件,开始提取文本: {file_path}")
return extract.extract_ppt_text(file_path)
# 如果是HTML文件
elif ext in [".html", ".htm"]:
logger.info(f"检测到HTML文件,开始提取文本: {file_path}")
return extract.extract_text_from_html(file_path)
# 如果是XML文件
elif ext == ".xml":
logger.info(f"检测到XML文件,开始提取文本: {file_path}")
return extract.extract_xml_text(file_path)
# 如果是CSV文件
elif ext == ".csv":
logger.info(f"检测到CSV文件,开始提取文本: {file_path}")
return extract.read_csv_to_text(file_path)
# 如果是JSON文件
elif ext == ".json":
logger.info(f"检测到JSON文件,开始提取文本: {file_path}")
return extract.extract_text_from_json(file_path)
# 如果是纯文本、Markdown、JSONL文件
elif ext in [".md", ".txt", ".jsonl"]:
logger.info(f"检测到文本/Markdown/JSONL文件,开始读取: {file_path}")
return extract.read_text_file(file_path)
# 其余不支持的文件类型
else:
logger.error(f"不支持的文件类型: {ext}")
raise ValueError(f"不支持的文件类型: {ext}")
# 捕获全部异常,记录日志并抛出
except Exception as e:
logger.error(f"提取文件内容失败: {file_path}, 错误: {str(e)}")
raise
# 定义文档入库的主流程函数
def doc_to_vectorstore(
file_path: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> int:
"""
将文档提取、分块并保存到向量数据库
参数:
file_path (str): 文件路径
collection_name (str): 集合名称,默认为 "rag"
chunk_size (int): 分块大小,默认为 200
chunk_overlap (int): 分块重叠长度,默认为 30
返回:
int: 成功保存的分块数量
异常:
FileNotFoundError: 文件不存在
ValueError: 不支持的文件类型或其他参数错误
"""
try:
# 步骤1:加载非结构化文本
logger.info(f"开始提取文件内容: {file_path}")
text = extract_text_auto(file_path)
logger.info(f"文件内容提取完成,长度为{len(text)}个字符")
# 检查文本是否为空
if not text.strip():
logger.warning(f"文件内容为空: {file_path}")
return 0
# 步骤2:将文本进行分块
logger.info(f"开始进行文本分块 (chunk_size={chunk_size}, chunk_overlap={chunk_overlap})")
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
chunks = splitter.split_text(text)
logger.info(f"文本分块完成,共分为{len(chunks)}块")
# 步骤3:为每个分块生成向量并保存入库
success_count = 0
for idx, chunk in enumerate(chunks):
try:
logger.info(f"正在保存第{idx+1}/{len(chunks)}块到向量数据库")
save_text_to_db(chunk, collection_name=collection_name)
success_count += 1
except Exception as e:
logger.error(f"保存第{idx+1}块失败: {str(e)}")
# 继续处理下一块,不中断整个流程
logger.info(f"文件 {file_path} 已完成入库,成功保存 {success_count}/{len(chunks)} 个分块")
return success_count
except FileNotFoundError:
logger.error(f"文件不存在: {file_path}")
raise
except Exception as e:
logger.error(f"文档入库失败: {file_path}, 错误: {str(e)}")
raise
# 程序入口,如果直接运行本脚本,则执行入库操作(示例:入库‘红楼梦.txt’)
if __name__ == "__main__":
doc_to_vectorstore('红楼梦.txt')用户调用 doc_to_vectorstore(file_path, ...)
↓
[步骤1:提取文本]
extract_text_auto(file_path)
↓
├─ 检查文件是否存在
├─ 识别文件类型(根据扩展名)
├─ 调用对应的提取函数
└─ 返回文本内容
↓
[步骤2:验证文本]
文本是否为空?
├─ 是 → 记录警告 → 返回 0
└─ 否 → 继续
↓
[步骤3:文本分块]
RecursiveCharacterTextSplitter.split_text()
↓
├─ 创建分块器(设置 chunk_size 和 chunk_overlap)
├─ 执行分块操作
└─ 返回分块列表
↓
[步骤4:保存到数据库]
遍历每个分块
↓
┌─────────────────────┐
│ 对每个chunk: │
│ 1. 调用save_text_to_db() │
│ 2. 生成向量嵌入 │
│ 3. 保存到ChromaDB │
│ 4. 记录成功/失败 │
└─────────────────────┘
↓
[步骤5:返回结果]
统计成功数量 → 返回 success_count6.4 llm.py #
llm.py
# 导入OpenAI客户端库
from openai import OpenAI
# 导入os库,用于读取环境变量
import os
# 导入logging库,用于记录日志
import logging
# 导入Optional类型,便于类型注解
from typing import Optional
# 获取当前模块的logger日志对象
logger = logging.getLogger(__name__)
# 从环境变量获取OPENAI_BASE_URL,若未设置则使用默认地址
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
# 从环境变量获取OPENAI_API_KEY,若未设置则使用默认测试key(生产环境必须配置!)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "d52e49a1-36ea-44bb-bc6e-65ce789a72f6")
# 从环境变量获取模型名称,若未设置则使用默认模型名
MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "doubao-seed-1-6-250615")
# 全局OpenAI客户端实例,初始为None,延迟初始化
_client: Optional[OpenAI] = None
# 获取OpenAI客户端实例(单例模式)
def _get_client() -> OpenAI:
"""
获取OpenAI客户端实例(单例模式)
返回:
OpenAI: 客户端实例
"""
# 声明全局变量_client
global _client
# 如果客户端尚未初始化,则进行初始化
if _client is None:
# 如果API密钥不存在,则抛出异常
if not OPENAI_API_KEY:
raise ValueError(
"OPENAI_API_KEY 未设置。请设置环境变量 OPENAI_API_KEY 或在代码中配置。"
)
# 使用指定的base_url和api_key初始化OpenAI客户端
_client = OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)
# 记录客户端初始化成功的日志
logger.info(f"OpenAI客户端已初始化,base_url: {OPENAI_BASE_URL}")
# 返回客户端实例
return _client
# 定义调用大模型的函数
def invoke(prompt: str, model: Optional[str] = None, temperature: float = 0.7) -> str:
"""
调用大模型生成回复
参数:
prompt (str): 输入的提示词
model (str, optional): 模型名称,默认使用环境变量或默认值
temperature (float): 生成温度,默认0.7
返回:
str: 大模型生成的回复内容
异常:
ValueError: API密钥未设置
Exception: API调用失败
"""
try:
# 获取OpenAI客户端对象
client = _get_client()
# 如果model参数为空,则使用默认模型名
model_name = model or MODEL_NAME
# 记录调试日志,显示模型名和prompt长度
logger.debug(f"调用大模型,model: {model_name}, prompt长度: {len(prompt)}")
# 调用OpenAI聊天模型接口生成回复
response = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
],
}
],
temperature=temperature,
)
# 获取大模型生成的回复内容
content = response.choices[0].message.content
# 记录调试日志,标记回复内容的长度
logger.debug(f"大模型回复生成成功,长度: {len(content) if content else 0}")
# 返回回复内容(若为空则返回空字符串)
return content or ""
# 捕捉并处理ValueError异常(如API密钥未配置)
except ValueError as e:
logger.error(f"配置错误: {str(e)}")
raise
# 捕捉并处理所有其他异常
except Exception as e:
logger.error(f"调用大模型失败: {str(e)}")
raise用户调用 invoke(prompt, model=None, temperature=0.7)
↓
[获取客户端]
client = _get_client()
↓
├─ 检查客户端是否已初始化
├─ 否 → 创建新客户端(单例)
└─ 是 → 使用现有客户端
↓
[确定模型]
model_name = model or MODEL_NAME
↓
[记录日志]
logger.debug("调用大模型...")
↓
[调用API]
response = client.chat.completions.create(
model=model_name,
messages=[{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}],
temperature=temperature
)
↓
[提取回复]
content = response.choices[0].message.content
↓
[记录日志]
logger.debug("回复生成成功...")
↓
[返回结果]
return content or ""6.5 query.py #
query.py
# 导入sentence_transformers库中的SentenceTransformer类
from sentence_transformers import SentenceTransformer
# 导入chromadb库
import chromadb
# 从typing库导入List和Optional类型
from typing import List, Optional
# 导入logging库用于日志记录
import logging
# 导入llm模块(自定义的大模型API封装)
import llm
# 配置日志:设置日志等级为INFO,指定日志格式
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
)
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 默认集合名称,存储块的标签名
DEFAULT_COLLECTION_NAME = "rag"
# 默认嵌入模型名称
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
# 默认Chroma数据库路径
DEFAULT_DB_PATH = "./chroma_db"
# 默认检索返回文本块数目
DEFAULT_N_RESULTS = 3
# 全局SentenceTransformer模型实例(延迟初始化)
_model: Optional[SentenceTransformer] = None
# 全局ChromaDB客户端实例(延迟初始化)
_client: Optional[chromadb.PersistentClient] = None
# 全局ChromaDB集合实例(延迟初始化)
_collection: Optional[chromadb.Collection] = None
# 获取嵌入模型实例(单例模式,只有一个模型)
def _get_model() -> SentenceTransformer:
# 声明使用全局变量_model
global _model
# 如果还没有实例化,则初始化模型
if _model is None:
# 打印加载模型信息
logger.info(f"正在加载嵌入模型: {DEFAULT_MODEL_NAME}")
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
# 加载完成
logger.info("嵌入模型加载完成")
# 返回模型实例
return _model
# 获取ChromaDB客户端实例(单例模式)
def _get_client() -> chromadb.PersistentClient:
# 声明全局变量_client
global _client
# 如果客户端还未初始化,则进行初始化
if _client is None:
# 打印初始化信息
logger.info(f"正在初始化ChromaDB客户端,路径: {DEFAULT_DB_PATH}")
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
logger.info("ChromaDB客户端初始化完成")
# 返回客户端实例
return _client
# 获取或创建集合实例(单例模式),collection_name可指定集合名
def _get_collection(collection_name: str = DEFAULT_COLLECTION_NAME) -> chromadb.Collection:
# 声明全局变量_collection
global _collection
# 如果集合还未初始化,则获取或创建集合
if _collection is None:
# 获取客户端
client = _get_client()
# 打印获取/创建集合信息
logger.info(f"正在获取或创建集合: {collection_name}")
_collection = client.get_or_create_collection(collection_name)
logger.info(f"集合 '{collection_name}' 已准备就绪")
# 返回集合实例
return _collection
# 将query字符串转为embedding向量
def get_query_embedding(query: str) -> List[float]:
"""
将查询文本转换为embedding向量
参数:
query (str): 查询文本
返回:
List[float]: embedding向量
"""
# 打印debug信息,开始向量化
logger.debug("正在将Query转为向量...")
# 获取模型实例
model = _get_model()
# 调用模型将输入文本转为embedding,并转为list
embedding = model.encode(query).tolist()
# 打印向量化完成的debug信息
logger.debug(f"Query向量化完成,向量维度: {len(embedding)}")
# 返回embedding
return embedding
# 向量检索,返回最相关的文本块列表
def retrieve_related_chunks(
query_embedding: List[float],
n_results: int = DEFAULT_N_RESULTS,
collection_name: str = DEFAULT_COLLECTION_NAME,
) -> List[str]:
"""
向量检索,返回最相关的文本块列表
参数:
query_embedding (List[float]): 查询向量
n_results (int): 返回的结果数量,默认为3
collection_name (str): 集合名称,默认为 "rag"
返回:
List[str]: 最相关的文本块列表
异常:
ValueError: 未检索到相关内容
"""
try:
# 打印检索动作的日志
logger.info(f"正在进行向量检索,返回最相关的{n_results}个文本块...")
# 获取集合实例
collection = _get_collection(collection_name)
# 在指定集合中做向量相似度检索,n_results为最多返回的结果数
results = collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
# 获取检索到的文档内容
related_chunks = results.get("documents")
# 检查是否检索到相关内容
if not related_chunks or not related_chunks[0]:
# 未检索到内容则打印警告并抛出异常
logger.warning("未检索到相关内容,请先入库或检查数据库!")
raise ValueError("未检索到相关内容,请先入库或检查数据库!")
# 打印检索到的文本块数量
logger.info(f"成功检索到{len(related_chunks[0])}个相关文本块")
# 返回第一个结果list(按设计,一个query只查一个batch,取[0]即可)
return related_chunks[0]
except Exception as e:
# 打印并抛出错误
logger.error(f"向量检索失败: {str(e)}")
raise
# RAG查询主函数:向量检索 + LLM生成答案
def query_rag(
query: str,
n_results: int = DEFAULT_N_RESULTS,
collection_name: str = DEFAULT_COLLECTION_NAME,
) -> str:
"""
RAG查询主函数:向量检索 + LLM生成答案
参数:
query (str): 用户查询问题
n_results (int): 检索的文档块数量,默认为3
collection_name (str): 集合名称,默认为 "rag"
返回:
str: LLM生成的答案
异常:
ValueError: 检索失败或未找到相关内容
"""
try:
# 打印RAG查询开始日志
logger.info(f"开始RAG查询: {query}")
# 步骤1:将查询文本转为向量
query_embedding = get_query_embedding(query)
# 步骤2:基于query embedding做向量检索
related_chunks = retrieve_related_chunks(
query_embedding, n_results=n_results, collection_name=collection_name
)
# 步骤3:将检索到的文本块合并为上下文,拼接prompt
context = "\n".join(related_chunks)
prompt = f"已知信息:\n{context}\n\n请根据上述内容回答用户问题:{query}"
# 打印构建的prompt长度
logger.debug(f"Prompt已构建,长度: {len(prompt)}")
# 步骤4:调用llm.invoke(大语言模型调用)生成最终答案
logger.info("正在调用大模型生成答案...")
answer = llm.invoke(prompt)
# 打印答案生成完成
logger.info("答案生成完成")
# 返回模型生成的答案
return answer
except ValueError as e:
# 捕获并打印检索失败相关的异常
logger.error(f"RAG查询失败: {str(e)}")
raise
except Exception as e:
# 捕获并打印所有其他异常
logger.error(f"RAG查询过程中发生错误: {str(e)}")
raise
# 主程序入口,支持直接命令行运行本脚本
if __name__ == "__main__":
# 设定一个查询问题
query = "红楼梦的作者是谁?"
logger.info(f"用户查询: {query}")
try:
# 进行RAG查询,设置n_results为10
answer = query_rag(query, n_results=10)
# 打印结果
print("\n【答案】\n", answer)
except ValueError as e:
# 捕获未找到相关内容的错误,打印提示
print(f"\n【错误】\n{str(e)}")
except Exception as e:
# 捕获程序异常,打印日志并提示
logger.exception("程序执行失败")
print(f"\n【错误】\n程序执行失败: {str(e)}")
步骤1: 查询向量化
用户问题 → 嵌入模型 → 查询向量
"红楼梦的作者是谁?" → [384维向量]
步骤2: 向量检索
查询向量 → ChromaDB → 最相关的文档块
[向量] → 相似度计算 → ["文档块1", "文档块2", ...]
步骤3: 构建 Prompt
文档块 + 用户问题 → 完整的 Prompt
"""
已知信息:
文档块1...
文档块2...
请根据上述内容回答用户问题:红楼梦的作者是谁?
"""
步骤4: LLM 生成答案
Prompt → 大语言模型 → 最终答案
"已知信息..." → GPT/豆包 → "红楼梦的作者是曹雪芹。"