AI大模型系列:(七)实战,手写一个RAG系统

2024-11-30 23:40
676
0

目的:通过自己编写一个RAG系统,来了解RAG系统的核心逻辑。

一、整体架构

本RAG 系统主要包含以下几个核心组件:

  • 文档加载器(loaders):支持三种文件格式,txt、pdf、docx
  • 文本分割器(processors):将长文本分割成适合向量化的小片段(可配置块的大小和重叠度、优先在段落、句子结束处分割、通过重叠区域保持上下文连贯性)
  • 向量化模块(embeddings):基于sentence-transformers实现,把文本分割的文本块转换为向量。
  • 向量存储模块(vectorstore):基于FAISS实现。
  • 向量检索器(retriever):向量相似度检索实现,它是连接文档处理和LLM的桥梁。检索模块负责根据用户查询从向量存储中检索相关文档。
  • 大语言模型接入(llm):LLM 接口模块
  • RAG核心模块(rag):RAG具体实现,整合检索和生成
  • Web服务(server):提供REST API接口供用户查询,配置管理。

数据流转过程

  • 文档处理:loaders → processors → embeddings → vectorstore
  • 查询处理:query → embeddings → vectorstore → retriever → rag → llm → response

项目基本结构:

rag_system/ # 项目根目录
├── src/ # 源代码目录
│ ├── embeddings/ # 文本向量化模块
│ │ ├── base.py # 向量化基类
│ │ └── transformer_embedding.py # Transformer 模型实现
│ ├── llm/ # 大语言模型模块
│ │ ├── base.py # LLM 基类
│ │ ├── ollama_llm.py # Ollama 模型实现
│ │ └── openai_llm.py # OpenAI 模型实现
│ ├── loaders/ # 文档加载模块
│ │ ├── base.py # 加载器基类
│ │ └── document_loader.py # 文档加载实现
│ ├── processors/ # 文本处理模块
│ │ ├── base.py # 处理器基类
│ │ └── text_splitter.py # 文本分割实现
│ ├── rag/ # RAG 核心模块
│ │ ├── base.py # RAG 基类
│ │ └── simple_rag.py # RAG 实现
│ ├── retriever/ # 检索模块
│ │ ├── base.py # 检索器基类
│ │ └── vector_retriever.py # 向量检索实现
│ └── server/ # Web 服务模块
│ ├── config.py # 配置管理
│ ├── main.py # 服务入口
│ └── routers/ # API 路由
├── data/ # 数据目录
│ ├── documents/ # 文档存储
│ └── vector_store/ # 向量数据存储
└── tests/ # 测试代码目录

下文仅对关键逻辑进行介绍,具体代码已上传到Gitee:源码地址

二、文档加载器实现

定义一个文档加载函数,传入文档路径,判断文档如果是如下三种文件格式:txt、pdf、docx,则进行加载。

2.1、txt文件读取

直接open读取,如果UTF-8读取失败则用GBK编码再读取一次:

