明天你会感谢今天奋力拼搏的你。
ヾ(o◕∀◕)ノヾ
目的:通过自己编写一个RAG系统,来了解RAG系统的核心逻辑。
本RAG 系统主要包含以下几个核心组件:
数据流转过程
项目基本结构:
rag_system/ # 项目根目录
├── src/ # 源代码目录
│ ├── embeddings/ # 文本向量化模块
│ │ ├── base.py # 向量化基类
│ │ └── transformer_embedding.py # Transformer 模型实现
│ ├── llm/ # 大语言模型模块
│ │ ├── base.py # LLM 基类
│ │ ├── ollama_llm.py # Ollama 模型实现
│ │ └── openai_llm.py # OpenAI 模型实现
│ ├── loaders/ # 文档加载模块
│ │ ├── base.py # 加载器基类
│ │ └── document_loader.py # 文档加载实现
│ ├── processors/ # 文本处理模块
│ │ ├── base.py # 处理器基类
│ │ └── text_splitter.py # 文本分割实现
│ ├── rag/ # RAG 核心模块
│ │ ├── base.py # RAG 基类
│ │ └── simple_rag.py # RAG 实现
│ ├── retriever/ # 检索模块
│ │ ├── base.py # 检索器基类
│ │ └── vector_retriever.py # 向量检索实现
│ └── server/ # Web 服务模块
│ ├── config.py # 配置管理
│ ├── main.py # 服务入口
│ └── routers/ # API 路由
├── data/ # 数据目录
│ ├── documents/ # 文档存储
│ └── vector_store/ # 向量数据存储
└── tests/ # 测试代码目录
下文仅对关键逻辑进行介绍,具体代码已上传到Gitee:源码地址
定义一个文档加载函数,传入文档路径,判断文档如果是如下三种文件格式:txt、pdf、docx,则进行加载。
直接open读取,如果UTF-8读取失败则用GBK编码再读取一次:
def _load_txt(self, file_path: str) -> str:
"""加载TXT文件"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
with open(file_path, 'r', encoding='gbk') as f:
return f.read()
通过PyPDF2组件对PDF文件进行简单的读取。
PDF算是文件读取中最麻烦的,复杂的PDF文档很多控件读取都会有问题,比如复杂表格,图文混杂,在此只是做了一个简单的示例很多问题没有进行扩展。
def _load_pdf(self, file_path: str) -> str:
"""加载PDF文件"""
try:
if self._pdf_reader is None:
from PyPDF2 import PdfReader
self._pdf_reader = PdfReader
reader = self._pdf_reader(file_path)
text = []
for page in reader.pages:
text.append(page.extract_text())
return '\n'.join(text)
except ImportError:
raise ImportError("请安装 PyPDF2: pip install PyPDF2")
通过python-docx组件读取:
def _load_docx(self, file_path: str) -> str:
"""加载DOCX文件"""
try:
if self._docx_reader is None:
from docx import Document
self._docx_reader = Document
doc = self._docx_reader(file_path)
text = []
for para in doc.paragraphs:
if para.text.strip():
text.append(para.text)
return '\n'.join(text)
except ImportError:
raise ImportError("请安装 python-docx: pip install python-docx")
文本分割器将长文本分割成适合向量化的小片段
def __init__(self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
separators: List[str] = None):
"""
初始化文本分割器
参数:
chunk_size: 每个文本块的最大字符数
chunk_overlap: 相邻文本块的重叠字符数
separators: 分割文本时使用的分隔符列表,按优先级排序
"""
super().__init__()
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.separators = separators or ["\n\n", "\n", ". ", ", ", " ", "。"]
def _find_split_point(self, text: str, start: int, end: int) -> int:
"""
在指定范围内寻找最佳分割点
参数:
text: 完整文本
start: 起始位置
end: 结束位置
返回:
最佳分割点的位置
"""
if end >= len(text):
return len(text)
# 遍历所有分隔符,按优先级查找
for separator in self.separators:
# 在指定范围内查找最后一个分隔符
position = text[start:end].rfind(separator)
if position != -1:
return start + position + len(separator)
# 如果找不到任何分隔符,就在最大长度处直接分割
return end
def split(self, text: str) -> List[Dict[str, Any]]:
"""
将文本分割成多个小块
参数:
text: 要分割的文本
返回:
包含文本块信息的字典列表
"""
if not text:
return []
chunks = []
start = 0
chunk_index = 0
text_length = len(text)
while start < text_length:
# 计算当前块的结束位置
end = min(start + self.chunk_size, text_length)
# 如果已经到达文本末尾,直接处理剩余文本
if end >= text_length:
split_point = text_length
else:
split_point = self._find_split_point(text, start, end)
# 提取当前文本块
chunk_text = text[start:split_point].strip()
# 只添加非空的文本块
if chunk_text:
chunk = {
"content": chunk_text,
"index": chunk_index,
"metadata": {
"start_char": start,
"end_char": split_point,
"chunk_size": len(chunk_text)
}
}
chunks.append(chunk)
chunk_index += 1
# 更新下一个块的起始位置
if split_point == start: # 防止无限循环
split_point = min(start + self.chunk_size, text_length)
start = split_point
if start < text_length: # 考虑重叠
start = max(start - self.chunk_overlap, 0)
# 防止死循环
if start >= text_length:
break
return chunks
此模块是把文本分割的文本块转换为向量。
基于 sentence-transformers 实现的一个向量化器。sentence-transformers官网:https://www.sbert.net/。
1、适合RAG系统需求
2、开箱即用
3、效果好
对于RAG系统这样需要深度语义理解的应用,Sentence-Transformers是一个很好的选择,尽管它可能需要更多计算资源,但带来的效果提升是值得的。
要注意的一点,sentence-transformers第一次执行时会去访问Hugging Face下载模型(需要科学上网!!!)如果无法科学上网,可以从ModelScope下载,如下示例:
# 国内模型下载,先通过modelscope下载模型,再加载模型
from modelscope import snapshot_download
model_dir = snapshot_download(self.model_name, cache_dir=cache_dir)
self._model = SentenceTransformer(model_dir, trust_remote_code=True)
Sentence-Transformers的使用很简单,官网也有示例,无需科学上网,需要进一步了解请移步官网:https://www.sbert.net/docs/quickstart.html
核心逻辑其实就2步,加载模型,和生成向量。
加载模型,其实就是构造SentenceTransformer类的对象:
def __init__(self,
model_name: str = "shibing624/text2vec-base-chinese",
device: str = None,
cache_dir: str = "models/embeddings"):
"""
初始化Transformer编码器
参数:
model_name: 模型名称或路径
device: 设备 ('cuda' 或 'cpu')
cache_dir: 模型缓存目录
"""
import os
os.makedirs(cache_dir, exist_ok=True)
self.model_name = model_name
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"正在加载模型 {model_name}...")
self.model = SentenceTransformer(
model_name,
cache_folder=cache_dir,
device=self.device
)
文本生成向量,其实就是调用SentenceTransformer对象的encode方法:
def encode(self, texts: Union[str, List[str]], **kwargs) -> Union[List[float], List[List[float]]]:
"""
将文本编码为向量
参数:
texts: 单个文本或文本列表
**kwargs: 传递给模型的额外参数
返回:
文本向量或向量列表
"""
# 确保输入是列表形式
if isinstance(texts, str):
texts = [texts]
# 使用模型生成向量嵌入
try:
embeddings = self.model.encode(
texts,
convert_to_numpy=True,
show_progress_bar=False,
**kwargs
)
return embeddings.tolist()
except Exception as e:
print(f"生成向量嵌入时出错: {str(e)}")
raise
向量数据库的介绍和选型可以查看另一篇文章。
在此基于 FAISS 实现向量存储,原因:这毕竟只是一个demo代码,FAISS作为一个嵌入式向量数据库,使用简单无需再另外部署服务。
归一化向量(Normalization)是一个将向量缩放到单位长度(通常是1)的过程。
为什么要归一化?
def add_texts(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None) -> None:
"""添加文本到向量存储"""
if not texts:
return
try:
# 确保metadata列表长度与texts相同
if metadata is None:
metadata = [{} for _ in texts]
elif len(metadata) != len(texts):
raise ValueError("metadata长度必须与texts长度相同")
# 生成向量嵌入
embeddings = np.array(self.embedding_model.encode(texts))
if len(embeddings.shape) == 1:
embeddings = embeddings.reshape(1, -1)
embeddings = embeddings.astype(np.float32)
# 归一化向量
faiss.normalize_L2(embeddings)
# 打印调试信息
print(f"添加向量数量: {len(texts)}")
print(f"向量维度: {embeddings.shape}")
# 添加到FAISS索引
self.index.add(embeddings)
# 保存文本和metadata
self.texts.extend(texts)
self.metadata.extend(metadata)
# 保存到磁盘
self.save()
print(f"成功添加文档,当前总数: {len(self.texts)}")
except Exception as e:
print(f"添加文本时出错: {str(e)}")
raise
def save(self):
"""保存索引和数据到磁盘"""
try:
# 保存FAISS索引
faiss.write_index(self.index, self.index_path)
# 保存文本和metadata
data_path = f"{self.index_path}.data"
with open(data_path, 'wb') as f:
pickle.dump({
'texts': self.texts,
'metadata': self.metadata
}, f)
print(f"成功保存索引到: {self.index_path}")
print(f"成功保存数据到: {data_path}")
except Exception as e:
print(f"保存向量存储时出错: {str(e)}")
raise
def _load_stored_data(self):
"""从磁盘加载保存的数据"""
try:
data_path = f"{self.index_path}.data"
if os.path.exists(data_path):
with open(data_path, 'rb') as f:
data = pickle.load(f)
self.texts = data.get('texts', [])
self.metadata = data.get('metadata', [])
print(f"成功加载数据,文档数量: {len(self.texts)}")
else:
print(f"警告: 未找到数据文件 {data_path}")
except Exception as e:
print(f"加载存储数据时出错: {str(e)}")
raise
def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
"""基于相似度搜索"""
try:
print("\n=== 开始向量搜索 ===")
print(f"查询文本: {query}")
print(f"请求返回数量: {k}")
if self.index.ntotal == 0:
print("警告: 向量存储为空")
return []
# 打印当前状态
print(f"当前索引中的向量数量: {self.index.ntotal}")
print(f"文本数量: {len(self.texts)}")
print(f"元数据数量: {len(self.metadata)}")
# 生成查询向量
print("正在生成查询向量...")
query_embedding = np.array(self.embedding_model.encode(query))
print(f"原始查询向量维度: {query_embedding.shape}")
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
query_embedding = query_embedding.astype(np.float32)
print(f"处理后查询向量维度: {query_embedding.shape}")
# 归一化查询向量
print("正在归一化查询向量...")
faiss.normalize_L2(query_embedding)
# 搜索最近邻
print("开始搜索最近邻...")
k = min(k, self.index.ntotal)
print(f"实际搜索数量: {k}")
scores, indices = self.index.search(query_embedding, k)
print(f"搜索结果 - scores: {scores.shape}, indices: {indices.shape}")
# 返回结果
results = []
print("\n处理搜索结果...")
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
print(f"处理第 {i+1} 个结果:")
print(f" - 索引: {idx}")
print(f" - 分数: {score}")
if idx >= 0 and idx < len(self.texts):
try:
result = {
"content": self.texts[idx],
"metadata": self.metadata[idx],
"score": float(score)
}
print(f" - 内容长度: {len(result['content'])}")
print(f" - 元数据: {result['metadata']}")
results.append(result)
except Exception as e:
print(f" - 处理结果时出错: {str(e)}")
print(f"\n找到 {len(results)} 个相关文档")
return results
except Exception as e:
print("\n=== 向量搜索出错 ===")
print(f"错误类型: {type(e).__name__}")
print(f"错误信息: {str(e)}")
if 'query_embedding' in locals():
print(f"查询向量维度: {query_embedding.shape}")
print(f"索引维度: {self.dimension}")
print(f"索引类型: {type(self.index).__name__}")
raise
检索模块负责根据用户查询从向量存储中检索相关文档,其实就是调用上文5.3中的向量搜索函数。
def get_relevant_documents(self, query: str, k: Optional[int] = None) -> List[Dict[str, Any]]:
"""
检索相关文档
"""
try:
print("\n=== 开始文档检索 ===")
k = k or self.k
print(f"检索参数 - 查询: {query}, k: {k}")
results = self.vectorstore.similarity_search(query, k=k)
print(f"检索到原始结果数量: {len(results)}")
# 确保返回结果格式统一
formatted_results = []
for i, doc in enumerate(results):
print(f"\n处理第 {i+1} 个检索结果:")
try:
formatted_doc = {
"content": doc.get("content", ""),
"metadata": doc.get("metadata", {}),
"score": doc.get("score", 0.0)
}
print(f" - 内容长度: {len(formatted_doc['content'])}")
print(f" - 元数据: {formatted_doc['metadata']}")
print(f" - 分数: {formatted_doc['score']}")
formatted_results.append(formatted_doc)
except Exception as e:
print(f" - 格式化结果时出错: {str(e)}")
print(f"\n最终返回 {len(formatted_results)} 个文档")
return formatted_results
except Exception as e:
print("\n=== 文档检索出错 ===")
print(f"错误类型: {type(e).__name__}")
print(f"错误信息: {str(e)}")
return []
本项目中接入了OpenAI、月之暗面和Ollama,通过环境配置LLM_MODEL设置连接哪个LLM(可选: openai、ollama、moonshot)。
OpenAI需要科技才能连接,默认openAI的配置都注释了。1、需要把requirements.txt中openai的依赖注释打开。2、需要去OpenAI获得API_KEY配置到环境配置OPENAI_API_KEY中。
月之暗面是国产大模型,API控制台地址:platform.deepseek.com/,不需要科技、注册简单并且有免费额度,获取API_KEY后配置在MOONSHOT_API_KEY中。
在此以本地Ollama为例进行介绍。
前置条件:需要在本地搭建一套Ollama服务,具体的搭建方式请移步另一篇文章:《Ollama + Continue搭建一个离线开源的AI编程助手》
def _build_prompt(
self,
question: str,
context: List[Dict[str, Any]]
) -> str:
"""
构建提示词
参数:
question: 用户问题
context: 相关文档上下文
返回:
完整的提示词
"""
context_str = "\n".join(
f"文档{i+1}:{doc['content']}"
for i, doc in enumerate(context)
)
return f"""请基于以下参考文档回答问题。如果无法从参考文档中得到答案,请说明。
参考文档:
{context_str}
问题:{question}
回答:"""
def generate(
self,
prompt: str,
context: Optional[List[Dict[str, Any]]] = None,
**kwargs
) -> str:
"""
生成回答
参数:
prompt: 用户问题
context: 相关文档上下文
**kwargs: 其他参数
返回:
生成的回答
"""
# 构建完整提示词
if context:
full_prompt = self._build_prompt(prompt, context)
else:
full_prompt = prompt
try:
# 调用Ollama API
response = requests.post(
f"{self.base_url}/api/generate",
json={
"model": self.model,
"prompt": full_prompt,
"system": self.system_prompt,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": True, # 启用流式输出
**kwargs
}
)
if response.status_code != 200:
raise Exception(f"API调用失败: {response.status_code}")
# 处理流式响应
response_text = ""
for line in response.text.strip().split('\n'):
if line:
try:
data = json.loads(line)
if "response" in data:
response_text += data["response"]
# 实时打印生成的文本
print(data["response"], end="", flush=True)
except json.JSONDecodeError:
continue
print() # 换行
return response_text.strip()
except Exception as e:
error_msg = f"Ollama API调用出错: {str(e)}"
print(error_msg)
return error_msg
通过此模块把上文中的模块整合串联,实现2个功能:1、把文档添加到知识库(向量数据库,本地保存)。2、处理用户查询。
其实就是把文档内容和元数据调用上文5.1的add_texts函数存入向量数据库。
def add_documents(
self,
documents: List[str],
metadata: Optional[List[Dict[str, Any]]] = None
) -> None:
"""
添加文档到RAG系统
参数:
documents: 文档内容列表
metadata: 文档元数据列表
"""
try:
print("\n=== 添加文档开始 ===")
print(f"文档数量: {len(documents)}")
if metadata:
print(f"元数据数量: {len(metadata)}")
# 添加到向量存储
self.vectorstore.add_texts(documents, metadata)
print("=== 添加文档完成 ===")
print(f"当前总文档数: {self.vectorstore.total_docs}")
except Exception as e:
print("\n=== 添加文档出错 ===")
print(f"错误类型: {type(e).__name__}")
print(f"错误信息: {str(e)}")
raise
def query(
self,
question: str,
k: Optional[int] = None
) -> Dict[str, Any]:
"""
处理查询请求
参数:
question: 查询问题
k: 返回的相关文档数量,如果为None则使用默认值
返回:
包含答案、来源和相关文档的字典
"""
try:
print("\n=== RAG查询开始 ===")
print(f"问题: {question}")
print(f"k值: {k or self.k}")
# 1. 检索相关文档
k = k or self.k
relevant_docs = self.retriever.get_relevant_documents(question, k=k)
print(f"检索到 {len(relevant_docs)} 个相关文档")
if not relevant_docs:
print("未找到相关文档")
return {
"answer": "抱歉,没有找到相关信息。",
"sources": [],
"relevant_docs": []
}
# 2. 构建提示
context = "\n\n".join([
f"文档 {i+1}:\n{doc['content']}"
for i, doc in enumerate(relevant_docs)
])
prompt = f"""基于以下信息回答问题。如果无法从提供的信息中找到答案,请说"抱歉,我无法从提供的信息中找到答案。"
相关信息:
{context}
问题: {question}
回答:"""
print("\n=== 生成回答 ===")
# 3. 生成回答
answer = self.llm.generate(prompt)
# 4. 准备返回结果
sources = []
for doc in relevant_docs:
if isinstance(doc.get("metadata"), dict):
sources.append(doc["metadata"])
else:
sources.append({})
result = {
"answer": answer,
"sources": sources,
"relevant_docs": [
{
"content": doc["content"],
"metadata": doc.get("metadata", {}),
"score": doc.get("score", 0.0)
}
for doc in relevant_docs
]
}
print("\n=== RAG查询完成 ===")
print(f"找到文档数: {len(relevant_docs)}")
print(f"回答长度: {len(answer)}")
return result
except Exception as e:
print("\n=== RAG查询出错 ===")
print(f"错误类型: {type(e).__name__}")
print(f"错误信息: {str(e)}")
raise
提供REST API接口供用户查询,配置管理。
逻辑非常简单,就是定义一个查询REST API,直接调用上文8.2的函数
@router.post("/query")
async def query(
request: QueryRequest,
rag_system: SimpleRAG = Depends(get_rag_system)
):
"""查询接口"""
try:
result = rag_system.query(
question=request.question,
k=request.k
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
main.py中调用初始化RAG系统的方法,然后启动WEB服务。
async def initialize_rag_system(
progress_callback: Optional[Callable] = None
) -> SimpleRAG:
"""初始化RAG系统"""
try:
print("正在初始化RAG系统...")
# 初始化组件
embedding_model = TransformerEmbedding()
vectorstore = FAISSStore(embedding_model)
retriever = VectorRetriever(vectorstore=vectorstore, k=4)
llm = OllamaLLM(model="llama3.1", temperature=0.7)
# 初始化RAG
rag_system = SimpleRAG(
embedding_model=embedding_model,
vectorstore=vectorstore,
retriever=retriever,
llm=llm,
k=4
)
# 加载新增的文档
await load_new_documents(rag_system, progress_callback)
return rag_system
except Exception as e:
print(f"初始化RAG系统时出错: {str(e)}")
raise
全部评论