导航菜单

  • 1.vector
  • 2.milvus
  • 3.pymilvus
  • 4.rag
  • 5.rag_measure
  • 7.search
  • ragflow
  • heapq
  • HNSW
  • cosine_similarity
  • math
  • typing
  • etcd
  • minio
  • collections
  • jieba
  • random
  • beautifulsoup4
  • chromadb
  • sentence_transformers
  • numpy
  • lxml
  • openpyxl
  • PyMuPDF
  • python-docx
  • requests
  • python-pptx
  • text_splitter
  • all-MiniLM-L6-v2
  • openai
  • llm
  • BPETokenizer
  • Flask
  • RAGAS
  • BagofWords
  • langchain
  • Pydantic
  • abc
  • faiss
  • MMR
  • scikit-learn
  • 1.项目说明
  • 1.1 项目简介
  • 1.2 主要功能
  • 1.3 安装
  • 1.4 参考
  • 2. 创建服务器
    • 2.1. .env
    • 2.2. app.py
    • 2.3. home.html
    • 2.4. .gitignore
  • 3.上传文档
    • 3.1. doc_service.py
    • 3.2. upload_doc.html
    • 3.3. utils.py
    • 3.4. app.py
    • 3.5. home.html
  • 4.获取文档内容
    • 4.1. doc_service.py
    • 4.2. utils.py
  • 5.计算文档向量
    • 5.1. embedding_utils.py
    • 5.2. doc_service.py
  • 6.保存向量
    • 6.1. init.py
    • 6.2. doc.py
    • 6.3. app.py
    • 6.4. doc_service.py
  • 7.检索文档向量
    • 7.1. search_doc.html
    • 7.2. app.py
    • 7.3. init.py
    • 7.4. doc.py
    • 7.5. doc_service.py
    • 7.6. home.html
  • 8.上传图片
    • 8.1. cos_utils.py
    • 8.2. image.py
    • 8.3. image_service.py
    • 8.4. upload_image.html
    • 8.5. app.py
    • 8.6. embedding_utils.py
    • 8.7. init.py
    • 8.8. home.html
    • 8.9. utils.py
  • 9.图片检索
    • 9.1. search_image.html
    • 9.2. app.py
    • 9.3. init.py
    • 9.4. image.py
    • 9.5. image_service.py
    • 9.6. home.html

1.项目说明 #

1.1 项目简介 #

本项目基于 Python、Milvus 向量数据库和火山引擎豆包Embedding模型,实现了一个"文搜文"Web系统。支持文档上传、文本向量化、语义检索。

1.2 主要功能 #

  • 支持上传本地txt文档,自动向量化并入库
  • 支持输入任意中文文本,检索最相关的文档
  • 检索结果展示文档内容及相似度分数
  • 简洁的Flask Web页面
  • 向量模型采用火山引擎"豆包Embedding"API
  • 向量数据库采用 Milvus,检索方式为余弦相似度(COSINE)

1.3 安装 #

# 创建虚拟环境
python -m venv .venv
# 激活虚拟环境(Windows)
.venv\Scripts\activate.bat
# 激活虚拟环境(Linux/macOS)
source .venv/Scripts/activate
# 安装依赖d
pip install flask python-dotenv requests pymilvus cos-python-sdk-v5 pdfplumber  python-docx

1.4 参考 #

  • 火山引擎服务开通管理
  • 图像向量化 API
  • 文本向量化 API
  • API密钥
  • API Key 管理

2. 创建服务器 #

2.1. .env #

.env

# 秘钥,用于加密会话
SECRET_KEY="rensheng"
# 上传文件夹路径
UPLOAD_FOLDER = "uploads"

# 火山引擎文本向量化API地址
VOLC_EMBEDDINGS_API_URL="https://ark.cn-beijing.volces.com/api/v3/embeddings"
# 火山引擎多模态(图文)向量化API地址
VOLC_EMBEDDINGS_VISION_API_URL="https://ark.cn-beijing.volces.com/api/v3/embeddings/multimodal"
# 火山引擎API密钥
VOLC_API_KEY="d52e49a1-36ea-44bb-bc6e-65ce789a72f6"

# 腾讯云SecretId
TENCENT_SECRET_ID="AKIDYzEI9nDRgrxufDi5FB0ueLuJZo7ITyQm"
# 腾讯云SecretKey
TENCENT_SECRET_KEY="HXmu11voYAdv483nS0OacVAg6Sj3GLFz"
# 腾讯云存储区域
TENCENT_REGION="ap-beijing"
# 腾讯云存储桶名称
TENCENT_BUCKET="xiaohongshu-1258145019"

2.2. app.py #

app.py

# 导入Flask框架中的Flask类和render_template函数,用于创建应用和渲染模板
from flask import Flask, render_template
# 导入dotenv库中的load_dotenv函数,用于加载.env环境变量文件
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging

# 加载.env文件中的环境变量
load_dotenv()
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)


# 定义一个创建Flask应用的工厂函数
def create_app():
    # 创建Flask应用实例
    app = Flask(__name__)

    # 定义根路由"/"的视图函数
    @app.route("/")
    def home():
        # 记录访问首页的警告日志
        logger.warning("访问首页")
        # 渲染home.html模板并返回
        return render_template("home.html")

    # 返回Flask应用实例
    return app


# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 创建Flask应用
    app = create_app()
    # 启动Flask开发服务器
    app.run(debug=True)

2.3. home.html #

templates/home.html

<!DOCTYPE html>
<html lang="zh">

<head>
  <meta charset="UTF-8">
  <title>AI智能检索首页</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .nav {
      margin: 40px auto;
      max-width: 400px;
    }

    .nav a {
      display: block;
      margin: 20px 0;
      padding: 16px;
      background: #f0f0f0;
      border-radius: 8px;
      text-align: center;
      text-decoration: none;
      color: #333;
      font-size: 18px;
    }

    .nav a:hover {
      background: #d0eaff;
    }
  </style>
</head>

<body>
  <div class="nav">
    <h1>AI智能检索</h1>
  </div>
</body>

</html>

2.4. .gitignore #

.gitignore

.venv
__pycache__
uploads

3.上传文档 #

3.1. doc_service.py #

services/doc_service.py

# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename函数,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file函数,用于判断文件类型是否合法
from utils import allowed_doc_file

# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 从环境变量中获取上传文件夹路径,默认为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)

# 定义处理文档上传的函数,参数为request对象
def handle_doc_upload(request):
    # 如果请求方法为POST
    if request.method == "POST":
        # 记录收到文档上传POST请求的警告日志
        logger.warning("收到文档上传POST请求")
        # 如果请求中没有文件部分
        if "file" not in request.files:
            # 记录未选择文件的警告日志
            logger.warning("未选择文件")
            # 闪现提示未选择文件
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 如果文件名为空
        if file.filename == "":
            # 记录文件名为空的警告日志
            logger.warning("文件名为空")
            # 闪现提示未选择文件
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 如果文件存在且文件类型合法
        if file and allowed_doc_file(file.filename):
            # 对文件名进行安全处理
            filename = secure_filename(file.filename)
            # 拼接文件保存路径
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            # 保存文件到指定路径
            file.save(file_path)
            # 记录文件已保存的警告日志
            logger.warning(f"文件已保存到: {file_path}")
            # 闪现提示文件上传成功,类型为success
            flash("文件上传成功", "success")
            # 重定向到上传文档页面
            return redirect(url_for("upload_doc"))
    # 渲染上传文档页面
    return render_template("upload_doc.html")

3.2. upload_doc.html #

templates/upload_doc.html

<!DOCTYPE html>
<html lang="en">

<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>上传文档 - 文搜文</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .container {
      max-width: 600px;
      margin: auto;
    }

    .section {
      margin-bottom: 32px;
      padding: 24px;
      border: 1px solid #eee;
      border-radius: 8px;
    }

    .flash {
      color: red;
    }

    .flash.success {
      color: green;
    }

    .flash.error {
      color: red;
    }
  </style>
</head>

