1
This commit is contained in:
209
org/rag_system.py
Normal file
209
org/rag_system.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from milvus_client import connect_to_milvus, create_collection, insert_documents, search_documents
|
||||
from document_processor import DocumentProcessor
|
||||
from config import LLM_MODEL
|
||||
from org.config import SALT_PROMPT, REISASOL_PROMPT
|
||||
|
||||
|
||||
class RAGSystem:
|
||||
def __init__(self):
|
||||
# 连接Milvus
|
||||
connect_to_milvus()
|
||||
|
||||
# 创建集合
|
||||
self.collection = create_collection()
|
||||
|
||||
# 初始化文档处理器
|
||||
self.processor = DocumentProcessor()
|
||||
|
||||
# 加载集合
|
||||
self.collection.load()
|
||||
|
||||
def add_documents(self, documents):
|
||||
"""添加文档到知识库"""
|
||||
# 分割文档
|
||||
split_texts = self.processor.split_documents(documents)
|
||||
|
||||
# 生成嵌入
|
||||
embeddings = self.processor.generate_embeddings(split_texts)
|
||||
|
||||
# 准备插入数据
|
||||
texts = [text['content'] for text in split_texts]
|
||||
sources = [text['source'] for text in split_texts]
|
||||
|
||||
# 插入数据
|
||||
insert_documents(self.collection, embeddings, texts, sources)
|
||||
|
||||
return len(split_texts)
|
||||
|
||||
def query(self, question, top_k=3):
|
||||
"""基础查询:生成标准答案"""
|
||||
# 生成查询嵌入
|
||||
query_embedding = self.processor.embedder.embed_query(question)
|
||||
|
||||
# 检索相关文档
|
||||
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
|
||||
|
||||
# 构造上下文
|
||||
context = "\n".join([doc['text'] for doc in retrieved_docs])
|
||||
|
||||
# 生成标准答案
|
||||
answer = self._generate_answer(question, context)
|
||||
|
||||
return {
|
||||
'answer': answer,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'context': context
|
||||
}
|
||||
|
||||
def convert_songs_to_documents(songs_data):
|
||||
"""将歌曲数据转换为文档格式"""
|
||||
documents = []
|
||||
for song in songs_data:
|
||||
content = f"歌曲名称: {song['title']}, 歌手: {song['artist']}, BPM: {song['bpm']}, 版本: {song['version']}"
|
||||
documents.append({
|
||||
'content': content,
|
||||
'source': f"歌曲数据 - {song['title']}"
|
||||
})
|
||||
return documents
|
||||
|
||||
def add_song_data(self, songs_data):
|
||||
"""添加歌曲数据到知识库"""
|
||||
# 转换数据格式
|
||||
song_documents = self.convert_songs_to_documents(songs_data)
|
||||
|
||||
# 复用现有文档处理流程
|
||||
split_texts = self.processor.split_documents(song_documents)
|
||||
embeddings = self.processor.generate_embeddings(split_texts)
|
||||
|
||||
texts = [text['content'] for text in split_texts]
|
||||
sources = [text['source'] for text in split_texts]
|
||||
|
||||
insert_documents(self.collection, embeddings, texts, sources)
|
||||
return len(split_texts)
|
||||
|
||||
# -------------------------- 新增:角色扮演式查询方法 --------------------------
|
||||
def role_play_query(self, question, role, top_k=3):
|
||||
"""
|
||||
角色扮演查询:按指定角色生成符合身份的答案
|
||||
:param question: 用户问题(如“Python怎么定义函数?”)
|
||||
:param role: 指定角色(如“Python讲师”“产品经理”“小学生辅导员”)
|
||||
:param top_k: 检索相关文档的数量
|
||||
:return: 包含角色化答案的结果字典
|
||||
"""
|
||||
# 1. 复用现有检索逻辑(和query方法完全一致,确保上下文相关性)
|
||||
query_embedding = self.processor.embedder.embed_query(question)
|
||||
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
|
||||
context = "\n".join([doc['text'] for doc in retrieved_docs])
|
||||
|
||||
# 2. 调用新增的“角色化答案生成方法”(区别于基础的_generate_answer)
|
||||
role_answer = self._generate_role_answer(question, context, role)
|
||||
|
||||
# 3. 返回结构和query方法一致,便于后续使用
|
||||
return {
|
||||
'answer': role_answer, # 角色化的答案
|
||||
'role': role, # 明确返回当前角色
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'context': context
|
||||
}
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
def _generate_answer(self, question, context):
|
||||
"""基础答案生成:无角色限制"""
|
||||
from config import OLLAMA_BASE_URL
|
||||
import requests
|
||||
|
||||
prompt = f"""
|
||||
基于以下上下文回答问题。如果上下文不包含相关信息,请说明无法基于提供的资料回答。
|
||||
要求:答案简洁、准确,符合技术文档规范。
|
||||
|
||||
上下文:
|
||||
{context}
|
||||
|
||||
问题: {question}
|
||||
|
||||
回答:
|
||||
"""
|
||||
|
||||
try:
|
||||
base_url = OLLAMA_BASE_URL.rstrip('/')
|
||||
ollama_api_url = f"{base_url}/api/generate"
|
||||
request_body = {
|
||||
"model": LLM_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(ollama_api_url, json=request_body, timeout=30)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("response", "未获取到答案内容")
|
||||
else:
|
||||
return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}"
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
return f"连接Ollama服务失败({ollama_api_url}),请检查网络"
|
||||
except requests.exceptions.Timeout:
|
||||
return f"连接Ollama服务超时(30秒),可能是模型生成过慢"
|
||||
except Exception as e:
|
||||
return f"生成答案出错:{str(e)}"
|
||||
|
||||
# -------------------------- 新增:角色化答案生成方法 --------------------------
|
||||
def _generate_role_answer(self, question, context, role):
|
||||
"""
|
||||
角色化答案生成:按指定角色调整prompt语气和内容风格
|
||||
:param question: 用户问题
|
||||
:param context: 检索到的相关文档上下文
|
||||
:param role: 指定角色
|
||||
:return: 符合角色身份的答案
|
||||
"""
|
||||
from config import OLLAMA_BASE_URL
|
||||
import requests
|
||||
|
||||
# 核心区别:在prompt中加入“角色定义”,让模型按身份生成内容
|
||||
# 不同角色对应不同的语气要求(可根据需要扩展更多角色的提示词)
|
||||
role_prompt_map = {
|
||||
"纱露朵": SALT_PROMPT,
|
||||
"Reisasol" : REISASOL_PROMPT
|
||||
}
|
||||
|
||||
# 获取当前角色的提示词(若角色不在预设中,用默认提示词)
|
||||
role_desc = role_prompt_map.get(
|
||||
role,
|
||||
f"你是{role},请基于上下文回答问题,保持语气符合{role}的身份。"
|
||||
)
|
||||
|
||||
# 构造角色化prompt
|
||||
prompt = f"""
|
||||
{role_desc}
|
||||
核心要求:基于以下上下文回答,如果上下文不包含相关信息,请给出最合适的回答,并确保内容符合{role}的身份。”。
|
||||
|
||||
上下文:
|
||||
{context}
|
||||
|
||||
问题: {question}
|
||||
|
||||
回答:
|
||||
"""
|
||||
|
||||
# 后续API调用逻辑和_generate_answer一致(复用网络请求代码)
|
||||
try:
|
||||
base_url = OLLAMA_BASE_URL.rstrip('/')
|
||||
ollama_api_url = f"{base_url}/api/generate"
|
||||
request_body = {
|
||||
"model": LLM_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(ollama_api_url, json=request_body, timeout=30)
|
||||
if response.status_code == 200:
|
||||
return response.json().get("response", "未获取到答案内容")
|
||||
else:
|
||||
return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}"
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
return f"连接Ollama服务失败({ollama_api_url}),请检查网络"
|
||||
except requests.exceptions.Timeout:
|
||||
return f"连接Ollama服务超时(30秒),可能是模型生成过慢"
|
||||
except Exception as e:
|
||||
return f"生成角色化答案出错:{str(e)}"
|
||||
# ----------------------------------------------------------------------------
|
||||
Reference in New Issue
Block a user