def _load_txt(self, file_path: str) -> str:
        """加载TXT文件"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        except UnicodeDecodeError:
            with open(file_path, 'r', encoding='gbk') as f:
                return f.read()

2.2、pdf文件读取

通过PyPDF2组件对PDF文件进行简单的读取。

PDF算是文件读取中最麻烦的,复杂的PDF文档很多控件读取都会有问题,比如复杂表格,图文混杂,在此只是做了一个简单的示例很多问题没有进行扩展。

def _load_pdf(self, file_path: str) -> str:
        """加载PDF文件"""
        try:
            if self._pdf_reader is None:
                from PyPDF2 import PdfReader
                self._pdf_reader = PdfReader
            
            reader = self._pdf_reader(file_path)
            text = []
            for page in reader.pages:
                text.append(page.extract_text())
            return '\n'.join(text)
        except ImportError:
            raise ImportError("请安装 PyPDF2: pip install PyPDF2")

2.3、docx文件读取

通过python-docx组件读取:

def _load_docx(self, file_path: str) -> str:
        """加载DOCX文件"""
        try:
            if self._docx_reader is None:
                from docx import Document
                self._docx_reader = Document
            
            doc = self._docx_reader(file_path)
            text = []
            for para in doc.paragraphs:
                if para.text.strip():
                    text.append(para.text)
            return '\n'.join(text)
        except ImportError:
            raise ImportError("请安装 python-docx: pip install python-docx")

三、文本分割器实现

文本分割器将长文本分割成适合向量化的小片段

3.1、可配置参数

  • chunk_size:控制每个文本块的最大长度
  • chunk_overlap:控制相邻块的重叠程度
  • separator:自定义分割符
def __init__(self, 
                 chunk_size: int = 1000,
                 chunk_overlap: int = 200,
                 separators: List[str] = None):
        """
        初始化文本分割器
        参数:
            chunk_size: 每个文本块的最大字符数
            chunk_overlap: 相邻文本块的重叠字符数
            separators: 分割文本时使用的分隔符列表,按优先级排序
        """
        super().__init__()
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators or ["\n\n", "\n", ". ", ", ", " ", "。"]

2、灵活的分割策略

  • 可配置的块大小和重叠度
  • 智能寻找分割点(优先在段落、句子结束处分割)
  • 保持上下文连贯性(通过重叠区域)
def _find_split_point(self, text: str, start: int, end: int) -> int:
        """
        在指定范围内寻找最佳分割点
        参数:
            text: 完整文本
            start: 起始位置
            end: 结束位置
        返回:
            最佳分割点的位置
        """
        if end >= len(text):
            return len(text)

        # 遍历所有分隔符,按优先级查找
        for separator in self.separators:
            # 在指定范围内查找最后一个分隔符
            position = text[start:end].rfind(separator)
            if position != -1:
                return start + position + len(separator)

        # 如果找不到任何分隔符,就在最大长度处直接分割
        return end

3、结构化输出

  • 每个文本块都包含索引信息
  • 保留原始文本的结构特征
def split(self, text: str) -> List[Dict[str, Any]]:
        """
        将文本分割成多个小块
        参数:
            text: 要分割的文本
        返回:
            包含文本块信息的字典列表
        """
        if not text:
            return []

        chunks = []
        start = 0
        chunk_index = 0
        text_length = len(text)

        while start < text_length:
            # 计算当前块的结束位置
            end = min(start + self.chunk_size, text_length)
            
            # 如果已经到达文本末尾,直接处理剩余文本
            if end >= text_length:
                split_point = text_length
            else:
                split_point = self._find_split_point(text, start, end)
            
            # 提取当前文本块
            chunk_text = text[start:split_point].strip()
            
            # 只添加非空的文本块
            if chunk_text:
                chunk = {
                    "content": chunk_text,
                    "index": chunk_index,
                    "metadata": {
                        "start_char": start,
                        "end_char": split_point,
                        "chunk_size": len(chunk_text)
                    }
                }
                chunks.append(chunk)
                chunk_index += 1
            
            # 更新下一个块的起始位置
            if split_point == start:  # 防止无限循环
                split_point = min(start + self.chunk_size, text_length)
            
            start = split_point
            if start < text_length:  # 考虑重叠
                start = max(start - self.chunk_overlap, 0)
            
            # 防止死循环
            if start >= text_length:
                break

        return chunks

四、向量化模块

此模块是把文本分割的文本块转换为向量。

基于 sentence-transformers 实现的一个向量化器。sentence-transformers官网:https://www.sbert.net/。

4.1、为什么选择Sentence-Transformers?

1、适合RAG系统需求

  • RAG系统需要准确理解文本语义
  • 需要处理长短不一的文本片段
  • 需要高质量的语义相似度计算

2、开箱即用

  • 有现成的中文预训练模型
  • API简单易用
  • 社区支持好,文档完善

3、效果好

  • 在中文语义理解上表现优秀
  • 能处理同义词、上下文
  • 相似度计算准确

对于RAG系统这样需要深度语义理解的应用,Sentence-Transformers是一个很好的选择,尽管它可能需要更多计算资源,但带来的效果提升是值得的。

要注意的一点,sentence-transformers第一次执行时会去‌访问Hugging Face下载模型(需要科学上网!!!如果无法科学上网,可以从ModelScope下载,如下示例:

	# 国内模型下载,先通过modelscope下载模型,再加载模型
	from modelscope import snapshot_download
	model_dir = snapshot_download(self.model_name, cache_dir=cache_dir)
	self._model = SentenceTransformer(model_dir, trust_remote_code=True)

Sentence-Transformers的使用很简单,官网也有示例,无需科学上网,需要进一步了解请移步官网:https://www.sbert.net/docs/quickstart.html

4.2、核心逻辑

核心逻辑其实就2步,加载模型,和生成向量。

加载模型,其实就是构造SentenceTransformer类的对象:

def __init__(self, 
                 model_name: str = "shibing624/text2vec-base-chinese",
                 device: str = None,
                 cache_dir: str = "models/embeddings"):
        """
        初始化Transformer编码器
        参数:
            model_name: 模型名称或路径
            device: 设备 ('cuda' 或 'cpu')
            cache_dir: 模型缓存目录
        """
        import os
        os.makedirs(cache_dir, exist_ok=True)
        
        self.model_name = model_name
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        
        print(f"正在加载模型 {model_name}...")
        self.model = SentenceTransformer(
            model_name,
            cache_folder=cache_dir,
            device=self.device
        )

文本生成向量,其实就是调用SentenceTransformer对象的encode方法:

def encode(self, texts: Union[str, List[str]], **kwargs) -> Union[List[float], List[List[float]]]:
        """
        将文本编码为向量
        参数:
            texts: 单个文本或文本列表
            **kwargs: 传递给模型的额外参数
        返回:
            文本向量或向量列表
        """
        # 确保输入是列表形式
        if isinstance(texts, str):
            texts = [texts]
        
        # 使用模型生成向量嵌入
        try:
            embeddings = self.model.encode(
                texts,
                convert_to_numpy=True,
                show_progress_bar=False,
                **kwargs
            )
            return embeddings.tolist()
        except Exception as e:
            print(f"生成向量嵌入时出错: {str(e)}")
            raise

五、向量存储模块

向量数据库的介绍和选型可以查看另一篇文章

在此基于 FAISS 实现向量存储,原因:这毕竟只是一个demo代码,FAISS作为一个嵌入式向量数据库,使用简单无需再另外部署服务。

5.1、文本存储功能

  1. 接收文本列表和元数据
  2. 生成文本的向量
  3. 归一化向量
  4. 添加到FAISS索引
  5. 保存文本和元数据
  6. 持久化到磁盘

归一化向量(Normalization)是一个将向量缩放到单位长度(通常是1)的过程。

为什么要归一化?

  • 使不同长度的向量可比较
  • 只关注向量的方向,而不是大小
  • 提高相似度计算的准确性
  • 减少数值计算中的精度问题
def add_texts(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None) -> None:
        """添加文本到向量存储"""
        if not texts:
            return
        
        try:
            # 确保metadata列表长度与texts相同
            if metadata is None:
                metadata = [{} for _ in texts]
            elif len(metadata) != len(texts):
                raise ValueError("metadata长度必须与texts长度相同")
            
            # 生成向量嵌入
            embeddings = np.array(self.embedding_model.encode(texts))
            if len(embeddings.shape) == 1:
                embeddings = embeddings.reshape(1, -1)
            embeddings = embeddings.astype(np.float32)
            
            # 归一化向量
            faiss.normalize_L2(embeddings)
            
            # 打印调试信息
            print(f"添加向量数量: {len(texts)}")
            print(f"向量维度: {embeddings.shape}")
            
            # 添加到FAISS索引
            self.index.add(embeddings)
            
            # 保存文本和metadata
            self.texts.extend(texts)
            self.metadata.extend(metadata)
            
            # 保存到磁盘
            self.save()
            
            print(f"成功添加文档,当前总数: {len(self.texts)}")
            
        except Exception as e:
            print(f"添加文本时出错: {str(e)}")
            raise

5.2、数据持久化

  1. 持久化保存:保存FAISS索引到文件,然后用pickle序列化文本和元数据保存到独立文件中。
  2. 持久化加载:从磁盘加载数据时反序列化恢复文本和元数据
def save(self):
        """保存索引和数据到磁盘"""
        try:
            # 保存FAISS索引
            faiss.write_index(self.index, self.index_path)
            
            # 保存文本和metadata
            data_path = f"{self.index_path}.data"
            with open(data_path, 'wb') as f:
                pickle.dump({
                    'texts': self.texts,
                    'metadata': self.metadata
                }, f)
                
            print(f"成功保存索引到: {self.index_path}")
            print(f"成功保存数据到: {data_path}")
            
        except Exception as e:
            print(f"保存向量存储时出错: {str(e)}")
            raise
    
    def _load_stored_data(self):
        """从磁盘加载保存的数据"""
        try:
            data_path = f"{self.index_path}.data"
            if os.path.exists(data_path):
                with open(data_path, 'rb') as f:
                    data = pickle.load(f)
                    self.texts = data.get('texts', [])
                    self.metadata = data.get('metadata', [])
                print(f"成功加载数据,文档数量: {len(self.texts)}")
            else:
                print(f"警告: 未找到数据文件 {data_path}")
                
        except Exception as e:
            print(f"加载存储数据时出错: {str(e)}")
            raise

5.3、向量搜索

  1. 输入:查询文本和返回数量k
  2. 生成查询向量
  3. 归一化向量
  4. 执行最近邻搜索
  5. 返回相似度得分和文档内容
def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
        """基于相似度搜索"""
        try:
            print("\n=== 开始向量搜索 ===")
            print(f"查询文本: {query}")
            print(f"请求返回数量: {k}")
            
            if self.index.ntotal == 0:
                print("警告: 向量存储为空")
                return []
            
            # 打印当前状态
            print(f"当前索引中的向量数量: {self.index.ntotal}")
            print(f"文本数量: {len(self.texts)}")
            print(f"元数据数量: {len(self.metadata)}")
            
            # 生成查询向量
            print("正在生成查询向量...")
            query_embedding = np.array(self.embedding_model.encode(query))
            print(f"原始查询向量维度: {query_embedding.shape}")
            
            if len(query_embedding.shape) == 1:
                query_embedding = query_embedding.reshape(1, -1)
            query_embedding = query_embedding.astype(np.float32)
            print(f"处理后查询向量维度: {query_embedding.shape}")
            
            # 归一化查询向量
            print("正在归一化查询向量...")
            faiss.normalize_L2(query_embedding)
            
            # 搜索最近邻
            print("开始搜索最近邻...")
            k = min(k, self.index.ntotal)
            print(f"实际搜索数量: {k}")
            scores, indices = self.index.search(query_embedding, k)
            print(f"搜索结果 - scores: {scores.shape}, indices: {indices.shape}")
            
            # 返回结果
            results = []
            print("\n处理搜索结果...")
            for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
                print(f"处理第 {i+1} 个结果:")
                print(f"  - 索引: {idx}")
                print(f"  - 分数: {score}")
                
                if idx >= 0 and idx < len(self.texts):
                    try:
                        result = {
                            "content": self.texts[idx],
                            "metadata": self.metadata[idx],
                            "score": float(score)
                        }
                        print(f"  - 内容长度: {len(result['content'])}")
                        print(f"  - 元数据: {result['metadata']}")
                        results.append(result)
                    except Exception as e:
                        print(f"  - 处理结果时出错: {str(e)}")
            
            print(f"\n找到 {len(results)} 个相关文档")
            return results
            
        except Exception as e:
            print("\n=== 向量搜索出错 ===")
            print(f"错误类型: {type(e).__name__}")
            print(f"错误信息: {str(e)}")
            if 'query_embedding' in locals():
                print(f"查询向量维度: {query_embedding.shape}")
            print(f"索引维度: {self.dimension}")
            print(f"索引类型: {type(self.index).__name__}")
            raise

六、向量检索器

检索模块负责根据用户查询从向量存储中检索相关文档,其实就是调用上文5.3中的向量搜索函数。

def get_relevant_documents(self, query: str, k: Optional[int] = None) -> List[Dict[str, Any]]:
        """
        检索相关文档
        """
        try:
            print("\n=== 开始文档检索 ===")
            k = k or self.k
            print(f"检索参数 - 查询: {query}, k: {k}")
            
            results = self.vectorstore.similarity_search(query, k=k)
            print(f"检索到原始结果数量: {len(results)}")
            
            # 确保返回结果格式统一
            formatted_results = []
            for i, doc in enumerate(results):
                print(f"\n处理第 {i+1} 个检索结果:")
                try:
                    formatted_doc = {
                        "content": doc.get("content", ""),
                        "metadata": doc.get("metadata", {}),
                        "score": doc.get("score", 0.0)
                    }
                    print(f"  - 内容长度: {len(formatted_doc['content'])}")
                    print(f"  - 元数据: {formatted_doc['metadata']}")
                    print(f"  - 分数: {formatted_doc['score']}")
                    formatted_results.append(formatted_doc)
                except Exception as e:
                    print(f"  - 格式化结果时出错: {str(e)}")
            
            print(f"\n最终返回 {len(formatted_results)} 个文档")
            return formatted_results
            
        except Exception as e:
            print("\n=== 文档检索出错 ===")
            print(f"错误类型: {type(e).__name__}")
            print(f"错误信息: {str(e)}")
            return []

七、大语言模型接入

本项目中接入了OpenAI、月之暗面Ollama,通过环境配置LLM_MODEL设置连接哪个LLM(可选: openai、ollama、moonshot)。

OpenAI需要科技才能连接,默认openAI的配置都注释了。1、需要把requirements.txt中openai的依赖注释打开。2、需要去OpenAI获得API_KEY配置到环境配置OPENAI_API_KEY中。

月之暗面是国产大模型,API控制台地址:platform.deepseek.com/,不需要科技、注册简单并且有免费额度,获取API_KEY后配置在MOONSHOT_API_KEY中。

在此以本地Ollama为例进行介绍。

前置条件:需要在本地搭建一套Ollama服务,具体的搭建方式请移步另一篇文章:《Ollama + Continue搭建一个离线开源的AI编程助手》

7.1、提示词生成

  1. 整合用户问题
  2. 添加上下文信息
  3. 格式化提示词结构
def _build_prompt(
        self,
        question: str,
        context: List[Dict[str, Any]]
    ) -> str:
        """
        构建提示词
        参数:
            question: 用户问题
            context: 相关文档上下文
        返回:
            完整的提示词
        """
        context_str = "\n".join(
            f"文档{i+1}:{doc['content']}"
            for i, doc in enumerate(context)
        )
        
        return f"""请基于以下参考文档回答问题。如果无法从参考文档中得到答案,请说明。