<body>
  <h1>上传文档</h1>
  <a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
  <div class="section">
    <form method="post" enctype="multipart/form-data" action="{{ url_for('upload_doc') }}">
      <input type="file" name="file" accept=".txt,.docx,.pdf">
      <button type="submit">上传</button>
    </form>
    {% with messages = get_flashed_messages(with_categories=true) %}
    {% if messages %}
    <ul class="flashes">
      {% for category, message in messages %}
      <li class="flash {{ category }}">{{ message }}</li>
      {% endfor %}
    </ul>
    {% endif %}
    {% endwith %}
  </div>
</body>

</html>

3.3. utils.py #

utils.py

# 允许上传的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}


# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
    # 检查文件名中是否包含点,并且扩展名是否在允许的集合中
    return (
        "." in filename
        and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
    )

3.4. app.py #

app.py

# 导入os模块,用于读取环境变量
+import os
# 从Flask框架中导入Flask类、render_template函数和request对象
+from flask import Flask, render_template, request
# 从dotenv库中导入load_dotenv函数,用于加载.env文件中的环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从自定义的services.doc_service模块中导入handle_doc_upload函数,用于处理文档上传逻辑
+from services.doc_service import handle_doc_upload

# 加载.env文件中的环境变量
load_dotenv()
# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)


# 定义一个创建Flask应用的工厂函数
def create_app():
    # 创建Flask应用实例
    app = Flask(__name__)
    # 设置应用的密钥,用于会话加密,从环境变量中获取
+   app.secret_key = os.getenv("SECRET_KEY")

    # 定义根路由"/"的视图函数
    @app.route("/")
    def home():
        # 记录访问首页的警告日志
        logger.warning("访问首页")
        # 渲染home.html模板并返回
        return render_template("home.html")

    # 定义上传文档路由"/upload_doc",支持GET和POST方法
+   @app.route("/upload_doc", methods=["GET", "POST"])
+   def upload_doc():
+       # 记录访问上传文档页面的警告日志
+       logger.warning("访问上传文档页面")
+       # 调用handle_doc_upload函数处理上传逻辑,并返回结果
+       return handle_doc_upload(request)

    # 返回Flask应用实例
    return app


# 如果当前模块是主程序入口,则启动应用
if __name__ == "__main__":
    # 创建Flask应用
    app = create_app()
    # 启动Flask开发服务器
    app.run(debug=True)

3.5. home.html #

templates/home.html

<!DOCTYPE html>
<html lang="zh">

<head>
  <meta charset="UTF-8">
  <title>AI智能检索首页</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .nav {
      margin: 40px auto;
      max-width: 400px;
    }

    .nav a {
      display: block;
      margin: 20px 0;
      padding: 16px;
      background: #f0f0f0;
      border-radius: 8px;
      text-align: center;
      text-decoration: none;
      color: #333;
      font-size: 18px;
    }

    .nav a:hover {
      background: #d0eaff;
    }
  </style>
</head>

<body>
  <div class="nav">
    <h1>AI智能检索</h1>
+   <a href="{{ url_for('upload_doc') }}">上传文档</a>
  </div>
</body>

</html>

4.获取文档内容 #

4.1. doc_service.py #

services/doc_service.py

# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils中导入secure_filename,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file和extract_doc_content函数
+from utils import allowed_doc_file, extract_doc_content

# 获取以当前模块名为名的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认值为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)

# 处理文档上传的主函数
def handle_doc_upload(request):
    # 如果请求方法为POST,表示有文件上传
    if request.method == "POST":
        # 记录收到POST请求的日志
        logger.warning("收到文档上传POST请求")
        # 检查请求中是否包含文件
        if "file" not in request.files:
            # 记录未选择文件的日志
            logger.warning("未选择文件")
            # 弹出提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 检查文件名是否为空
        if file.filename == "":
            # 记录文件名为空的日志
            logger.warning("文件名为空")
            # 弹出提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 检查文件是否存在且文件类型是否允许
        if file and allowed_doc_file(file.filename):
            # 对文件名进行安全处理
            filename = secure_filename(file.filename)
            # 构建文件保存路径
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            # 保存文件到指定路径
            file.save(file_path)
            # 记录文件保存成功的日志
            logger.warning(f"文件已保存到: {file_path}")
            # 提取文档内容
+           doc_content = extract_doc_content(file_path)
            # 记录文档内容提取成功及内容长度的日志
+           logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
            # 弹出文件上传成功的提示信息
            flash("文件上传成功", "success")
            # 重定向到上传文档页面
            return redirect(url_for("upload_doc"))
    # 如果不是POST请求,则渲染上传文档页面
    return render_template("upload_doc.html")

4.2. utils.py #

utils.py

# 允许上传的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}


# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
    # 检查文件名中是否包含点,并且扩展名是否在允许的集合中
    return (
        "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
    )


# 提取文档内容,根据不同文件类型分别处理
+def extract_doc_content(file_path):
     """
     根据文件类型(txt, pdf, docx)提取文件内容。
     :param file_path: 文件路径
     :return: 文件内容字符串
+    """
     # 获取文件扩展名并转为小写
+    ext = file_path.rsplit(".", 1)[-1].lower()
     # 如果是txt文件
+    if ext == "txt":
+        # 以utf-8编码读取文本内容
+        with open(file_path, "r", encoding="utf-8") as f:
+            return f.read()
     # 如果是pdf文件
+    elif ext == "pdf":
         # 导入pdfplumber库
+        import pdfplumber

         # 初始化文本内容为空字符串
+        text = ""
         # 使用pdfplumber打开pdf文件
+        with pdfplumber.open(file_path) as pdf:
+            # 遍历每一页,提取文本内容
+            for page in pdf.pages:
+                text += page.extract_text() or ""
         # 返回提取的文本内容
+        return text
     # 如果是docx文件
+    elif ext == "docx":
         # 导入Document类
+        from docx import Document
+
         # 加载docx文档
+        doc = Document(file_path)
         # 将所有段落的文本拼接为一个字符串
+        return "\n".join([para.text for para in doc.paragraphs])
     # 如果文件类型不支持
+    else:
         # 抛出异常,提示不支持的文件类型
+        raise ValueError("不支持的文件类型: " + ext)

5.计算文档向量 #

5.1. embedding_utils.py #

embedding_utils.py

# 导入os模块,用于读取环境变量
import os
# 导入requests库,用于发送HTTP请求
import requests

# 从环境变量中获取文本向量API的URL
VOLC_EMBEDDINGS_API_URL = os.getenv("VOLC_EMBEDDINGS_API_URL")
# 从环境变量中获取视觉向量API的URL
VOLC_EMBEDDINGS_VISION_API_URL = os.getenv("VOLC_EMBEDDINGS_VISION_API_URL")
# 从环境变量中获取API密钥
VOLC_API_KEY = os.getenv("VOLC_API_KEY")


# 定义获取文档向量的函数,参数为文档内容
def get_doc_embedding(doc_content):
    # 构造请求头,包含内容类型和认证信息
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {VOLC_API_KEY}",
    }
    # 构造请求体,指定模型和输入内容
    payload = {"model": "doubao-embedding-text-240715", "input": doc_content}
    # 发送POST请求到向量API,获取响应
    response = requests.post(VOLC_EMBEDDINGS_API_URL, json=payload, headers=headers)
    # 如果响应状态码为200,表示请求成功
    if response.status_code == 200:
        # 解析响应的JSON数据
        data = response.json()
        # 提取嵌入向量
        embedding = data["data"][0]["embedding"]
        # 返回嵌入向量
        return embedding
    else:
        # 如果请求失败,抛出异常并输出错误信息
        raise Exception(f"Embedding API error: {response.text}")

5.2. doc_service.py #

services/doc_service.py

import os
import logging
from flask import redirect, flash, render_template, url_for
from werkzeug.utils import secure_filename
from utils import allowed_doc_file, extract_doc_content
# 从embedding_utils模块导入get_doc_embedding函数,用于获取文档向量
+from embedding_utils import get_doc_embedding

logger = logging.getLogger(__name__)
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)


def handle_doc_upload(request):
    if request.method == "POST":
        logger.warning("收到文档上传POST请求")
        if "file" not in request.files:
            logger.warning("未选择文件")
            flash("未选择文件")
            return redirect(request.url)
        file = request.files["file"]
        if file.filename == "":
            logger.warning("文件名为空")
            flash("未选择文件")
            return redirect(request.url)
        if file and allowed_doc_file(file.filename):
            filename = secure_filename(file.filename)
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            file.save(file_path)
            logger.warning(f"文件已保存到: {file_path}")
            doc_content = extract_doc_content(file_path)
            logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
            # 计算文档向量
