This commit is contained in:
2025-12-15 14:38:31 +08:00
commit 22778e22fb
17 changed files with 938 additions and 0 deletions

209
org/rag_system.py Normal file
View File

@@ -0,0 +1,209 @@
from milvus_client import connect_to_milvus, create_collection, insert_documents, search_documents
from document_processor import DocumentProcessor
from config import LLM_MODEL
from org.config import SALT_PROMPT, REISASOL_PROMPT
class RAGSystem:
def __init__(self):
# 连接Milvus
connect_to_milvus()
# 创建集合
self.collection = create_collection()
# 初始化文档处理器
self.processor = DocumentProcessor()
# 加载集合
self.collection.load()
def add_documents(self, documents):
"""添加文档到知识库"""
# 分割文档
split_texts = self.processor.split_documents(documents)
# 生成嵌入
embeddings = self.processor.generate_embeddings(split_texts)
# 准备插入数据
texts = [text['content'] for text in split_texts]
sources = [text['source'] for text in split_texts]
# 插入数据
insert_documents(self.collection, embeddings, texts, sources)
return len(split_texts)
def query(self, question, top_k=3):
"""基础查询:生成标准答案"""
# 生成查询嵌入
query_embedding = self.processor.embedder.embed_query(question)
# 检索相关文档
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
# 构造上下文
context = "\n".join([doc['text'] for doc in retrieved_docs])
# 生成标准答案
answer = self._generate_answer(question, context)
return {
'answer': answer,
'retrieved_docs': retrieved_docs,
'context': context
}
def convert_songs_to_documents(songs_data):
"""将歌曲数据转换为文档格式"""
documents = []
for song in songs_data:
content = f"歌曲名称: {song['title']}, 歌手: {song['artist']}, BPM: {song['bpm']}, 版本: {song['version']}"
documents.append({
'content': content,
'source': f"歌曲数据 - {song['title']}"
})
return documents
def add_song_data(self, songs_data):
"""添加歌曲数据到知识库"""
# 转换数据格式
song_documents = self.convert_songs_to_documents(songs_data)
# 复用现有文档处理流程
split_texts = self.processor.split_documents(song_documents)
embeddings = self.processor.generate_embeddings(split_texts)
texts = [text['content'] for text in split_texts]
sources = [text['source'] for text in split_texts]
insert_documents(self.collection, embeddings, texts, sources)
return len(split_texts)
# -------------------------- 新增:角色扮演式查询方法 --------------------------
def role_play_query(self, question, role, top_k=3):
"""
角色扮演查询:按指定角色生成符合身份的答案
:param question: 用户问题如“Python怎么定义函数
:param role: 指定角色如“Python讲师”“产品经理”“小学生辅导员”
:param top_k: 检索相关文档的数量
:return: 包含角色化答案的结果字典
"""
# 1. 复用现有检索逻辑和query方法完全一致确保上下文相关性
query_embedding = self.processor.embedder.embed_query(question)
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
context = "\n".join([doc['text'] for doc in retrieved_docs])
# 2. 调用新增的“角色化答案生成方法”区别于基础的_generate_answer
role_answer = self._generate_role_answer(question, context, role)
# 3. 返回结构和query方法一致便于后续使用
return {
'answer': role_answer, # 角色化的答案
'role': role, # 明确返回当前角色
'retrieved_docs': retrieved_docs,
'context': context
}
# ----------------------------------------------------------------------------
def _generate_answer(self, question, context):
"""基础答案生成:无角色限制"""
from config import OLLAMA_BASE_URL
import requests
prompt = f"""
基于以下上下文回答问题。如果上下文不包含相关信息,请说明无法基于提供的资料回答。
要求:答案简洁、准确,符合技术文档规范。
上下文:
{context}
问题: {question}
回答:
"""
try:
base_url = OLLAMA_BASE_URL.rstrip('/')
ollama_api_url = f"{base_url}/api/generate"
request_body = {
"model": LLM_MODEL,
"prompt": prompt,
"stream": False
}
response = requests.post(ollama_api_url, json=request_body, timeout=30)
if response.status_code == 200:
return response.json().get("response", "未获取到答案内容")
else:
return f"Ollama API请求失败状态码{response.status_code},原因:{response.text}"
except requests.exceptions.ConnectionError:
return f"连接Ollama服务失败{ollama_api_url}),请检查网络"
except requests.exceptions.Timeout:
return f"连接Ollama服务超时30秒可能是模型生成过慢"
except Exception as e:
return f"生成答案出错:{str(e)}"
# -------------------------- 新增:角色化答案生成方法 --------------------------
def _generate_role_answer(self, question, context, role):
"""
角色化答案生成按指定角色调整prompt语气和内容风格
:param question: 用户问题
:param context: 检索到的相关文档上下文
:param role: 指定角色
:return: 符合角色身份的答案
"""
from config import OLLAMA_BASE_URL
import requests
# 核心区别在prompt中加入“角色定义”让模型按身份生成内容
# 不同角色对应不同的语气要求(可根据需要扩展更多角色的提示词)
role_prompt_map = {
"纱露朵": SALT_PROMPT,
"Reisasol" : REISASOL_PROMPT
}
# 获取当前角色的提示词(若角色不在预设中,用默认提示词)
role_desc = role_prompt_map.get(
role,
f"你是{role},请基于上下文回答问题,保持语气符合{role}的身份。"
)
# 构造角色化prompt
prompt = f"""
{role_desc}
核心要求:基于以下上下文回答,如果上下文不包含相关信息,请给出最合适的回答,并确保内容符合{role}的身份。”。
上下文:
{context}
问题: {question}
回答:
"""
# 后续API调用逻辑和_generate_answer一致复用网络请求代码
try:
base_url = OLLAMA_BASE_URL.rstrip('/')
ollama_api_url = f"{base_url}/api/generate"
request_body = {
"model": LLM_MODEL,
"prompt": prompt,
"stream": False
}
response = requests.post(ollama_api_url, json=request_body, timeout=30)
if response.status_code == 200:
return response.json().get("response", "未获取到答案内容")
else:
return f"Ollama API请求失败状态码{response.status_code},原因:{response.text}"
except requests.exceptions.ConnectionError:
return f"连接Ollama服务失败{ollama_api_url}),请检查网络"
except requests.exceptions.Timeout:
return f"连接Ollama服务超时30秒可能是模型生成过慢"
except Exception as e:
return f"生成角色化答案出错:{str(e)}"
# ----------------------------------------------------------------------------