参考文档:
{context_str}

问题:{question}

回答:"""

7.2、生成回答

  1. 接收用户问题和上下文
  2. 构建完整提示词
  3. 调用Ollama API
  4. 支持流式输出
def generate(
        self,
        prompt: str,
        context: Optional[List[Dict[str, Any]]] = None,
        **kwargs
    ) -> str:
        """
        生成回答
        参数:
            prompt: 用户问题
            context: 相关文档上下文
            **kwargs: 其他参数
        返回:
            生成的回答
        """
        # 构建完整提示词
        if context:
            full_prompt = self._build_prompt(prompt, context)
        else:
            full_prompt = prompt
            
        try:
            # 调用Ollama API
            response = requests.post(
                f"{self.base_url}/api/generate",
                json={
                    "model": self.model,
                    "prompt": full_prompt,
                    "system": self.system_prompt,
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens,
                    "stream": True,  # 启用流式输出
                    **kwargs
                }
            )
            
            if response.status_code != 200:
                raise Exception(f"API调用失败: {response.status_code}")
                
            # 处理流式响应
            response_text = ""
            for line in response.text.strip().split('\n'):
                if line:
                    try:
                        data = json.loads(line)
                        if "response" in data:
                            response_text += data["response"]
                            # 实时打印生成的文本
                            print(data["response"], end="", flush=True)
                    except json.JSONDecodeError:
                        continue
            
            print()  # 换行
            return response_text.strip()
            
        except Exception as e:
            error_msg = f"Ollama API调用出错: {str(e)}"
            print(error_msg)
            return error_msg

八、RAG核心模块

通过此模块把上文中的模块整合串联,实现2个功能:1、把文档添加到知识库(向量数据库,本地保存)。2、处理用户查询。

8.1、把文档添加到知识库

其实就是把文档内容和元数据调用上文5.1的add_texts函数存入向量数据库。

def add_documents(
        self, 
        documents: List[str], 
        metadata: Optional[List[Dict[str, Any]]] = None
    ) -> None:
        """
        添加文档到RAG系统
        
        参数:
            documents: 文档内容列表
            metadata: 文档元数据列表
        """
        try:
            print("\n=== 添加文档开始 ===")
            print(f"文档数量: {len(documents)}")
            if metadata:
                print(f"元数据数量: {len(metadata)}")
            
            # 添加到向量存储
            self.vectorstore.add_texts(documents, metadata)
            
            print("=== 添加文档完成 ===")
            print(f"当前总文档数: {self.vectorstore.total_docs}")
            
        except Exception as e:
            print("\n=== 添加文档出错 ===")
            print(f"错误类型: {type(e).__name__}")
            print(f"错误信息: {str(e)}")
            raise

8.2、处理用户查询

  1. 调用上文6、向量检索器的检索函数,检索相关文档
  2. 构建调用LLM的提示词
  3. 调用LLM(上文7.2)生成回答
def query(
        self,
        question: str,
        k: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        处理查询请求
        
        参数:
            question: 查询问题
            k: 返回的相关文档数量,如果为None则使用默认值
            
        返回:
            包含答案、来源和相关文档的字典
        """
        try:
            print("\n=== RAG查询开始 ===")
            print(f"问题: {question}")
            print(f"k值: {k or self.k}")
            
            # 1. 检索相关文档
            k = k or self.k
            relevant_docs = self.retriever.get_relevant_documents(question, k=k)
            print(f"检索到 {len(relevant_docs)} 个相关文档")
            
            if not relevant_docs:
                print("未找到相关文档")
                return {
                    "answer": "抱歉,没有找到相关信息。",
                    "sources": [],
                    "relevant_docs": []
                }
            
            # 2. 构建提示
            context = "\n\n".join([
                f"文档 {i+1}:\n{doc['content']}"
                for i, doc in enumerate(relevant_docs)
            ])
            
            prompt = f"""基于以下信息回答问题。如果无法从提供的信息中找到答案,请说"抱歉,我无法从提供的信息中找到答案。"

相关信息:
{context}

问题: {question}

回答:"""
            
            print("\n=== 生成回答 ===")
            # 3. 生成回答
            answer = self.llm.generate(prompt)
            
            # 4. 准备返回结果
            sources = []
            for doc in relevant_docs:
                if isinstance(doc.get("metadata"), dict):
                    sources.append(doc["metadata"])
                else:
                    sources.append({})
            
            result = {
                "answer": answer,
                "sources": sources,
                "relevant_docs": [
                    {
                        "content": doc["content"],
                        "metadata": doc.get("metadata", {}),
                        "score": doc.get("score", 0.0)
                    }
                    for doc in relevant_docs
                ]
            }
            
            print("\n=== RAG查询完成 ===")
            print(f"找到文档数: {len(relevant_docs)}")
            print(f"回答长度: {len(answer)}")
            
            return result
            
        except Exception as e:
            print("\n=== RAG查询出错 ===")
            print(f"错误类型: {type(e).__name__}")
            print(f"错误信息: {str(e)}")
            raise