+           doc_embedding = get_doc_embedding(doc_content)
            # 记录文档向量计算成功的日志
+           logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
            insert_doc_vectors([doc_embedding], [doc_content])
            flash("文件上传成功", "success")
            return redirect(url_for("upload_doc"))
    return render_template("upload_doc.html")

6.保存向量 #

6.1. init.py #

milvus_utils/init.py

# 导入pymilvus中的connections模块,用于连接Milvus服务
from pymilvus import connections
# 从当前包的doc模块导入init_doc_collection和insert_doc_vectors函数
from .doc import (
    init_doc_collection,
    insert_doc_vectors,
)

# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False


# 定义确保Milvus连接的函数,默认主机为localhost,端口为19530
def ensure_milvus_connection(host="localhost", port="19530"):
    # 声明使用全局变量_milvus_connected
    global _milvus_connected
    # 如果尚未连接Milvus,则进行连接和集合初始化
    if not _milvus_connected:
        # 连接到Milvus服务
        connections.connect(host=host, port=port)
        # 初始化文档向量集合
        init_doc_collection()
        # 设置连接标志为True,避免重复连接
        _milvus_connected = True


# 定义模块对外暴露的成员
__all__ = [
    ensure_milvus_connection,
    "insert_doc_vectors",
]

6.2. doc.py #

milvus_utils/doc.py

# 从pymilvus库中导入Collection、FieldSchema、CollectionSchema、DataType和utility模块
from pymilvus import (
    Collection,
    FieldSchema,
    CollectionSchema,
    DataType,
    utility,
)

# 定义Milvus中用于存储文档向量的集合名称
DOC_COLLECTION_NAME = "doc_vectors"

# 定义初始化文档向量集合的函数
def init_doc_collection():
    """
    初始化Milvus连接和文档向量集合。
    """
    # 如果Milvus中不存在指定名称的集合,则进行创建
    if not utility.has_collection(DOC_COLLECTION_NAME):
        # 定义集合的字段,包括自增主键id、浮点向量embedding和原文档内容doc
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560),
            FieldSchema(name="doc", dtype=DataType.VARCHAR, max_length=4096),
        ]
        # 根据字段定义创建集合的schema
        schema = CollectionSchema(fields, description="文档向量集合")
        # 创建集合对象
        col = Collection(DOC_COLLECTION_NAME, schema)
        # 定义向量字段的索引参数
        index_params = {
            "index_type": "IVF_FLAT",
            "metric_type": "COSINE",
            "params": {"nlist": 128},
        }
        # 为embedding字段创建索引
        col.create_index(field_name="embedding", index_params=index_params)
        # 加载集合到内存
        col.load()
    else:
        # 如果集合已存在,则直接加载到内存
        col = Collection(DOC_COLLECTION_NAME)
        col.load()

# 定义插入文档向量和原文档内容到Milvus的函数
def insert_doc_vectors(embeddings, docs):
    """
    插入文本向量及原文档内容到Milvus。
    """
    # 获取指定名称的集合对象
    col = Collection(DOC_COLLECTION_NAME)
    # 组织插入的数据,顺序与schema一致
    data = [embeddings, docs]
    # 执行插入操作
    col.insert(data)

6.3. app.py #

app.py

# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 从dotenv库导入load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
+from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入handle_doc_upload函数,处理文档上传
from services.doc_service import handle_doc_upload

# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
+ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)


# 定义创建Flask应用的工厂函数
def create_app():
    # 创建Flask应用实例
    app = Flask(__name__)
    # 设置应用的密钥,用于会话加密
    app.secret_key = os.getenv("SECRET_KEY")

    # 定义根路由,访问首页
    @app.route("/")
    def home():
        # 记录访问首页的警告日志
        logger.warning("访问首页")
        # 渲染home.html模板
        return render_template("home.html")

    # 定义上传文档的路由,支持GET和POST方法
    @app.route("/upload_doc", methods=["GET", "POST"])
    def upload_doc():
        # 记录访问上传文档页面的警告日志
        logger.warning("访问上传文档页面")
        # 调用handle_doc_upload函数处理上传逻辑
        return handle_doc_upload(request)

    # 返回Flask应用实例
    return app


# 判断当前模块是否为主程序入口
if __name__ == "__main__":
    # 创建Flask应用
    app = create_app()
    # 启动Flask开发服务器
    app.run(debug=True)

6.4. doc_service.py #

services/doc_service.py

# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入重定向、消息闪现、模板渲染和url生成相关函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename函数,用于安全处理上传文件名
from werkzeug.utils import secure_filename
# 导入自定义工具函数,判断文件类型和提取文档内容
from utils import allowed_doc_file, extract_doc_content
# 导入获取文档向量的函数
from embedding_utils import get_doc_embedding
# 从milvus_utils模块导入插入向量的函数
+from milvus_utils import insert_doc_vectors

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认值为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)


# 定义处理文档上传的函数
def handle_doc_upload(request):
    # 如果请求方法为POST,表示有文件上传
    if request.method == "POST":
        # 记录收到POST请求的日志
        logger.warning("收到文档上传POST请求")
        # 检查请求中是否包含文件
        if "file" not in request.files:
            # 记录未选择文件的日志
            logger.warning("未选择文件")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回上传页面
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 检查文件名是否为空
        if file.filename == "":
            # 记录文件名为空的日志
            logger.warning("文件名为空")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回上传页面
            return redirect(request.url)
        # 检查文件是否存在且类型允许
        if file and allowed_doc_file(file.filename):
            # 对文件名进行安全处理
            filename = secure_filename(file.filename)
            # 拼接文件保存路径
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            # 保存文件到指定路径
            file.save(file_path)
            # 记录文件保存成功的日志
            logger.warning(f"文件已保存到: {file_path}")
            # 提取文档内容
            doc_content = extract_doc_content(file_path)
            # 记录文档内容提取成功的日志
            logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
            # 获取文档的向量表示
            doc_embedding = get_doc_embedding(doc_content)
            # 记录向量计算成功的日志
            logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
            # 插入文档向量和原文档内容到Milvus数据库
+           insert_doc_vectors([doc_embedding], [doc_content])
            # 闪现上传成功的提示信息
            flash("文件上传成功", "success")
            # 重定向回上传页面
            return redirect(url_for("upload_doc"))
    # 渲染上传文档页面
    return render_template("upload_doc.html")

7.检索文档向量 #

7.1. search_doc.html #

templates/search_doc.html

<!DOCTYPE html>
<html lang="en">

<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>文本检索 - 文搜文</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .container {
      max-width: 600px;
      margin: auto;
    }

    .section {
      margin-bottom: 32px;
      padding: 24px;
      border: 1px solid #eee;
      border-radius: 8px;
    }

    .results {
      margin-top: 24px;
    }

    .result-item {
      margin-bottom: 16px;
      padding: 12px;
      background: #f8f8f8;
      border-radius: 6px;
    }

    .flash {
      color: red;
    }

    .flash.success {
      color: green;
    }

    .flash.error {
      color: red;
    }
  </style>
</head>

<body>
  <h1>文本检索</h1>
  <a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
  <div class="section">
    <form action="{{ url_for('search_doc') }}" method="post">
      <input type="text" name="query" value="{{ query if query else '' }}" style="width:70%" placeholder="请输入检索文本">
      <button type="submit">检索</button>
    </form>
    <div class="results">
      {% if results is defined and results|length > 0 %}
      <h5>检索结果:</h5>
      {% for item in results %}
      <div class="result-item">
        <div><b>文档:</b>{{ item['doc'] }}</div>
        <div><b>相似度:</b>{{ '%.4f'|format(item['score']) }}</div>
      </div>
      {% endfor %}
      {% elif query and results is defined and results|length == 0 %}
      <div>未检索到相关文档。</div>
      {% endif %}
    </div>
    {% with messages = get_flashed_messages(with_categories=true) %}
    {% if messages %}
    <ul class="flashes">
      {% for category, message in messages %}
      <li class="flash {{ category }}">{{ message }}</li>
      {% endfor %}
    </ul>
    {% endif %}
    {% endwith %}
  </div>
