1.项目说明 #
1.1 项目简介 #
本项目基于 Python、Milvus 向量数据库和火山引擎豆包Embedding模型,实现了一个"文搜文"Web系统。支持文档上传、文本向量化、语义检索。
1.2 主要功能 #
- 支持上传本地txt文档,自动向量化并入库
- 支持输入任意中文文本,检索最相关的文档
- 检索结果展示文档内容及相似度分数
- 简洁的Flask Web页面
- 向量模型采用火山引擎"豆包Embedding"API
- 向量数据库采用 Milvus,检索方式为余弦相似度(COSINE)
1.3 安装 #
# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境(Windows)
.venv\Scripts\activate.bat
# 激活虚拟环境(Linux/macOS)
source .venv/Scripts/activate
# 安装依赖d
pip install flask python-dotenv requests pymilvus cos-python-sdk-v5 pdfplumber python-docx1.4 参考 #
2. 创建服务器 #
2.1. .env #
.env
# 秘钥,用于加密会话
SECRET_KEY="rensheng"
# 上传文件夹路径
UPLOAD_FOLDER = "uploads"
# 火山引擎文本向量化API地址
VOLC_EMBEDDINGS_API_URL="https://ark.cn-beijing.volces.com/api/v3/embeddings"
# 火山引擎多模态(图文)向量化API地址
VOLC_EMBEDDINGS_VISION_API_URL="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
# 火山引擎API密钥
VOLC_API_KEY="d52e49a1-36ea-44bb-bc6e-65ce789a72f6"
# 腾讯云SecretId
TENCENT_SECRET_ID="AKIDYzEI9nDRgrxufDi5FB0ueLuJZo7ITyQm"
# 腾讯云SecretKey
TENCENT_SECRET_KEY="HXmu11voYAdv483nS0OacVAg6Sj3GLFz"
# 腾讯云存储区域
TENCENT_REGION="ap-beijing"
# 腾讯云存储桶名称
TENCENT_BUCKET="xiaohongshu-1258145019"2.2. app.py #
app.py
# 导入Flask框架中的Flask类和render_template函数,用于创建应用和渲染模板
from flask import Flask, render_template
# 导入dotenv库中的load_dotenv函数,用于加载.env环境变量文件
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 加载.env文件中的环境变量
load_dotenv()
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 定义一个创建Flask应用的工厂函数
def create_app():
# 创建Flask应用实例
app = Flask(__name__)
# 定义根路由"/"的视图函数
@app.route("/")
def home():
# 记录访问首页的警告日志
logger.warning("访问首页")
# 渲染home.html模板并返回
return render_template("home.html")
# 返回Flask应用实例
return app
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
# 创建Flask应用
app = create_app()
# 启动Flask开发服务器
app.run(debug=True)
2.3. home.html #
templates/home.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>AI智能检索首页</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.nav {
margin: 40px auto;
max-width: 400px;
}
.nav a {
display: block;
margin: 20px 0;
padding: 16px;
background: #f0f0f0;
border-radius: 8px;
text-align: center;
text-decoration: none;
color: #333;
font-size: 18px;
}
.nav a:hover {
background: #d0eaff;
}
</style>
</head>
<body>
<div class="nav">
<h1>AI智能检索</h1>
</div>
</body>
</html>2.4. .gitignore #
.gitignore
.venv
__pycache__
uploads3.上传文档 #
3.1. doc_service.py #
services/doc_service.py
# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename函数,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file函数,用于判断文件类型是否合法
from utils import allowed_doc_file
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 从环境变量中获取上传文件夹路径,默认为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 定义处理文档上传的函数,参数为request对象
def handle_doc_upload(request):
# 如果请求方法为POST
if request.method == "POST":
# 记录收到文档上传POST请求的警告日志
logger.warning("收到文档上传POST请求")
# 如果请求中没有文件部分
if "file" not in request.files:
# 记录未选择文件的警告日志
logger.warning("未选择文件")
# 闪现提示未选择文件
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 如果文件名为空
if file.filename == "":
# 记录文件名为空的警告日志
logger.warning("文件名为空")
# 闪现提示未选择文件
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 如果文件存在且文件类型合法
if file and allowed_doc_file(file.filename):
# 对文件名进行安全处理
filename = secure_filename(file.filename)
# 拼接文件保存路径
file_path = os.path.join(UPLOAD_FOLDER, filename)
# 保存文件到指定路径
file.save(file_path)
# 记录文件已保存的警告日志
logger.warning(f"文件已保存到: {file_path}")
# 闪现提示文件上传成功,类型为success
flash("文件上传成功", "success")
# 重定向到上传文档页面
return redirect(url_for("upload_doc"))
# 渲染上传文档页面
return render_template("upload_doc.html")
3.2. upload_doc.html #
templates/upload_doc.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>上传文档 - 文搜文</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.container {
max-width: 600px;
margin: auto;
}
.section {
margin-bottom: 32px;
padding: 24px;
border: 1px solid #eee;
border-radius: 8px;
}
.flash {
color: red;
}
.flash.success {
color: green;
}
.flash.error {
color: red;
}
</style>
</head>
<body>
<h1>上传文档</h1>
<a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
<div class="section">
<form method="post" enctype="multipart/form-data" action="{{ url_for('upload_doc') }}">
<input type="file" name="file" accept=".txt,.docx,.pdf">
<button type="submit">上传</button>
</form>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
<ul class="flashes">
{% for category, message in messages %}
<li class="flash {{ category }}">{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
</div>
</body>
</html>3.3. utils.py #
utils.py
# 允许上传的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}
# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
# 检查文件名中是否包含点,并且扩展名是否在允许的集合中
return (
"." in filename
and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
)
3.4. app.py #
app.py
# 导入os模块,用于读取环境变量
+import os
# 从Flask框架中导入Flask类、render_template函数和request对象
+from flask import Flask, render_template, request
# 从dotenv库中导入load_dotenv函数,用于加载.env文件中的环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从自定义的services.doc_service模块中导入handle_doc_upload函数,用于处理文档上传逻辑
+from services.doc_service import handle_doc_upload
# 加载.env文件中的环境变量
load_dotenv()
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 定义一个创建Flask应用的工厂函数
def create_app():
# 创建Flask应用实例
app = Flask(__name__)
# 设置应用的密钥,用于会话加密,从环境变量中获取
+ app.secret_key = os.getenv("SECRET_KEY")
# 定义根路由"/"的视图函数
@app.route("/")
def home():
# 记录访问首页的警告日志
logger.warning("访问首页")
# 渲染home.html模板并返回
return render_template("home.html")
# 定义上传文档路由"/upload_doc",支持GET和POST方法
+ @app.route("/upload_doc", methods=["GET", "POST"])
+ def upload_doc():
+ # 记录访问上传文档页面的警告日志
+ logger.warning("访问上传文档页面")
+ # 调用handle_doc_upload函数处理上传逻辑,并返回结果
+ return handle_doc_upload(request)
# 返回Flask应用实例
return app
# 如果当前模块是主程序入口,则启动应用
if __name__ == "__main__":
# 创建Flask应用
app = create_app()
# 启动Flask开发服务器
app.run(debug=True)
3.5. home.html #
templates/home.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>AI智能检索首页</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.nav {
margin: 40px auto;
max-width: 400px;
}
.nav a {
display: block;
margin: 20px 0;
padding: 16px;
background: #f0f0f0;
border-radius: 8px;
text-align: center;
text-decoration: none;
color: #333;
font-size: 18px;
}
.nav a:hover {
background: #d0eaff;
}
</style>
</head>
<body>
<div class="nav">
<h1>AI智能检索</h1>
+ <a href="{{ url_for('upload_doc') }}">上传文档</a>
</div>
</body>
</html>4.获取文档内容 #
4.1. doc_service.py #
services/doc_service.py
# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils中导入secure_filename,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file和extract_doc_content函数
+from utils import allowed_doc_file, extract_doc_content
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认值为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 处理文档上传的主函数
def handle_doc_upload(request):
# 如果请求方法为POST,表示有文件上传
if request.method == "POST":
# 记录收到POST请求的日志
logger.warning("收到文档上传POST请求")
# 检查请求中是否包含文件
if "file" not in request.files:
# 记录未选择文件的日志
logger.warning("未选择文件")
# 弹出提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 检查文件名是否为空
if file.filename == "":
# 记录文件名为空的日志
logger.warning("文件名为空")
# 弹出提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 检查文件是否存在且文件类型是否允许
if file and allowed_doc_file(file.filename):
# 对文件名进行安全处理
filename = secure_filename(file.filename)
# 构建文件保存路径
file_path = os.path.join(UPLOAD_FOLDER, filename)
# 保存文件到指定路径
file.save(file_path)
# 记录文件保存成功的日志
logger.warning(f"文件已保存到: {file_path}")
# 提取文档内容
+ doc_content = extract_doc_content(file_path)
# 记录文档内容提取成功及内容长度的日志
+ logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
# 弹出文件上传成功的提示信息
flash("文件上传成功", "success")
# 重定向到上传文档页面
return redirect(url_for("upload_doc"))
# 如果不是POST请求,则渲染上传文档页面
return render_template("upload_doc.html")
4.2. utils.py #
utils.py
# 允许上传的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}
# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
# 检查文件名中是否包含点,并且扩展名是否在允许的集合中
return (
"." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
)
# 提取文档内容,根据不同文件类型分别处理
+def extract_doc_content(file_path):
"""
根据文件类型(txt, pdf, docx)提取文件内容。
:param file_path: 文件路径
:return: 文件内容字符串
+ """
# 获取文件扩展名并转为小写
+ ext = file_path.rsplit(".", 1)[-1].lower()
# 如果是txt文件
+ if ext == "txt":
+ # 以utf-8编码读取文本内容
+ with open(file_path, "r", encoding="utf-8") as f:
+ return f.read()
# 如果是pdf文件
+ elif ext == "pdf":
# 导入pdfplumber库
+ import pdfplumber
# 初始化文本内容为空字符串
+ text = ""
# 使用pdfplumber打开pdf文件
+ with pdfplumber.open(file_path) as pdf:
+ # 遍历每一页,提取文本内容
+ for page in pdf.pages:
+ text += page.extract_text() or ""
# 返回提取的文本内容
+ return text
# 如果是docx文件
+ elif ext == "docx":
# 导入Document类
+ from docx import Document
+
# 加载docx文档
+ doc = Document(file_path)
# 将所有段落的文本拼接为一个字符串
+ return "\n".join([para.text for para in doc.paragraphs])
# 如果文件类型不支持
+ else:
# 抛出异常,提示不支持的文件类型
+ raise ValueError("不支持的文件类型: " + ext)5.计算文档向量 #
5.1. embedding_utils.py #
embedding_utils.py
# 导入os模块,用于读取环境变量
import os
# 导入requests库,用于发送HTTP请求
import requests
# 从环境变量中获取文本向量API的URL
VOLC_EMBEDDINGS_API_URL = os.getenv("VOLC_EMBEDDINGS_API_URL")
# 从环境变量中获取视觉向量API的URL
VOLC_EMBEDDINGS_VISION_API_URL = os.getenv("VOLC_EMBEDDINGS_VISION_API_URL")
# 从环境变量中获取API密钥
VOLC_API_KEY = os.getenv("VOLC_API_KEY")
# 定义获取文档向量的函数,参数为文档内容
def get_doc_embedding(doc_content):
# 构造请求头,包含内容类型和认证信息
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {VOLC_API_KEY}",
}
# 构造请求体,指定模型和输入内容
payload = {"model": "doubao-embedding-text-240715", "input": doc_content}
# 发送POST请求到向量API,获取响应
response = requests.post(VOLC_EMBEDDINGS_API_URL, json=payload, headers=headers)
# 如果响应状态码为200,表示请求成功
if response.status_code == 200:
# 解析响应的JSON数据
data = response.json()
# 提取嵌入向量
embedding = data["data"][0]["embedding"]
# 返回嵌入向量
return embedding
else:
# 如果请求失败,抛出异常并输出错误信息
raise Exception(f"Embedding API error: {response.text}")
5.2. doc_service.py #
services/doc_service.py
import os
import logging
from flask import redirect, flash, render_template, url_for
from werkzeug.utils import secure_filename
from utils import allowed_doc_file, extract_doc_content
# 从embedding_utils模块导入get_doc_embedding函数,用于获取文档向量
+from embedding_utils import get_doc_embedding
logger = logging.getLogger(__name__)
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
def handle_doc_upload(request):
if request.method == "POST":
logger.warning("收到文档上传POST请求")
if "file" not in request.files:
logger.warning("未选择文件")
flash("未选择文件")
return redirect(request.url)
file = request.files["file"]
if file.filename == "":
logger.warning("文件名为空")
flash("未选择文件")
return redirect(request.url)
if file and allowed_doc_file(file.filename):
filename = secure_filename(file.filename)
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
logger.warning(f"文件已保存到: {file_path}")
doc_content = extract_doc_content(file_path)
logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
# 计算文档向量
+ doc_embedding = get_doc_embedding(doc_content)
# 记录文档向量计算成功的日志
+ logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
insert_doc_vectors([doc_embedding], [doc_content])
flash("文件上传成功", "success")
return redirect(url_for("upload_doc"))
return render_template("upload_doc.html")
6.保存向量 #
6.1. init.py #
milvus_utils/init.py
# 导入pymilvus中的connections模块,用于连接Milvus服务
from pymilvus import connections
# 从当前包的doc模块导入init_doc_collection和insert_doc_vectors函数
from .doc import (
init_doc_collection,
insert_doc_vectors,
)
# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False
# 定义确保Milvus连接的函数,默认主机为localhost,端口为19530
def ensure_milvus_connection(host="localhost", port="19530"):
# 声明使用全局变量_milvus_connected
global _milvus_connected
# 如果尚未连接Milvus,则进行连接和集合初始化
if not _milvus_connected:
# 连接到Milvus服务
connections.connect(host=host, port=port)
# 初始化文档向量集合
init_doc_collection()
# 设置连接标志为True,避免重复连接
_milvus_connected = True
# 定义模块对外暴露的成员
__all__ = [
ensure_milvus_connection,
"insert_doc_vectors",
]
6.2. doc.py #
milvus_utils/doc.py
# 从pymilvus库中导入Collection、FieldSchema、CollectionSchema、DataType和utility模块
from pymilvus import (
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility,
)
# 定义Milvus中用于存储文档向量的集合名称
DOC_COLLECTION_NAME = "doc_vectors"
# 定义初始化文档向量集合的函数
def init_doc_collection():
"""
初始化Milvus连接和文档向量集合。
"""
# 如果Milvus中不存在指定名称的集合,则进行创建
if not utility.has_collection(DOC_COLLECTION_NAME):
# 定义集合的字段,包括自增主键id、浮点向量embedding和原文档内容doc
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560),
FieldSchema(name="doc", dtype=DataType.VARCHAR, max_length=4096),
]
# 根据字段定义创建集合的schema
schema = CollectionSchema(fields, description="文档向量集合")
# 创建集合对象
col = Collection(DOC_COLLECTION_NAME, schema)
# 定义向量字段的索引参数
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "COSINE",
"params": {"nlist": 128},
}
# 为embedding字段创建索引
col.create_index(field_name="embedding", index_params=index_params)
# 加载集合到内存
col.load()
else:
# 如果集合已存在,则直接加载到内存
col = Collection(DOC_COLLECTION_NAME)
col.load()
# 定义插入文档向量和原文档内容到Milvus的函数
def insert_doc_vectors(embeddings, docs):
"""
插入文本向量及原文档内容到Milvus。
"""
# 获取指定名称的集合对象
col = Collection(DOC_COLLECTION_NAME)
# 组织插入的数据,顺序与schema一致
data = [embeddings, docs]
# 执行插入操作
col.insert(data)
6.3. app.py #
app.py
# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 从dotenv库导入load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
+from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入handle_doc_upload函数,处理文档上传
from services.doc_service import handle_doc_upload
# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
+ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 定义创建Flask应用的工厂函数
def create_app():
# 创建Flask应用实例
app = Flask(__name__)
# 设置应用的密钥,用于会话加密
app.secret_key = os.getenv("SECRET_KEY")
# 定义根路由,访问首页
@app.route("/")
def home():
# 记录访问首页的警告日志
logger.warning("访问首页")
# 渲染home.html模板
return render_template("home.html")
# 定义上传文档的路由,支持GET和POST方法
@app.route("/upload_doc", methods=["GET", "POST"])
def upload_doc():
# 记录访问上传文档页面的警告日志
logger.warning("访问上传文档页面")
# 调用handle_doc_upload函数处理上传逻辑
return handle_doc_upload(request)
# 返回Flask应用实例
return app
# 判断当前模块是否为主程序入口
if __name__ == "__main__":
# 创建Flask应用
app = create_app()
# 启动Flask开发服务器
app.run(debug=True)
6.4. doc_service.py #
services/doc_service.py
# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入重定向、消息闪现、模板渲染和url生成相关函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename函数,用于安全处理上传文件名
from werkzeug.utils import secure_filename
# 导入自定义工具函数,判断文件类型和提取文档内容
from utils import allowed_doc_file, extract_doc_content
# 导入获取文档向量的函数
from embedding_utils import get_doc_embedding
# 从milvus_utils模块导入插入向量的函数
+from milvus_utils import insert_doc_vectors
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认值为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 定义处理文档上传的函数
def handle_doc_upload(request):
# 如果请求方法为POST,表示有文件上传
if request.method == "POST":
# 记录收到POST请求的日志
logger.warning("收到文档上传POST请求")
# 检查请求中是否包含文件
if "file" not in request.files:
# 记录未选择文件的日志
logger.warning("未选择文件")
# 闪现提示信息
flash("未选择文件")
# 重定向回上传页面
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 检查文件名是否为空
if file.filename == "":
# 记录文件名为空的日志
logger.warning("文件名为空")
# 闪现提示信息
flash("未选择文件")
# 重定向回上传页面
return redirect(request.url)
# 检查文件是否存在且类型允许
if file and allowed_doc_file(file.filename):
# 对文件名进行安全处理
filename = secure_filename(file.filename)
# 拼接文件保存路径
file_path = os.path.join(UPLOAD_FOLDER, filename)
# 保存文件到指定路径
file.save(file_path)
# 记录文件保存成功的日志
logger.warning(f"文件已保存到: {file_path}")
# 提取文档内容
doc_content = extract_doc_content(file_path)
# 记录文档内容提取成功的日志
logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
# 获取文档的向量表示
doc_embedding = get_doc_embedding(doc_content)
# 记录向量计算成功的日志
logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
# 插入文档向量和原文档内容到Milvus数据库
+ insert_doc_vectors([doc_embedding], [doc_content])
# 闪现上传成功的提示信息
flash("文件上传成功", "success")
# 重定向回上传页面
return redirect(url_for("upload_doc"))
# 渲染上传文档页面
return render_template("upload_doc.html")
7.检索文档向量 #
7.1. search_doc.html #
templates/search_doc.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>文本检索 - 文搜文</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.container {
max-width: 600px;
margin: auto;
}
.section {
margin-bottom: 32px;
padding: 24px;
border: 1px solid #eee;
border-radius: 8px;
}
.results {
margin-top: 24px;
}
.result-item {
margin-bottom: 16px;
padding: 12px;
background: #f8f8f8;
border-radius: 6px;
}
.flash {
color: red;
}
.flash.success {
color: green;
}
.flash.error {
color: red;
}
</style>
</head>
<body>
<h1>文本检索</h1>
<a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
<div class="section">
<form action="{{ url_for('search_doc') }}" method="post">
<input type="text" name="query" value="{{ query if query else '' }}" style="width:70%" placeholder="请输入检索文本">
<button type="submit">检索</button>
</form>
<div class="results">
{% if results is defined and results|length > 0 %}
<h5>检索结果:</h5>
{% for item in results %}
<div class="result-item">
<div><b>文档:</b>{{ item['doc'] }}</div>
<div><b>相似度:</b>{{ '%.4f'|format(item['score']) }}</div>
</div>
{% endfor %}
{% elif query and results is defined and results|length == 0 %}
<div>未检索到相关文档。</div>
{% endif %}
</div>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
<ul class="flashes">
{% for category, message in messages %}
<li class="flash {{ category }}">{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
</div>
</body>
</html>7.2. app.py #
app.py
# 导入os模块,用于操作环境变量等
import os
# 从flask中导入Flask、render_template、request对象
from flask import Flask, render_template, request
# 导入dotenv的load_dotenv函数,用于加载.env环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils中导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索的处理函数
+from services.doc_service import handle_doc_upload, handle_doc_search
# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建Flask应用的工厂函数
def create_app():
# 实例化Flask应用
app = Flask(__name__)
# 设置Flask应用的密钥
app.secret_key = os.getenv("SECRET_KEY")
# 定义首页路由
@app.route("/")
def home():
# 记录访问首页的日志
logger.warning("访问首页")
# 渲染首页模板
return render_template("home.html")
# 定义上传文档的路由,支持GET和POST方法
@app.route("/upload_doc", methods=["GET", "POST"])
def upload_doc():
# 记录访问上传文档页面的日志
logger.warning("访问上传文档页面")
# 调用文档上传处理函数
return handle_doc_upload(request)
+ # 定义文本检索的路由,支持GET和POST方法
+ @app.route("/search_doc", methods=["GET", "POST"])
+ def search_doc():
+ # 记录访问文本检索页面的日志
+ logger.warning("访问文本检索页面")
+ # 调用文档检索处理函数
+ return handle_doc_search(request)
# 返回Flask应用实例
return app
# 判断当前脚本是否作为主程序运行
if __name__ == "__main__":
# 调用工厂函数创建Flask应用实例
app = create_app()
# 启动Flask开发服务器
app.run(debug=True)
7.3. init.py #
milvus_utils/init.py
# 导入pymilvus的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从当前包的doc模块导入初始化集合、插入向量和检索向量的函数
+from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors
# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False
# 定义确保Milvus连接的函数
def ensure_milvus_connection(host="localhost", port="19530"):
# 声明使用全局变量
global _milvus_connected
# 如果尚未连接Milvus,则进行连接和集合初始化
if not _milvus_connected:
# 连接到Milvus服务器
connections.connect(host=host, port=port)
# 初始化文档向量集合
init_doc_collection()
# 设置已连接标志为True
_milvus_connected = True
# 定义模块对外暴露的成员
__all__ = [
# 暴露ensure_milvus_connection函数
ensure_milvus_connection,
# 暴露insert_doc_vectors函数
"insert_doc_vectors",
+ # 暴露search_doc_vectors函数
+ "search_doc_vectors",
]
7.4. doc.py #
milvus_utils/doc.py
# 从pymilvus库导入所需的类和方法
from pymilvus import (
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility,
)
# 定义Milvus中用于存储文档向量的集合名称
DOC_COLLECTION_NAME = "doc_vectors"
# 初始化文档向量集合(Collection),如不存在则创建
def init_doc_collection():
"""
初始化Milvus连接和文档向量集合。
"""
# 如果集合不存在,则创建集合及其schema
if not utility.has_collection(DOC_COLLECTION_NAME):
# 定义集合的字段,包括主键id、向量字段embedding、文档内容doc
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560),
FieldSchema(name="doc", dtype=DataType.VARCHAR, max_length=4096),
]
# 创建集合schema
schema = CollectionSchema(fields, description="文档向量集合")
# 创建集合
col = Collection(DOC_COLLECTION_NAME, schema)
# 定义向量索引参数
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "COSINE",
"params": {"nlist": 128},
}
# 为embedding字段创建索引
col.create_index(field_name="embedding", index_params=index_params)
# 加载集合到内存
col.load()
else:
# 如果集合已存在,直接加载
col = Collection(DOC_COLLECTION_NAME)
col.load()
# 向Milvus插入文档向量和原文档内容
def insert_doc_vectors(embeddings, docs):
"""
插入文本向量及原文档内容到Milvus。
"""
# 获取集合对象
col = Collection(DOC_COLLECTION_NAME)
# 组织插入数据
data = [embeddings, docs]
# 插入数据到集合
col.insert(data)
# 检索与查询向量最相似的文档
+def search_doc_vectors(query_embedding, top_k=5):
+ """
+ 检索最相似的文本向量。
+ """
+ # 获取集合对象
+ col = Collection(DOC_COLLECTION_NAME)
+ # 执行向量检索,返回最相似的top_k条结果
+ results = col.search(
+ data=[query_embedding],
+ anns_field="embedding",
+ param={"metric_type": "COSINE", "params": {"nprobe": 10}},
+ limit=top_k,
+ output_fields=["doc"],
+ )
+ # 返回检索结果
+ return results
7.5. doc_service.py #
services/doc_service.py
# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file和extract_doc_content函数
from utils import allowed_doc_file, extract_doc_content
# 从embedding_utils模块导入get_doc_embedding函数
from embedding_utils import get_doc_embedding
+# 从milvus_utils模块导入insert_doc_vectors和search_doc_vectors函数
+from milvus_utils import insert_doc_vectors, search_doc_vectors
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 处理文档上传的函数
def handle_doc_upload(request):
# 如果请求方法为POST
if request.method == "POST":
# 记录收到文档上传POST请求的日志
logger.warning("收到文档上传POST请求")
# 如果请求中没有"file"字段
if "file" not in request.files:
# 记录未选择文件的日志
logger.warning("未选择文件")
# 闪现提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 如果文件名为空
if file.filename == "":
# 记录文件名为空的日志
logger.warning("文件名为空")
# 闪现提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 如果文件存在且类型允许
if file and allowed_doc_file(file.filename):
# 对文件名进行安全处理
filename = secure_filename(file.filename)
# 拼接文件保存路径
file_path = os.path.join(UPLOAD_FOLDER, filename)
# 保存文件到指定路径
file.save(file_path)
# 记录文件保存成功的日志
logger.warning(f"文件已保存到: {file_path}")
# 提取文档内容
doc_content = extract_doc_content(file_path)
# 记录文档内容提取成功的日志
logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
# 获取文档的向量表示
doc_embedding = get_doc_embedding(doc_content)
# 记录向量计算成功的日志
logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
# 插入文档向量和原文档内容到Milvus数据库
insert_doc_vectors([doc_embedding], [doc_content])
# 闪现上传成功的提示信息
flash("文件上传成功", "success")
# 重定向回上传页面
return redirect(url_for("upload_doc"))
# 渲染上传文档页面
return render_template("upload_doc.html")
+# 处理文本检索的函数
+def handle_doc_search(request):
+ # 初始化检索结果列表
+ results = []
+ # 初始化检索文本
+ query = ""
+ # 如果请求方法为POST
+ if request.method == "POST":
+ # 记录收到文本检索POST请求的日志
+ logger.warning("收到文本检索POST请求")
+ # 获取表单中的检索文本
+ query = request.form.get("query", "")
+ # 记录检索文本内容
+ logger.warning(f"检索文本: {query}")
+ # 如果检索文本非空
+ if query.strip():
+ try:
+ # 记录开始文本向量化的日志
+ logger.warning("开始文本向量化(检索)")
+ # 获取检索文本的向量表示
+ query_embedding = get_doc_embedding(query)
+ # 记录向量化成功,准备检索
+ logger.warning("文本向量化成功,开始Milvus检索")
+ # 在Milvus中检索最相似的文档
+ milvus_results = search_doc_vectors(query_embedding, top_k=5)
+ # 记录检索命中组数
+ logger.warning(f"Milvus检索完成,命中组数: {len(milvus_results)}")
+ # 遍历检索结果
+ for hits in milvus_results:
+ for hit in hits:
+ # 将每个命中结果添加到结果列表
+ results.append(
+ {"doc": hit.entity.get("doc", ""), "score": hit.distance}
+ )
+ # 捕获异常并记录错误日志
+ except Exception as e:
+ logger.error(f"检索失败: {e}")
+ # 闪现检索失败的提示信息
+ flash(f"检索失败: {e}")
+ # 记录最终检索结果
+ logger.warning(f"检索结果: {results}")
+ # 渲染检索结果页面
+ return render_template("search_doc.html", results=results, query=query)
7.6. home.html #
templates/home.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>AI智能检索首页</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.nav {
margin: 40px auto;
max-width: 400px;
}
.nav a {
display: block;
margin: 20px 0;
padding: 16px;
background: #f0f0f0;
border-radius: 8px;
text-align: center;
text-decoration: none;
color: #333;
font-size: 18px;
}
.nav a:hover {
background: #d0eaff;
}
</style>
</head>
<body>
<div class="nav">
+ <h1 style="text-align: center;">AI智能检索</h1>
<a href="{{ url_for('upload_doc') }}">上传文档</a>
+ <a href="{{ url_for('search_doc') }}">文本检索</a>
</div>
</body>
</html>8.上传图片 #
8.1. cos_utils.py #
cos_utils.py
# 导入os模块,用于获取环境变量
import os
# 导入logging模块,用于日志记录
import logging
# 从qcloud_cos库中导入CosConfig和CosS3Client类
from qcloud_cos import CosConfig, CosS3Client
# 从环境变量中获取腾讯云的SecretId
TENCENT_SECRET_ID = os.getenv("TENCENT_SECRET_ID")
# 从环境变量中获取腾讯云的SecretKey
TENCENT_SECRET_KEY = os.getenv("TENCENT_SECRET_KEY")
# 从环境变量中获取腾讯云的区域信息
TENCENT_REGION = os.getenv("TENCENT_REGION")
# 从环境变量中获取腾讯云的存储桶名称
TENCENT_BUCKET = os.getenv("TENCENT_BUCKET")
# 令牌,默认为None
TENCENT_TOKEN = None
# 访问协议,默认为https
TENCENT_SCHEME = "https"
# 打印获取到的配置信息,便于调试
print(TENCENT_SECRET_ID, TENCENT_SECRET_KEY, TENCENT_REGION, TENCENT_BUCKET)
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 定义上传文件到腾讯云COS的函数
def upload_to_cos(stream, key):
# 创建CosConfig对象,配置COS客户端参数
config = CosConfig(
Region=TENCENT_REGION,
SecretId=TENCENT_SECRET_ID,
SecretKey=TENCENT_SECRET_KEY,
Token=TENCENT_TOKEN,
Scheme=TENCENT_SCHEME,
)
# 创建CosS3Client客户端对象
client = CosS3Client(config)
# 调用put_object方法上传对象到COS
client.put_object(
Bucket=TENCENT_BUCKET,
Body=stream,
Key=key,
StorageClass="STANDARD",
EnableMD5=False,
)
# 构造文件的访问URL
url = f"https://{TENCENT_BUCKET}.cos.{TENCENT_REGION}.myqcloud.com/{key}"
# 记录上传成功的警告日志,包含文件访问地址
logger.warning(f"文件已上传,访问地址: {url}")
# 返回文件的访问URL
return url
8.2. image.py #
milvus_utils/image.py
# 从pymilvus库导入所需的类和方法
from pymilvus import (
Collection, # 导入Collection类,用于操作Milvus集合
FieldSchema, # 导入FieldSchema类,用于定义字段结构
CollectionSchema, # 导入CollectionSchema类,用于定义集合结构
DataType, # 导入DataType类,用于指定字段类型
utility, # 导入utility模块,用于集合相关的工具函数
)
# 导入logging模块,用于日志记录
import logging
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 定义图片向量集合的名称
IMAGE_COLLECTION_NAME = "image_vectors"
# 定义初始化图片向量集合的函数
def init_image_collection():
# 如果Milvus中不存在该集合,则进行创建
if not utility.has_collection(IMAGE_COLLECTION_NAME):
# 定义集合的字段结构
fields = [
# 定义主键id字段,类型为INT64,自动生成
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
# 定义embedding字段,类型为FLOAT_VECTOR,维度为2048
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2048),
# 定义image_url字段,类型为VARCHAR,最大长度4096
FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=4096),
]
# 创建集合的schema对象,描述为“图片向量集合”
schema = CollectionSchema(fields, description="图片向量集合")
# 创建集合对象
col = Collection(IMAGE_COLLECTION_NAME, schema)
# 定义索引参数
index_params = {
"index_type": "IVF_FLAT", # 索引类型为IVF_FLAT
"metric_type": "COSINE", # 度量方式为余弦相似度
"params": {"nlist": 128}, # nlist参数设置为128
}
# 为embedding字段创建索引
col.create_index(field_name="embedding", index_params=index_params)
# 加载集合到内存
col.load()
else:
# 如果集合已存在,则直接加载
col = Collection(IMAGE_COLLECTION_NAME)
col.load()
# 定义插入图片向量及图片URL到Milvus的函数
def insert_image_vectors(embeddings, image_urls):
"""
插入图片向量及图片URL到Milvus。
"""
# 获取图片向量集合对象
col = Collection(IMAGE_COLLECTION_NAME)
# 构造插入数据,包含向量和图片URL
data = [embeddings, image_urls]
# 执行插入操作
col.insert(data)
8.3. image_service.py #
services/image_service.py
# 导入Flask的模板渲染、重定向、URL生成和消息闪现函数
from flask import render_template, redirect, url_for, flash
# 导入判断图片文件类型的工具函数
from utils import allowed_image_file
# 导入Werkzeug的安全文件名处理函数
from werkzeug.utils import secure_filename
# 导入腾讯云COS上传工具
from cos_utils import upload_to_cos
# 导入图片向量化工具
from embedding_utils import get_image_embedding
# 导入Milvus图片向量插入函数
from milvus_utils import insert_image_vectors
# 导入日志模块
import logging
# 定义处理图片上传的函数
def handle_image_upload(request):
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 判断请求方法是否为POST
if request.method == "POST":
# 记录收到POST请求的日志
logger.warning("收到图片上传POST请求")
# 检查请求中是否包含文件
if "file" not in request.files:
# 记录未选择文件的日志
logger.warning("未选择文件")
# 闪现提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 检查文件名是否为空
if file.filename == "":
# 记录文件名为空的日志
logger.warning("文件名为空")
# 闪现提示信息
flash("未选择文件")
# 重定向回当前页面
return redirect(request.url)
# 判断文件存在且为允许的图片类型
if file and allowed_image_file(file.filename):
# 获取安全的文件名
filename = secure_filename(file.filename)
try:
# 记录开始上传图片的日志
logger.warning("开始上传图片")
# 上传图片到腾讯云COS,返回图片URL
image_url = upload_to_cos(file.stream, key=filename)
# 记录图片上传成功的日志
logger.warning(f"图片上传成功,URL: {image_url}")
# 获取图片的向量表示
embedding = get_image_embedding(image_url)
# 记录图片向量化成功的日志
logger.warning("图片向量化成功,开始入库")
# 向Milvus插入图片向量和图片URL
insert_image_vectors([embedding], [image_url])
# 记录图片向量入库成功的日志
logger.warning("图片向量入库成功")
# 闪现成功提示信息
flash("文件上传并入库成功", "success")
except Exception as e:
# 记录向量化或入库失败的错误日志
logger.error(f"向量化或入库失败: {e}")
# 闪现错误提示信息
flash(f"向量化或入库失败: {e}", "error")
# 重定向到图片上传页面
return redirect(url_for("upload_image"))
# 渲染图片上传页面
return render_template("upload_image.html")
8.4. upload_image.html #
templates/upload_image.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>上传图片 - 图搜图</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.container {
max-width: 600px;
margin: auto;
}
.section {
margin-bottom: 32px;
padding: 24px;
border: 1px solid #eee;
border-radius: 8px;
}
.flash {
color: red;
}
.flash.success {
color: green;
}
.flash.error {
color: red;
}
</style>
</head>
<body>
<h1>上传图片</h1>
<a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
<div class="section">
<form method="post" enctype="multipart/form-data" action="{{ url_for('upload_image') }}">
<input type="file" name="file" accept=".jpg,.jpeg,.png,.bmp,.gif">
<button type="submit">上传</button>
</form>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
<ul class="flashes">
{% for category, message in messages %}
<li class="flash {{ category }}">{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
</div>
</body>
</html>8.5. app.py #
app.py
# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 导入dotenv库中的load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索处理函数
from services.doc_service import handle_doc_upload, handle_doc_search
+# 从services.image_service模块导入图片上传处理函数
+from services.image_service import handle_image_upload
# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库连接正常
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建Flask应用的工厂函数
def create_app():
# 实例化Flask应用
app = Flask(__name__)
# 设置应用的密钥,从环境变量中获取
app.secret_key = os.getenv("SECRET_KEY")
# 定义首页路由
@app.route("/")
def home():
# 记录访问首页的警告日志
logger.warning("访问首页")
# 渲染home.html模板
return render_template("home.html")
# 定义文档上传路由,支持GET和POST方法
@app.route("/upload_doc", methods=["GET", "POST"])
def upload_doc():
# 记录访问上传文档页面的警告日志
logger.warning("访问上传文档页面")
# 调用文档上传处理函数
return handle_doc_upload(request)
# 定义文档检索路由,支持GET和POST方法
@app.route("/search_doc", methods=["GET", "POST"])
def search_doc():
# 记录访问文本检索页面的警告日志
logger.warning("访问文本检索页面")
# 调用文档检索处理函数
return handle_doc_search(request)
+ # 定义图片上传路由,支持GET和POST方法
+ @app.route("/upload_image", methods=["GET", "POST"])
+ def upload_image():
+ # 记录访问上传图片页面的警告日志
+ logger.warning("访问上传图片页面")
+ # 调用图片上传处理函数
+ return handle_image_upload(request)
# 返回Flask应用实例
return app
# 判断当前模块是否为主程序入口
if __name__ == "__main__":
# 创建Flask应用实例
app = create_app()
+ # 启动Flask应用,开启调试模式
+ app.run(debug=True)
8.6. embedding_utils.py #
embedding_utils.py
# 导入os模块,用于读取环境变量
import os
# 导入requests库,用于发送HTTP请求
import requests
# 从环境变量中获取文本嵌入API的URL
VOLC_EMBEDDINGS_API_URL = os.getenv("VOLC_EMBEDDINGS_API_URL")
# 从环境变量中获取图像嵌入API的URL
VOLC_EMBEDDINGS_VISION_API_URL = os.getenv("VOLC_EMBEDDINGS_VISION_API_URL")
# 从环境变量中获取API密钥
VOLC_API_KEY = os.getenv("VOLC_API_KEY")
# 定义获取文档嵌入向量的函数
def get_doc_embedding(doc_content):
# 构造请求头,包含内容类型和认证信息
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {VOLC_API_KEY}",
}
# 构造请求体,指定模型和输入内容
payload = {"model": "doubao-embedding-text-240715", "input": doc_content}
# 发送POST请求到文本嵌入API
response = requests.post(VOLC_EMBEDDINGS_API_URL, json=payload, headers=headers)
# 如果响应状态码为200,表示请求成功
if response.status_code == 200:
# 解析响应的JSON数据
data = response.json()
# 提取嵌入向量
embedding = data["data"][0]["embedding"]
# 返回嵌入向量
return embedding
else:
# 如果请求失败,抛出异常并输出错误信息
raise Exception(f"Embedding API error: {response.text}")
+# 定义获取图片嵌入向量的函数
+def get_image_embedding(image_url):
+ # 构造请求体,指定模型、编码格式和图片URL
+ payload = {
+ "model": "doubao-embedding-vision-250615",
+ "encoding_format": "float",
+ "input": [
+ {
+ "type": "image_url",
+ "image_url": {"url": image_url},
+ }
+ ],
+ }
+ # 构造请求头,包含内容类型和认证信息
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {VOLC_API_KEY}",
+ }
+ # 发送POST请求到图像嵌入API
+ response = requests.post(
+ VOLC_EMBEDDINGS_VISION_API_URL, json=payload, headers=headers
+ )
+ # 如果响应状态码为200,表示请求成功
+ if response.status_code == 200:
+ # 解析响应的JSON数据
+ data = response.json()
+ # 提取嵌入向量
+ embedding = data["data"]["embedding"]
+ # 返回嵌入向量
+ return embedding
+ else:
+ # 如果请求失败,抛出异常并输出错误信息和状态码
+ raise Exception(f"Embedding API error: {response.status_code} {response.text}")
8.7. init.py #
milvus_utils/init.py
# 导入pymilvus中的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从doc模块导入初始化文档集合、插入文档向量、检索文档向量的函数
from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors
+# 从image模块导入初始化图片集合和插入图片向量的函数
+from .image import init_image_collection, insert_image_vectors
# 定义全局变量,标记Milvus是否已连接
_milvus_connected = False
# 定义确保Milvus数据库连接的函数
def ensure_milvus_connection(host="localhost", port="19530"):
# 声明使用全局变量
global _milvus_connected
# 如果尚未连接Milvus,则进行连接和集合初始化
if not _milvus_connected:
# 连接到Milvus服务器
connections.connect(host=host, port=port)
# 初始化文档向量集合
init_doc_collection()
+ # 初始化图片向量集合
+ init_image_collection()
# 设置已连接标志为True
_milvus_connected = True
# 定义模块对外暴露的成员
__all__ = [
# 暴露ensure_milvus_connection函数
ensure_milvus_connection,
# 暴露insert_doc_vectors函数
"insert_doc_vectors",
# 暴露search_doc_vectors函数
"search_doc_vectors",
+ # 暴露insert_image_vectors函数
+ "insert_image_vectors",
]
8.8. home.html #
templates/home.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>AI智能检索首页</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.nav {
margin: 40px auto;
max-width: 400px;
}
.nav a {
display: block;
margin: 20px 0;
padding: 16px;
background: #f0f0f0;
border-radius: 8px;
text-align: center;
text-decoration: none;
color: #333;
font-size: 18px;
}
.nav a:hover {
background: #d0eaff;
}
</style>
</head>
<body>
<div class="nav">
<h1 style="text-align: center;">AI智能检索</h1>
<a href="{{ url_for('upload_doc') }}">上传文档</a>
<a href="{{ url_for('search_doc') }}">文本检索</a>
+ <a href="{{ url_for('upload_image') }}">上传图片</a>
</div>
</body>
</html>8.9. utils.py #
utils.py
# 允许的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}
+# 允许的图片扩展名集合
+ALLOWED_IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "bmp", "gif"}
+
# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
# 检查文件名中是否包含点,并且扩展名在允许的文档扩展名集合中
return (
"." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
)
+
# 提取文档内容,根据不同类型分别处理
def extract_doc_content(file_path):
"""
根据文件类型(txt, pdf, docx, pdf)提取文件内容。
:param file_path: 文件路径
:return: 文件内容字符串
"""
# 获取文件扩展名并转为小写
ext = file_path.rsplit(".", 1)[-1].lower()
# 如果是txt文件,按utf-8编码读取全部内容
if ext == "txt":
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
# 如果是pdf文件,使用pdfplumber提取每一页的文本
elif ext == "pdf":
import pdfplumber
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
# 如果是docx文件,使用python-docx提取所有段落文本
elif ext == "docx":
from docx import Document
doc = Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
# 其他类型抛出异常
else:
raise ValueError("不支持的文件类型: " + ext)
+
+# 判断文件名是否为允许的图片类型
+def allowed_image_file(filename):
+ # 检查文件名中是否包含点,并且扩展名在允许的图片扩展名集合中
+ return (
+ "." in filename
+ and filename.rsplit(".", 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
+ )
9.图片检索 #
9.1. search_image.html #
templates/search_image.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>图片检索 - 图搜图</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.container {
max-width: 600px;
margin: auto;
}
.section {
margin-bottom: 32px;
padding: 24px;
border: 1px solid #eee;
border-radius: 8px;
}
.results {
margin-top: 24px;
}
.result-item {
margin-bottom: 16px;
padding: 12px;
background: #f8f8f8;
border-radius: 6px;
}
.result-item img {
width: 100px;
height: 100px;
}
.flash {
color: red;
}
.flash.success {
color: green;
}
.flash.error {
color: red;
}
</style>
</head>
<body>
<h1>图片检索</h1>
<div class="section">
<form action="{{ url_for('search_image') }}" method="post" enctype="multipart/form-data">
<input type="file" name="file" accept=".jpg,.jpeg,.png,.bmp,.gif">
<button type="submit">检索</button>
</form>
<div class="results">
{% if results is defined and results %}
<h5>检索结果:</h5>
{% for item in results %}
<div class="result-item">
<img class="result-img" src="{{ item['image_url'] }}" alt="图片">
<div>
<div><b>相似度:</b>{{ '%.4f'|format(item['score']) }}</div>
</div>
</div>
{% endfor %}
{% elif query and results is defined and results|length == 0 %}
<div>未检索到相关图片。</div>
{% endif %}
</div>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
<ul class="flashes">
{% for category, message in messages %}
<li class="flash {{ category }}">{{ message }}</li>
{% endfor %}
</ul>
{% endif %}
{% endwith %}
</div>
<a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
</body>
</html>9.2. app.py #
app.py
# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 导入dotenv库中的load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索处理函数
from services.doc_service import handle_doc_upload, handle_doc_search
+# 从services.image_service模块导入图片上传和检索处理函数
+from services.image_service import handle_image_upload, handle_image_search
# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 创建Flask应用的工厂函数
def create_app():
# 实例化Flask应用
app = Flask(__name__)
# 设置应用的密钥,从环境变量中获取
app.secret_key = os.getenv("SECRET_KEY")
# 定义首页路由
@app.route("/")
def home():
# 记录访问首页的警告日志
logger.warning("访问首页")
# 渲染home.html模板
return render_template("home.html")
# 定义上传文档的路由,支持GET和POST方法
@app.route("/upload_doc", methods=["GET", "POST"])
def upload_doc():
# 记录访问上传文档页面的警告日志
logger.warning("访问上传文档页面")
# 调用文档上传处理函数
return handle_doc_upload(request)
# 定义文本检索的路由,支持GET和POST方法
@app.route("/search_doc", methods=["GET", "POST"])
def search_doc():
# 记录访问文本检索页面的警告日志
logger.warning("访问文本检索页面")
# 调用文档检索处理函数
return handle_doc_search(request)
# 定义上传图片的路由,支持GET和POST方法
@app.route("/upload_image", methods=["GET", "POST"])
def upload_image():
# 记录访问上传图片页面的警告日志
logger.warning("访问上传图片页面")
# 调用图片上传处理函数
return handle_image_upload(request)
+ # 定义图片检索的路由,支持GET和POST方法
+ @app.route("/search_image", methods=["GET", "POST"])
+ def search_image():
+ # 记录访问图片检索页面的警告日志
+ logger.warning("访问图片检索页面")
+ # 调用图片检索处理函数
+ return handle_image_search(request)
# 返回Flask应用实例
return app
# 如果当前模块作为主程序运行
if __name__ == "__main__":
# 创建Flask应用
app = create_app()
# 启动应用,开启调试模式
app.run(debug=True)
9.3. init.py #
milvus_utils/init.py
# 导入pymilvus中的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从当前包的doc模块导入文档相关的初始化、插入和检索函数
from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors
+# 从当前包的image模块导入图片相关的初始化、插入和检索函数
+from .image import init_image_collection, insert_image_vectors, search_image_vectors
# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False
# 定义确保Milvus连接的函数,默认主机为localhost,端口为19530
def ensure_milvus_connection(host="localhost", port="19530"):
global _milvus_connected
# 如果尚未连接Milvus,则进行连接和初始化
if not _milvus_connected:
# 连接Milvus服务器
connections.connect(host=host, port=port)
# 初始化文档向量集合
init_doc_collection()
# 初始化图片向量集合
init_image_collection()
# 设置已连接标志为True
_milvus_connected = True
# 定义模块对外暴露的接口列表
__all__ = [
ensure_milvus_connection,
"insert_doc_vectors",
"search_doc_vectors",
"insert_image_vectors",
+ "search_image_vectors",
]
9.4. image.py #
milvus_utils/image.py
# 从pymilvus导入Collection、FieldSchema、CollectionSchema、DataType、utility等类和方法
from pymilvus import (
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility,
)
# 导入logging模块用于日志记录
import logging
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 定义图片向量集合的名称
IMAGE_COLLECTION_NAME = "image_vectors"
# 初始化图片向量集合
def init_image_collection():
# 如果Milvus中不存在该集合,则进行创建
if not utility.has_collection(IMAGE_COLLECTION_NAME):
# 定义集合的字段,包括自增主键id、图片向量embedding、图片url
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2048),
FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=4096),
]
# 创建集合的schema,描述信息为“图片向量集合”
schema = CollectionSchema(fields, description="图片向量集合")
# 创建集合对象
col = Collection(IMAGE_COLLECTION_NAME, schema)
# 定义索引参数,使用IVF_FLAT索引和余弦相似度
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "COSINE",
"params": {"nlist": 128},
}
# 在embedding字段上创建索引
col.create_index(field_name="embedding", index_params=index_params)
# 加载集合到内存
col.load()
else:
# 如果集合已存在,则直接加载
col = Collection(IMAGE_COLLECTION_NAME)
col.load()
# 插入图片向量及图片URL到Milvus
def insert_image_vectors(embeddings, image_urls):
"""
插入图片向量及图片URL到Milvus。
"""
# 获取图片向量集合对象
col = Collection(IMAGE_COLLECTION_NAME)
# 组织插入数据,顺序与schema一致
data = [embeddings, image_urls]
# 执行插入操作
col.insert(data)
+# 检索最相似的图片向量
+def search_image_vectors(query_embedding, top_k=5):
+ """
+ 检索最相似的图片向量。
+ """
+ # 获取图片向量集合对象
+ col = Collection(IMAGE_COLLECTION_NAME)
+ # 执行向量检索,设置检索参数
+ results = col.search(
+ data=[query_embedding], # 查询向量
+ anns_field="embedding", # 检索字段
+ param={"metric_type": "COSINE", "params": {"nprobe": 10}}, # 检索参数
+ limit=top_k, # 返回top_k个结果
+ output_fields=["image_url"], # 返回图片url字段
+ )
+ # 打印检索成功日志
+ logger.warning(f"图片检索成功")
+ # 返回检索结果
+ return results
9.5. image_service.py #
services/image_service.py
# 导入Flask的渲染模板、重定向、URL生成和消息闪现方法
from flask import render_template, redirect, url_for, flash
# 导入判断图片文件类型的工具函数
from utils import allowed_image_file
# 导入Werkzeug的安全文件名处理方法
from werkzeug.utils import secure_filename
# 导入COS对象存储上传方法
from cos_utils import upload_to_cos
# 导入图片向量化方法
from embedding_utils import get_image_embedding
+# 导入Milvus图片向量插入和检索方法
+from milvus_utils import insert_image_vectors, search_image_vectors
# 导入日志模块
import logging
# 处理图片上传的业务逻辑
def handle_image_upload(request):
# 获取logger对象
logger = logging.getLogger(__name__)
# 判断请求方法是否为POST
if request.method == "POST":
# 记录收到POST请求日志
logger.warning("收到图片上传POST请求")
# 检查请求中是否包含文件
if "file" not in request.files:
logger.warning("未选择文件")
flash("未选择文件")
return redirect(request.url)
# 获取上传的文件对象
file = request.files["file"]
# 检查文件名是否为空
if file.filename == "":
logger.warning("文件名为空")
flash("未选择文件")
return redirect(request.url)
# 检查文件是否存在且为允许的图片类型
if file and allowed_image_file(file.filename):
# 对文件名进行安全处理
filename = secure_filename(file.filename)
try:
# 记录开始上传图片日志
logger.warning("开始上传图片")
# 上传图片到COS,返回图片URL
image_url = upload_to_cos(file.stream, key=filename)
logger.warning(f"图片上传成功,URL: {image_url}")
# 获取图片的向量表示
embedding = get_image_embedding(image_url)
logger.warning("图片向量化成功,开始入库")
# 向Milvus插入图片向量和图片URL
insert_image_vectors([embedding], [image_url])
logger.warning("图片向量入库成功")
# 闪现成功消息
flash("文件上传并入库成功", "success")
except Exception as e:
# 记录异常日志并闪现错误消息
logger.error(f"向量化或入库失败: {e}")
flash(f"向量化或入库失败: {e}", "error")
# 上传后重定向到上传页面
return redirect(url_for("upload_image"))
# GET请求渲染上传页面
return render_template("upload_image.html")
+# 处理图片检索的业务逻辑
+def handle_image_search(request):
+ # 获取logger对象
+ logger = logging.getLogger(__name__)
+ # 初始化检索结果列表
+ results = []
+ # 判断请求方法是否为POST
+ if request.method == "POST":
+ # 记录收到图片检索POST请求日志
+ logger.warning("收到图片检索POST请求")
+ # 检查请求中是否包含文件
+ if "file" not in request.files:
+ logger.warning("未选择文件")
+ flash("请上传一张图片进行检索")
+ return redirect(url_for("search_image"))
+ # 获取上传的文件对象
+ file = request.files["file"]
+ # 检查文件名是否为空
+ if file.filename == "":
+ logger.warning("文件名为空")
+ flash("未选择文件")
+ return redirect(url_for("search_image"))
+ # 检查文件是否存在且为允许的图片类型
+ if file and allowed_image_file(file.filename):
+ # 对文件名进行安全处理
+ safe_filename = secure_filename(file.filename)
+ # 上传图片到COS,返回图片URL
+ image_url = upload_to_cos(file.stream, key=safe_filename)
+ logger.warning(f"图片上传成功,URL: {image_url}")
+ try:
+ # 记录开始图片向量化日志
+ logger.warning("开始图片向量化(检索)")
+ # 获取图片的向量表示
+ query_embedding = get_image_embedding(image_url)
+ logger.warning("图片向量化成功,开始Milvus检索")
+ # 调用Milvus进行图片向量检索,返回top5结果
+ milvus_results = search_image_vectors(query_embedding, top_k=5)
+ logger.warning(f"Milvus图片检索完成,命中组数: {len(milvus_results)}")
+ # 遍历检索结果,提取图片URL和相似度分数
+ for hits in milvus_results:
+ for hit in hits:
+ img_url = hit.entity.get("image_url", "")
+ results.append({"image_url": img_url, "score": hit.distance})
+ except Exception as e:
+ # 记录异常日志并闪现错误消息
+ logger.error(f"图片检索失败: {e}")
+ flash(f"检索失败: {e}", "error")
+ # 渲染图片检索结果页面
+ return render_template("search_image.html", results=results)
9.6. home.html #
templates/home.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>AI智能检索首页</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.nav {
margin: 40px auto;
max-width: 400px;
}
.nav a {
display: block;
margin: 20px 0;
padding: 16px;
background: #f0f0f0;
border-radius: 8px;
text-align: center;
text-decoration: none;
color: #333;
font-size: 18px;
}
.nav a:hover {
background: #d0eaff;
}
</style>
</head>
<body>
<div class="nav">
<h1 style="text-align: center;">AI智能检索</h1>
<a href="{{ url_for('upload_doc') }}">上传文档</a>
<a href="{{ url_for('search_doc') }}">文本检索</a>
<a href="{{ url_for('upload_image') }}">上传图片</a>
+ <a href="{{ url_for('search_image') }}">图片检索</a>
</div>
</body>
</html>