九、Web服务

提供REST API接口供用户查询,配置管理。

9.1、定义查询接口

逻辑非常简单,就是定义一个查询REST API,直接调用上文8.2的函数

@router.post("/query")
async def query(
    request: QueryRequest,
    rag_system: SimpleRAG = Depends(get_rag_system)
):
    """查询接口"""
    try:
        result = rag_system.query(
            question=request.question,
            k=request.k
        )
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

9.2、WEB服务启动时初始化RAG系统

main.py中调用初始化RAG系统的方法,然后启动WEB服务。

async def initialize_rag_system(
    progress_callback: Optional[Callable] = None
) -> SimpleRAG:
    """初始化RAG系统"""
    try:
        print("正在初始化RAG系统...")
        
        # 初始化组件
        embedding_model = TransformerEmbedding()
        vectorstore = FAISSStore(embedding_model)
        retriever = VectorRetriever(vectorstore=vectorstore, k=4)
        llm = OllamaLLM(model="llama3.1", temperature=0.7)
        
        # 初始化RAG
        rag_system = SimpleRAG(
            embedding_model=embedding_model,
            vectorstore=vectorstore,
            retriever=retriever,
            llm=llm,
            k=4
        )
        
        # 加载新增的文档
        await load_new_documents(rag_system, progress_callback)
        
        return rag_system
        
    except Exception as e:
        print(f"初始化RAG系统时出错: {str(e)}")
        raise

 

全部评论