</body>

</html>

7.2. app.py #

app.py

# 导入os模块,用于操作环境变量等
import os
# 从flask中导入Flask、render_template、request对象
from flask import Flask, render_template, request
# 导入dotenv的load_dotenv函数,用于加载.env环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils中导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索的处理函数
+from services.doc_service import handle_doc_upload, handle_doc_search

# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)


# 创建Flask应用的工厂函数
def create_app():
    # 实例化Flask应用
    app = Flask(__name__)
    # 设置Flask应用的密钥
    app.secret_key = os.getenv("SECRET_KEY")

    # 定义首页路由
    @app.route("/")
    def home():
        # 记录访问首页的日志
        logger.warning("访问首页")
        # 渲染首页模板
        return render_template("home.html")

    # 定义上传文档的路由,支持GET和POST方法
    @app.route("/upload_doc", methods=["GET", "POST"])
    def upload_doc():
        # 记录访问上传文档页面的日志
        logger.warning("访问上传文档页面")
        # 调用文档上传处理函数
        return handle_doc_upload(request)

+   # 定义文本检索的路由,支持GET和POST方法
+   @app.route("/search_doc", methods=["GET", "POST"])
+   def search_doc():
+       # 记录访问文本检索页面的日志
+       logger.warning("访问文本检索页面")
+       # 调用文档检索处理函数
+       return handle_doc_search(request)

    # 返回Flask应用实例
    return app


# 判断当前脚本是否作为主程序运行
if __name__ == "__main__":
    # 调用工厂函数创建Flask应用实例
    app = create_app()
    # 启动Flask开发服务器
    app.run(debug=True)

7.3. init.py #

milvus_utils/init.py

# 导入pymilvus的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从当前包的doc模块导入初始化集合、插入向量和检索向量的函数
+from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors

# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False

# 定义确保Milvus连接的函数
def ensure_milvus_connection(host="localhost", port="19530"):
    # 声明使用全局变量
    global _milvus_connected
    # 如果尚未连接Milvus,则进行连接和集合初始化
    if not _milvus_connected:
        # 连接到Milvus服务器
        connections.connect(host=host, port=port)
        # 初始化文档向量集合
        init_doc_collection()
        # 设置已连接标志为True
        _milvus_connected = True

# 定义模块对外暴露的成员
__all__ = [
    # 暴露ensure_milvus_connection函数
    ensure_milvus_connection,
    # 暴露insert_doc_vectors函数
    "insert_doc_vectors",
+   # 暴露search_doc_vectors函数
+   "search_doc_vectors",
]

7.4. doc.py #

milvus_utils/doc.py

# 从pymilvus库导入所需的类和方法
from pymilvus import (
    Collection,
    FieldSchema,
    CollectionSchema,
    DataType,
    utility,
)

# 定义Milvus中用于存储文档向量的集合名称
DOC_COLLECTION_NAME = "doc_vectors"

# 初始化文档向量集合(Collection),如不存在则创建
def init_doc_collection():
    """
    初始化Milvus连接和文档向量集合。
    """
    # 如果集合不存在,则创建集合及其schema
    if not utility.has_collection(DOC_COLLECTION_NAME):
        # 定义集合的字段,包括主键id、向量字段embedding、文档内容doc
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2560),
            FieldSchema(name="doc", dtype=DataType.VARCHAR, max_length=4096),
        ]
        # 创建集合schema
        schema = CollectionSchema(fields, description="文档向量集合")
        # 创建集合
        col = Collection(DOC_COLLECTION_NAME, schema)
        # 定义向量索引参数
        index_params = {
            "index_type": "IVF_FLAT",
            "metric_type": "COSINE",
            "params": {"nlist": 128},
        }
        # 为embedding字段创建索引
        col.create_index(field_name="embedding", index_params=index_params)
        # 加载集合到内存
        col.load()
    else:
        # 如果集合已存在,直接加载
        col = Collection(DOC_COLLECTION_NAME)
        col.load()

# 向Milvus插入文档向量和原文档内容
def insert_doc_vectors(embeddings, docs):
    """
    插入文本向量及原文档内容到Milvus。
    """
    # 获取集合对象
    col = Collection(DOC_COLLECTION_NAME)
    # 组织插入数据
    data = [embeddings, docs]
    # 插入数据到集合
    col.insert(data)

# 检索与查询向量最相似的文档
+def search_doc_vectors(query_embedding, top_k=5):
+   """
+   检索最相似的文本向量。
+   """
+   # 获取集合对象
+   col = Collection(DOC_COLLECTION_NAME)
+   # 执行向量检索,返回最相似的top_k条结果
+   results = col.search(
+       data=[query_embedding],
+       anns_field="embedding",
+       param={"metric_type": "COSINE", "params": {"nprobe": 10}},
+       limit=top_k,
+       output_fields=["doc"],
+   )
+   # 返回检索结果
+   return results

7.5. doc_service.py #

services/doc_service.py

# 导入os模块,用于文件和路径操作
import os
# 导入logging模块,用于日志记录
import logging
# 从flask中导入redirect、flash、render_template、url_for等函数
from flask import redirect, flash, render_template, url_for
# 从werkzeug.utils导入secure_filename,用于安全处理文件名
from werkzeug.utils import secure_filename
# 从utils模块导入allowed_doc_file和extract_doc_content函数
from utils import allowed_doc_file, extract_doc_content
# 从embedding_utils模块导入get_doc_embedding函数
from embedding_utils import get_doc_embedding
+# 从milvus_utils模块导入insert_doc_vectors和search_doc_vectors函数
+from milvus_utils import insert_doc_vectors, search_doc_vectors

# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)
# 获取上传文件夹路径,默认为"uploads"
UPLOAD_FOLDER = os.getenv("UPLOAD_FOLDER", "uploads")
# 如果上传文件夹不存在,则创建该文件夹
if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)


# 处理文档上传的函数
def handle_doc_upload(request):
    # 如果请求方法为POST
    if request.method == "POST":
        # 记录收到文档上传POST请求的日志
        logger.warning("收到文档上传POST请求")
        # 如果请求中没有"file"字段
        if "file" not in request.files:
            # 记录未选择文件的日志
            logger.warning("未选择文件")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 如果文件名为空
        if file.filename == "":
            # 记录文件名为空的日志
            logger.warning("文件名为空")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 如果文件存在且类型允许
        if file and allowed_doc_file(file.filename):
            # 对文件名进行安全处理
            filename = secure_filename(file.filename)
            # 拼接文件保存路径
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            # 保存文件到指定路径
            file.save(file_path)
            # 记录文件保存成功的日志
            logger.warning(f"文件已保存到: {file_path}")
            # 提取文档内容
            doc_content = extract_doc_content(file_path)
            # 记录文档内容提取成功的日志
            logger.warning(f"文件内容提取成功,内容长度: {len(doc_content)}")
            # 获取文档的向量表示
            doc_embedding = get_doc_embedding(doc_content)
            # 记录向量计算成功的日志
            logger.warning(f"文件向量计算成功,向量长度: {len(doc_embedding)}")
            # 插入文档向量和原文档内容到Milvus数据库
            insert_doc_vectors([doc_embedding], [doc_content])
            # 闪现上传成功的提示信息
            flash("文件上传成功", "success")
            # 重定向回上传页面
            return redirect(url_for("upload_doc"))
    # 渲染上传文档页面
    return render_template("upload_doc.html")


+# 处理文本检索的函数
+def handle_doc_search(request):
+   # 初始化检索结果列表
+   results = []
+   # 初始化检索文本
+   query = ""
+   # 如果请求方法为POST
+   if request.method == "POST":
+       # 记录收到文本检索POST请求的日志
+       logger.warning("收到文本检索POST请求")
+       # 获取表单中的检索文本
+       query = request.form.get("query", "")
+       # 记录检索文本内容
+       logger.warning(f"检索文本: {query}")
+       # 如果检索文本非空
+       if query.strip():
+           try:
+               # 记录开始文本向量化的日志
+               logger.warning("开始文本向量化(检索)")
+               # 获取检索文本的向量表示
+               query_embedding = get_doc_embedding(query)
+               # 记录向量化成功,准备检索
+               logger.warning("文本向量化成功,开始Milvus检索")
+               # 在Milvus中检索最相似的文档
+               milvus_results = search_doc_vectors(query_embedding, top_k=5)
+               # 记录检索命中组数
+               logger.warning(f"Milvus检索完成,命中组数: {len(milvus_results)}")
+               # 遍历检索结果
+               for hits in milvus_results:
+                   for hit in hits:
+                       # 将每个命中结果添加到结果列表
+                       results.append(
+                           {"doc": hit.entity.get("doc", ""), "score": hit.distance}
+                       )
+           # 捕获异常并记录错误日志
+           except Exception as e:
+               logger.error(f"检索失败: {e}")
+               # 闪现检索失败的提示信息
+               flash(f"检索失败: {e}")
+   # 记录最终检索结果
+   logger.warning(f"检索结果: {results}")
+   # 渲染检索结果页面
+   return render_template("search_doc.html", results=results, query=query)

7.6. home.html #

templates/home.html

<!DOCTYPE html>
<html lang="zh">

<head>
  <meta charset="UTF-8">
  <title>AI智能检索首页</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .nav {
      margin: 40px auto;
      max-width: 400px;
    }

    .nav a {
      display: block;
      margin: 20px 0;
      padding: 16px;
      background: #f0f0f0;
      border-radius: 8px;
      text-align: center;
      text-decoration: none;
      color: #333;
      font-size: 18px;
    }

    .nav a:hover {
      background: #d0eaff;
    }
  </style>
</head>

<body>
  <div class="nav">
+   <h1 style="text-align: center;">AI智能检索</h1>
    <a href="{{ url_for('upload_doc') }}">上传文档</a>
+   <a href="{{ url_for('search_doc') }}">文本检索</a>
  </div>
</body>

</html>

8.上传图片 #

8.1. cos_utils.py #

cos_utils.py

# 导入os模块,用于获取环境变量
import os
# 导入logging模块,用于日志记录
import logging
# 从qcloud_cos库中导入CosConfig和CosS3Client类
from qcloud_cos import CosConfig, CosS3Client

# 从环境变量中获取腾讯云的SecretId
TENCENT_SECRET_ID = os.getenv("TENCENT_SECRET_ID")
# 从环境变量中获取腾讯云的SecretKey
TENCENT_SECRET_KEY = os.getenv("TENCENT_SECRET_KEY")
# 从环境变量中获取腾讯云的区域信息
TENCENT_REGION = os.getenv("TENCENT_REGION")
# 从环境变量中获取腾讯云的存储桶名称
TENCENT_BUCKET = os.getenv("TENCENT_BUCKET")
# 令牌,默认为None
TENCENT_TOKEN = None
# 访问协议,默认为https
TENCENT_SCHEME = "https"
# 打印获取到的配置信息,便于调试
print(TENCENT_SECRET_ID, TENCENT_SECRET_KEY, TENCENT_REGION, TENCENT_BUCKET)
# 获取当前模块的logger对象
logger = logging.getLogger(__name__)


# 定义上传文件到腾讯云COS的函数
def upload_to_cos(stream, key):
    # 创建CosConfig对象,配置COS客户端参数
    config = CosConfig(
        Region=TENCENT_REGION,
        SecretId=TENCENT_SECRET_ID,
        SecretKey=TENCENT_SECRET_KEY,
        Token=TENCENT_TOKEN,
        Scheme=TENCENT_SCHEME,
    )
    # 创建CosS3Client客户端对象
    client = CosS3Client(config)
    # 调用put_object方法上传对象到COS
    client.put_object(
        Bucket=TENCENT_BUCKET,
        Body=stream,
        Key=key,
        StorageClass="STANDARD",
        EnableMD5=False,
    )
    # 构造文件的访问URL
    url = f"https://{TENCENT_BUCKET}.cos.{TENCENT_REGION}.myqcloud.com/{key}"
    # 记录上传成功的警告日志,包含文件访问地址
    logger.warning(f"文件已上传,访问地址: {url}")
    # 返回文件的访问URL
    return url

8.2. image.py #

milvus_utils/image.py

# 从pymilvus库导入所需的类和方法
from pymilvus import (
    Collection,           # 导入Collection类,用于操作Milvus集合
    FieldSchema,          # 导入FieldSchema类,用于定义字段结构
    CollectionSchema,     # 导入CollectionSchema类,用于定义集合结构
    DataType,             # 导入DataType类,用于指定字段类型
    utility,              # 导入utility模块,用于集合相关的工具函数
)
# 导入logging模块,用于日志记录
import logging

# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 定义图片向量集合的名称
IMAGE_COLLECTION_NAME = "image_vectors"


# 定义初始化图片向量集合的函数
def init_image_collection():
    # 如果Milvus中不存在该集合,则进行创建
    if not utility.has_collection(IMAGE_COLLECTION_NAME):
        # 定义集合的字段结构
        fields = [
            # 定义主键id字段,类型为INT64,自动生成
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            # 定义embedding字段,类型为FLOAT_VECTOR,维度为2048
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2048),
            # 定义image_url字段,类型为VARCHAR,最大长度4096
            FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=4096),
        ]
        # 创建集合的schema对象,描述为“图片向量集合”
        schema = CollectionSchema(fields, description="图片向量集合")
        # 创建集合对象
        col = Collection(IMAGE_COLLECTION_NAME, schema)
        # 定义索引参数
        index_params = {
            "index_type": "IVF_FLAT",      # 索引类型为IVF_FLAT
            "metric_type": "COSINE",       # 度量方式为余弦相似度
            "params": {"nlist": 128},      # nlist参数设置为128
        }
        # 为embedding字段创建索引
        col.create_index(field_name="embedding", index_params=index_params)
        # 加载集合到内存
        col.load()
    else:
        # 如果集合已存在,则直接加载
        col = Collection(IMAGE_COLLECTION_NAME)
        col.load()


# 定义插入图片向量及图片URL到Milvus的函数
def insert_image_vectors(embeddings, image_urls):
    """
    插入图片向量及图片URL到Milvus。
    """
    # 获取图片向量集合对象
    col = Collection(IMAGE_COLLECTION_NAME)
    # 构造插入数据,包含向量和图片URL
    data = [embeddings, image_urls]
    # 执行插入操作
    col.insert(data)

8.3. image_service.py #

services/image_service.py

# 导入Flask的模板渲染、重定向、URL生成和消息闪现函数
from flask import render_template, redirect, url_for, flash
# 导入判断图片文件类型的工具函数
from utils import allowed_image_file
# 导入Werkzeug的安全文件名处理函数
from werkzeug.utils import secure_filename
# 导入腾讯云COS上传工具
from cos_utils import upload_to_cos
# 导入图片向量化工具
from embedding_utils import get_image_embedding
# 导入Milvus图片向量插入函数
from milvus_utils import insert_image_vectors
# 导入日志模块
import logging


# 定义处理图片上传的函数
def handle_image_upload(request):
    # 获取当前模块的日志记录器
    logger = logging.getLogger(__name__)
    # 判断请求方法是否为POST
    if request.method == "POST":
        # 记录收到POST请求的日志
        logger.warning("收到图片上传POST请求")
        # 检查请求中是否包含文件
        if "file" not in request.files:
            # 记录未选择文件的日志
            logger.warning("未选择文件")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 检查文件名是否为空
        if file.filename == "":
            # 记录文件名为空的日志
            logger.warning("文件名为空")
            # 闪现提示信息
            flash("未选择文件")
            # 重定向回当前页面
            return redirect(request.url)
        # 判断文件存在且为允许的图片类型
        if file and allowed_image_file(file.filename):
            # 获取安全的文件名
            filename = secure_filename(file.filename)
            try:
                # 记录开始上传图片的日志
                logger.warning("开始上传图片")
                # 上传图片到腾讯云COS,返回图片URL
                image_url = upload_to_cos(file.stream, key=filename)
                # 记录图片上传成功的日志
                logger.warning(f"图片上传成功,URL: {image_url}")
                # 获取图片的向量表示
                embedding = get_image_embedding(image_url)
                # 记录图片向量化成功的日志
                logger.warning("图片向量化成功,开始入库")
                # 向Milvus插入图片向量和图片URL
                insert_image_vectors([embedding], [image_url])
                # 记录图片向量入库成功的日志
                logger.warning("图片向量入库成功")
                # 闪现成功提示信息
                flash("文件上传并入库成功", "success")
            except Exception as e:
                # 记录向量化或入库失败的错误日志
                logger.error(f"向量化或入库失败: {e}")
                # 闪现错误提示信息
                flash(f"向量化或入库失败: {e}", "error")
            # 重定向到图片上传页面
            return redirect(url_for("upload_image"))
    # 渲染图片上传页面
    return render_template("upload_image.html")

8.4. upload_image.html #

templates/upload_image.html

<!DOCTYPE html>
<html lang="en">

<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>上传图片 - 图搜图</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .container {
      max-width: 600px;
      margin: auto;
    }

    .section {
      margin-bottom: 32px;
      padding: 24px;
      border: 1px solid #eee;
      border-radius: 8px;
    }

    .flash {
      color: red;
    }

    .flash.success {
      color: green;
    }

    .flash.error {
      color: red;
    }
  </style>
</head>

<body>
  <h1>上传图片</h1>
  <a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
  <div class="section">
    <form method="post" enctype="multipart/form-data" action="{{ url_for('upload_image') }}">
      <input type="file" name="file" accept=".jpg,.jpeg,.png,.bmp,.gif">
      <button type="submit">上传</button>
    </form>
    {% with messages = get_flashed_messages(with_categories=true) %}
    {% if messages %}
    <ul class="flashes">
      {% for category, message in messages %}
      <li class="flash {{ category }}">{{ message }}</li>
      {% endfor %}
    </ul>
    {% endif %}
    {% endwith %}
  </div>
</body>

</html>

8.5. app.py #

app.py

# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 导入dotenv库中的load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索处理函数
from services.doc_service import handle_doc_upload, handle_doc_search
+# 从services.image_service模块导入图片上传处理函数
+from services.image_service import handle_image_upload

# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库连接正常
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)


# 创建Flask应用的工厂函数
def create_app():
    # 实例化Flask应用
    app = Flask(__name__)
    # 设置应用的密钥,从环境变量中获取
    app.secret_key = os.getenv("SECRET_KEY")

    # 定义首页路由
    @app.route("/")
    def home():
        # 记录访问首页的警告日志
        logger.warning("访问首页")
        # 渲染home.html模板
        return render_template("home.html")

    # 定义文档上传路由,支持GET和POST方法
    @app.route("/upload_doc", methods=["GET", "POST"])
    def upload_doc():
        # 记录访问上传文档页面的警告日志
        logger.warning("访问上传文档页面")
        # 调用文档上传处理函数
        return handle_doc_upload(request)

    # 定义文档检索路由,支持GET和POST方法
    @app.route("/search_doc", methods=["GET", "POST"])
    def search_doc():
        # 记录访问文本检索页面的警告日志
        logger.warning("访问文本检索页面")
        # 调用文档检索处理函数
        return handle_doc_search(request)

+   # 定义图片上传路由,支持GET和POST方法
+   @app.route("/upload_image", methods=["GET", "POST"])
+   def upload_image():
+       # 记录访问上传图片页面的警告日志
+       logger.warning("访问上传图片页面")
+       # 调用图片上传处理函数
+       return handle_image_upload(request)

    # 返回Flask应用实例
    return app


# 判断当前模块是否为主程序入口
if __name__ == "__main__":
    # 创建Flask应用实例
    app = create_app()
+   # 启动Flask应用,开启调试模式
+   app.run(debug=True)

8.6. embedding_utils.py #

embedding_utils.py

# 导入os模块,用于读取环境变量
import os
# 导入requests库,用于发送HTTP请求
import requests

# 从环境变量中获取文本嵌入API的URL
VOLC_EMBEDDINGS_API_URL = os.getenv("VOLC_EMBEDDINGS_API_URL")
# 从环境变量中获取图像嵌入API的URL
VOLC_EMBEDDINGS_VISION_API_URL = os.getenv("VOLC_EMBEDDINGS_VISION_API_URL")
# 从环境变量中获取API密钥
VOLC_API_KEY = os.getenv("VOLC_API_KEY")


# 定义获取文档嵌入向量的函数
def get_doc_embedding(doc_content):
    # 构造请求头,包含内容类型和认证信息
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {VOLC_API_KEY}",
    }
    # 构造请求体,指定模型和输入内容
    payload = {"model": "doubao-embedding-text-240715", "input": doc_content}
    # 发送POST请求到文本嵌入API
    response = requests.post(VOLC_EMBEDDINGS_API_URL, json=payload, headers=headers)
    # 如果响应状态码为200,表示请求成功
    if response.status_code == 200:
        # 解析响应的JSON数据
        data = response.json()
        # 提取嵌入向量
        embedding = data["data"][0]["embedding"]
        # 返回嵌入向量
        return embedding
    else:
        # 如果请求失败,抛出异常并输出错误信息
        raise Exception(f"Embedding API error: {response.text}")


+# 定义获取图片嵌入向量的函数
+def get_image_embedding(image_url):
+   # 构造请求体,指定模型、编码格式和图片URL
+   payload = {
+       "model": "doubao-embedding-vision-250615",
+       "encoding_format": "float",
+       "input": [
+           {
+               "type": "image_url",
+               "image_url": {"url": image_url},
+           }
+       ],
+   }
+   # 构造请求头,包含内容类型和认证信息
+   headers = {
+       "Content-Type": "application/json",
+       "Authorization": f"Bearer {VOLC_API_KEY}",
+   }
+   # 发送POST请求到图像嵌入API
+   response = requests.post(
+       VOLC_EMBEDDINGS_VISION_API_URL, json=payload, headers=headers
+   )
+   # 如果响应状态码为200,表示请求成功
+   if response.status_code == 200:
+       # 解析响应的JSON数据
+       data = response.json()
+       # 提取嵌入向量
+       embedding = data["data"]["embedding"]
+       # 返回嵌入向量
+       return embedding
+   else:
+       # 如果请求失败,抛出异常并输出错误信息和状态码
+       raise Exception(f"Embedding API error: {response.status_code} {response.text}")

8.7. init.py #

milvus_utils/init.py

# 导入pymilvus中的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从doc模块导入初始化文档集合、插入文档向量、检索文档向量的函数
from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors
+# 从image模块导入初始化图片集合和插入图片向量的函数
+from .image import init_image_collection, insert_image_vectors

# 定义全局变量,标记Milvus是否已连接
_milvus_connected = False


# 定义确保Milvus数据库连接的函数
def ensure_milvus_connection(host="localhost", port="19530"):
    # 声明使用全局变量
    global _milvus_connected
    # 如果尚未连接Milvus,则进行连接和集合初始化
    if not _milvus_connected:
        # 连接到Milvus服务器
        connections.connect(host=host, port=port)
        # 初始化文档向量集合
        init_doc_collection()
+       # 初始化图片向量集合
+       init_image_collection()
        # 设置已连接标志为True
        _milvus_connected = True


# 定义模块对外暴露的成员
__all__ = [
    # 暴露ensure_milvus_connection函数
    ensure_milvus_connection,
    # 暴露insert_doc_vectors函数
    "insert_doc_vectors",
    # 暴露search_doc_vectors函数
    "search_doc_vectors",
+   # 暴露insert_image_vectors函数
+   "insert_image_vectors",
]

8.8. home.html #

templates/home.html

<!DOCTYPE html>
<html lang="zh">

<head>
  <meta charset="UTF-8">
  <title>AI智能检索首页</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .nav {
      margin: 40px auto;
      max-width: 400px;
    }

    .nav a {
      display: block;
      margin: 20px 0;
      padding: 16px;
      background: #f0f0f0;
      border-radius: 8px;
      text-align: center;
      text-decoration: none;
      color: #333;
      font-size: 18px;
    }

    .nav a:hover {
      background: #d0eaff;
    }
  </style>
</head>

<body>
  <div class="nav">
    <h1 style="text-align: center;">AI智能检索</h1>
    <a href="{{ url_for('upload_doc') }}">上传文档</a>
    <a href="{{ url_for('search_doc') }}">文本检索</a>
+   <a href="{{ url_for('upload_image') }}">上传图片</a>
  </div>
</body>

</html>

8.9. utils.py #

utils.py

# 允许的文档扩展名集合
ALLOWED_DOC_EXTENSIONS = {"txt", "docx", "pdf"}
+# 允许的图片扩展名集合
+ALLOWED_IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "bmp", "gif"}
+

# 判断文件名是否为允许的文档类型
def allowed_doc_file(filename):
    # 检查文件名中是否包含点,并且扩展名在允许的文档扩展名集合中
    return (
        "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
    )
+

# 提取文档内容,根据不同类型分别处理
def extract_doc_content(file_path):
    """
    根据文件类型(txt, pdf, docx, pdf)提取文件内容。
    :param file_path: 文件路径
    :return: 文件内容字符串
    """
    # 获取文件扩展名并转为小写
    ext = file_path.rsplit(".", 1)[-1].lower()
    # 如果是txt文件,按utf-8编码读取全部内容
    if ext == "txt":
        with open(file_path, "r", encoding="utf-8") as f:
            return f.read()
    # 如果是pdf文件,使用pdfplumber提取每一页的文本
    elif ext == "pdf":
        import pdfplumber

        text = ""
        with pdfplumber.open(file_path) as pdf:
            for page in pdf.pages:
                text += page.extract_text() or ""
        return text
    # 如果是docx文件,使用python-docx提取所有段落文本
    elif ext == "docx":
        from docx import Document

        doc = Document(file_path)
        return "\n".join([para.text for para in doc.paragraphs])
    # 其他类型抛出异常
    else:
        raise ValueError("不支持的文件类型: " + ext)
+

+# 判断文件名是否为允许的图片类型
+def allowed_image_file(filename):
+   # 检查文件名中是否包含点,并且扩展名在允许的图片扩展名集合中
+   return (
+       "." in filename
+       and filename.rsplit(".", 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
+   )

9.图片检索 #

9.1. search_image.html #

templates/search_image.html

<!DOCTYPE html>
<html lang="en">

<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>图片检索 - 图搜图</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .container {
      max-width: 600px;
      margin: auto;
    }

    .section {
      margin-bottom: 32px;
      padding: 24px;
      border: 1px solid #eee;
      border-radius: 8px;
    }

    .results {
      margin-top: 24px;
    }

    .result-item {
      margin-bottom: 16px;
      padding: 12px;
      background: #f8f8f8;
      border-radius: 6px;
    }

    .result-item img {
      width: 100px;
      height: 100px;
    }

    .flash {
      color: red;
    }

    .flash.success {
      color: green;
    }

    .flash.error {
      color: red;
    }
  </style>
</head>

<body>
  <h1>图片检索</h1>
  <div class="section">
    <form action="{{ url_for('search_image') }}" method="post" enctype="multipart/form-data">
      <input type="file" name="file" accept=".jpg,.jpeg,.png,.bmp,.gif">
      <button type="submit">检索</button>
    </form>
    <div class="results">
      {% if results is defined and results %}
      <h5>检索结果:</h5>
      {% for item in results %}
      <div class="result-item">
        <img class="result-img" src="{{ item['image_url'] }}" alt="图片">
        <div>
          <div><b>相似度:</b>{{ '%.4f'|format(item['score']) }}</div>
        </div>
      </div>
      {% endfor %}
      {% elif query and results is defined and results|length == 0 %}
      <div>未检索到相关图片。</div>
      {% endif %}
    </div>
    {% with messages = get_flashed_messages(with_categories=true) %}
    {% if messages %}
    <ul class="flashes">
      {% for category, message in messages %}
      <li class="flash {{ category }}">{{ message }}</li>
      {% endfor %}
    </ul>
    {% endif %}
    {% endwith %}
  </div>
  <a href="{{ url_for('home') }}" style="display:inline-block;margin-bottom:20px;">返回首页</a>
</body>

</html>

9.2. app.py #

app.py

# 导入os模块,用于操作系统相关功能
import os
# 从flask库导入Flask、render_template和request对象
from flask import Flask, render_template, request
# 导入dotenv库中的load_dotenv函数,用于加载环境变量
from dotenv import load_dotenv
# 导入logging模块,用于日志记录
import logging
# 从milvus_utils模块导入ensure_milvus_connection函数,确保Milvus连接
from milvus_utils import ensure_milvus_connection
# 从services.doc_service模块导入文档上传和检索处理函数
from services.doc_service import handle_doc_upload, handle_doc_search
+# 从services.image_service模块导入图片上传和检索处理函数
+from services.image_service import handle_image_upload, handle_image_search

# 加载.env文件中的环境变量
load_dotenv()
# 确保Milvus数据库已连接
ensure_milvus_connection()
# 获取当前模块的日志记录器
logger = logging.getLogger(__name__)


# 创建Flask应用的工厂函数
def create_app():
    # 实例化Flask应用
    app = Flask(__name__)
    # 设置应用的密钥,从环境变量中获取
    app.secret_key = os.getenv("SECRET_KEY")

    # 定义首页路由
    @app.route("/")
    def home():
        # 记录访问首页的警告日志
        logger.warning("访问首页")
        # 渲染home.html模板
        return render_template("home.html")

    # 定义上传文档的路由,支持GET和POST方法
    @app.route("/upload_doc", methods=["GET", "POST"])
    def upload_doc():
        # 记录访问上传文档页面的警告日志
        logger.warning("访问上传文档页面")
        # 调用文档上传处理函数
        return handle_doc_upload(request)

    # 定义文本检索的路由,支持GET和POST方法
    @app.route("/search_doc", methods=["GET", "POST"])
    def search_doc():
        # 记录访问文本检索页面的警告日志
        logger.warning("访问文本检索页面")
        # 调用文档检索处理函数
        return handle_doc_search(request)

    # 定义上传图片的路由,支持GET和POST方法
    @app.route("/upload_image", methods=["GET", "POST"])
    def upload_image():
        # 记录访问上传图片页面的警告日志
        logger.warning("访问上传图片页面")
        # 调用图片上传处理函数
        return handle_image_upload(request)

+   # 定义图片检索的路由,支持GET和POST方法
+   @app.route("/search_image", methods=["GET", "POST"])
+   def search_image():
+       # 记录访问图片检索页面的警告日志
+       logger.warning("访问图片检索页面")
+       # 调用图片检索处理函数
+       return handle_image_search(request)

    # 返回Flask应用实例
    return app


# 如果当前模块作为主程序运行
if __name__ == "__main__":
    # 创建Flask应用
    app = create_app()
    # 启动应用,开启调试模式
    app.run(debug=True)

9.3. init.py #

milvus_utils/init.py

# 导入pymilvus中的connections模块,用于连接Milvus数据库
from pymilvus import connections
# 从当前包的doc模块导入文档相关的初始化、插入和检索函数
from .doc import init_doc_collection, insert_doc_vectors, search_doc_vectors
+# 从当前包的image模块导入图片相关的初始化、插入和检索函数
+from .image import init_image_collection, insert_image_vectors, search_image_vectors

# 定义一个全局变量,标记Milvus是否已连接
_milvus_connected = False


# 定义确保Milvus连接的函数,默认主机为localhost,端口为19530
def ensure_milvus_connection(host="localhost", port="19530"):
    global _milvus_connected
    # 如果尚未连接Milvus,则进行连接和初始化
    if not _milvus_connected:
        # 连接Milvus服务器
        connections.connect(host=host, port=port)
        # 初始化文档向量集合
        init_doc_collection()
        # 初始化图片向量集合
        init_image_collection()
        # 设置已连接标志为True
        _milvus_connected = True


# 定义模块对外暴露的接口列表
__all__ = [
    ensure_milvus_connection,
    "insert_doc_vectors",
    "search_doc_vectors",
    "insert_image_vectors",
+   "search_image_vectors",
]

9.4. image.py #

milvus_utils/image.py

# 从pymilvus导入Collection、FieldSchema、CollectionSchema、DataType、utility等类和方法
from pymilvus import (
    Collection,
    FieldSchema,
    CollectionSchema,
    DataType,
    utility,
)
# 导入logging模块用于日志记录
import logging

# 获取当前模块的logger对象
logger = logging.getLogger(__name__)
# 定义图片向量集合的名称
IMAGE_COLLECTION_NAME = "image_vectors"


# 初始化图片向量集合
def init_image_collection():
    # 如果Milvus中不存在该集合,则进行创建
    if not utility.has_collection(IMAGE_COLLECTION_NAME):
        # 定义集合的字段,包括自增主键id、图片向量embedding、图片url
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=2048),
            FieldSchema(name="image_url", dtype=DataType.VARCHAR, max_length=4096),
        ]
        # 创建集合的schema,描述信息为“图片向量集合”
        schema = CollectionSchema(fields, description="图片向量集合")
        # 创建集合对象
        col = Collection(IMAGE_COLLECTION_NAME, schema)
        # 定义索引参数,使用IVF_FLAT索引和余弦相似度
        index_params = {
            "index_type": "IVF_FLAT",
            "metric_type": "COSINE",
            "params": {"nlist": 128},
        }
        # 在embedding字段上创建索引
        col.create_index(field_name="embedding", index_params=index_params)
        # 加载集合到内存
        col.load()
    else:
        # 如果集合已存在,则直接加载
        col = Collection(IMAGE_COLLECTION_NAME)
        col.load()


# 插入图片向量及图片URL到Milvus
def insert_image_vectors(embeddings, image_urls):
    """
    插入图片向量及图片URL到Milvus。
    """
    # 获取图片向量集合对象
    col = Collection(IMAGE_COLLECTION_NAME)
    # 组织插入数据,顺序与schema一致
    data = [embeddings, image_urls]
    # 执行插入操作
    col.insert(data)


+# 检索最相似的图片向量
+def search_image_vectors(query_embedding, top_k=5):
+   """
+   检索最相似的图片向量。
+   """
+   # 获取图片向量集合对象
+   col = Collection(IMAGE_COLLECTION_NAME)
+   # 执行向量检索,设置检索参数
+   results = col.search(
+       data=[query_embedding],  # 查询向量
+       anns_field="embedding",  # 检索字段
+       param={"metric_type": "COSINE", "params": {"nprobe": 10}},  # 检索参数
+       limit=top_k,  # 返回top_k个结果
+       output_fields=["image_url"],  # 返回图片url字段
+   )
+   # 打印检索成功日志
+   logger.warning(f"图片检索成功")
+   # 返回检索结果
+   return results

9.5. image_service.py #

services/image_service.py

# 导入Flask的渲染模板、重定向、URL生成和消息闪现方法
from flask import render_template, redirect, url_for, flash
# 导入判断图片文件类型的工具函数
from utils import allowed_image_file
# 导入Werkzeug的安全文件名处理方法
from werkzeug.utils import secure_filename
# 导入COS对象存储上传方法
from cos_utils import upload_to_cos
# 导入图片向量化方法
from embedding_utils import get_image_embedding
+# 导入Milvus图片向量插入和检索方法
+from milvus_utils import insert_image_vectors, search_image_vectors
# 导入日志模块
import logging


# 处理图片上传的业务逻辑
def handle_image_upload(request):
    # 获取logger对象
    logger = logging.getLogger(__name__)
    # 判断请求方法是否为POST
    if request.method == "POST":
        # 记录收到POST请求日志
        logger.warning("收到图片上传POST请求")
        # 检查请求中是否包含文件
        if "file" not in request.files:
            logger.warning("未选择文件")
            flash("未选择文件")
            return redirect(request.url)
        # 获取上传的文件对象
        file = request.files["file"]
        # 检查文件名是否为空
        if file.filename == "":
            logger.warning("文件名为空")
            flash("未选择文件")
            return redirect(request.url)
        # 检查文件是否存在且为允许的图片类型
        if file and allowed_image_file(file.filename):
            # 对文件名进行安全处理
            filename = secure_filename(file.filename)
            try:
                # 记录开始上传图片日志
                logger.warning("开始上传图片")
                # 上传图片到COS,返回图片URL
                image_url = upload_to_cos(file.stream, key=filename)
                logger.warning(f"图片上传成功,URL: {image_url}")
                # 获取图片的向量表示
                embedding = get_image_embedding(image_url)
                logger.warning("图片向量化成功,开始入库")
                # 向Milvus插入图片向量和图片URL
                insert_image_vectors([embedding], [image_url])
                logger.warning("图片向量入库成功")
                # 闪现成功消息
                flash("文件上传并入库成功", "success")
            except Exception as e:
                # 记录异常日志并闪现错误消息
                logger.error(f"向量化或入库失败: {e}")
                flash(f"向量化或入库失败: {e}", "error")
            # 上传后重定向到上传页面
            return redirect(url_for("upload_image"))
    # GET请求渲染上传页面
    return render_template("upload_image.html")


+# 处理图片检索的业务逻辑
+def handle_image_search(request):
+   # 获取logger对象
+   logger = logging.getLogger(__name__)
+   # 初始化检索结果列表
+   results = []
+   # 判断请求方法是否为POST
+   if request.method == "POST":
+       # 记录收到图片检索POST请求日志
+       logger.warning("收到图片检索POST请求")
+       # 检查请求中是否包含文件
+       if "file" not in request.files:
+           logger.warning("未选择文件")
+           flash("请上传一张图片进行检索")
+           return redirect(url_for("search_image"))
+       # 获取上传的文件对象
+       file = request.files["file"]
+       # 检查文件名是否为空
+       if file.filename == "":
+           logger.warning("文件名为空")
+           flash("未选择文件")
+           return redirect(url_for("search_image"))
+       # 检查文件是否存在且为允许的图片类型
+       if file and allowed_image_file(file.filename):
+           # 对文件名进行安全处理
+           safe_filename = secure_filename(file.filename)
+           # 上传图片到COS,返回图片URL
+           image_url = upload_to_cos(file.stream, key=safe_filename)
+           logger.warning(f"图片上传成功,URL: {image_url}")
+           try:
+               # 记录开始图片向量化日志
+               logger.warning("开始图片向量化(检索)")
+               # 获取图片的向量表示
+               query_embedding = get_image_embedding(image_url)
+               logger.warning("图片向量化成功,开始Milvus检索")
+               # 调用Milvus进行图片向量检索,返回top5结果
+               milvus_results = search_image_vectors(query_embedding, top_k=5)
+               logger.warning(f"Milvus图片检索完成,命中组数: {len(milvus_results)}")
+               # 遍历检索结果,提取图片URL和相似度分数
+               for hits in milvus_results:
+                   for hit in hits:
+                       img_url = hit.entity.get("image_url", "")
+                       results.append({"image_url": img_url, "score": hit.distance})
+           except Exception as e:
+               # 记录异常日志并闪现错误消息
+               logger.error(f"图片检索失败: {e}")
+               flash(f"检索失败: {e}", "error")
+   # 渲染图片检索结果页面
+   return render_template("search_image.html", results=results)

9.6. home.html #

templates/home.html

<!DOCTYPE html>
<html lang="zh">

<head>
  <meta charset="UTF-8">
  <title>AI智能检索首页</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 40px;
    }

    .nav {
      margin: 40px auto;
      max-width: 400px;
    }

    .nav a {
      display: block;
      margin: 20px 0;
      padding: 16px;
      background: #f0f0f0;
      border-radius: 8px;
      text-align: center;
      text-decoration: none;
      color: #333;
      font-size: 18px;
    }

    .nav a:hover {
      background: #d0eaff;
    }
  </style>
</head>

<body>
  <div class="nav">
    <h1 style="text-align: center;">AI智能检索</h1>
    <a href="{{ url_for('upload_doc') }}">上传文档</a>
    <a href="{{ url_for('search_doc') }}">文本检索</a>
    <a href="{{ url_for('upload_image') }}">上传图片</a>
+   <a href="{{ url_for('search_image') }}">图片检索</a>
  </div>
</body>

</html>

访问验证

请输入访问令牌

Token不正确,请重新输入