209 lines
8.0 KiB
Python
209 lines
8.0 KiB
Python
from milvus_client import connect_to_milvus, create_collection, insert_documents, search_documents
|
||
from document_processor import DocumentProcessor
|
||
from config import LLM_MODEL
|
||
from org.config import SALT_PROMPT, REISASOL_PROMPT
|
||
|
||
|
||
class RAGSystem:
|
||
def __init__(self):
|
||
# 连接Milvus
|
||
connect_to_milvus()
|
||
|
||
# 创建集合
|
||
self.collection = create_collection()
|
||
|
||
# 初始化文档处理器
|
||
self.processor = DocumentProcessor()
|
||
|
||
# 加载集合
|
||
self.collection.load()
|
||
|
||
def add_documents(self, documents):
|
||
"""添加文档到知识库"""
|
||
# 分割文档
|
||
split_texts = self.processor.split_documents(documents)
|
||
|
||
# 生成嵌入
|
||
embeddings = self.processor.generate_embeddings(split_texts)
|
||
|
||
# 准备插入数据
|
||
texts = [text['content'] for text in split_texts]
|
||
sources = [text['source'] for text in split_texts]
|
||
|
||
# 插入数据
|
||
insert_documents(self.collection, embeddings, texts, sources)
|
||
|
||
return len(split_texts)
|
||
|
||
def query(self, question, top_k=3):
|
||
"""基础查询:生成标准答案"""
|
||
# 生成查询嵌入
|
||
query_embedding = self.processor.embedder.embed_query(question)
|
||
|
||
# 检索相关文档
|
||
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
|
||
|
||
# 构造上下文
|
||
context = "\n".join([doc['text'] for doc in retrieved_docs])
|
||
|
||
# 生成标准答案
|
||
answer = self._generate_answer(question, context)
|
||
|
||
return {
|
||
'answer': answer,
|
||
'retrieved_docs': retrieved_docs,
|
||
'context': context
|
||
}
|
||
|
||
def convert_songs_to_documents(songs_data):
|
||
"""将歌曲数据转换为文档格式"""
|
||
documents = []
|
||
for song in songs_data:
|
||
content = f"歌曲名称: {song['title']}, 歌手: {song['artist']}, BPM: {song['bpm']}, 版本: {song['version']}"
|
||
documents.append({
|
||
'content': content,
|
||
'source': f"歌曲数据 - {song['title']}"
|
||
})
|
||
return documents
|
||
|
||
def add_song_data(self, songs_data):
|
||
"""添加歌曲数据到知识库"""
|
||
# 转换数据格式
|
||
song_documents = self.convert_songs_to_documents(songs_data)
|
||
|
||
# 复用现有文档处理流程
|
||
split_texts = self.processor.split_documents(song_documents)
|
||
embeddings = self.processor.generate_embeddings(split_texts)
|
||
|
||
texts = [text['content'] for text in split_texts]
|
||
sources = [text['source'] for text in split_texts]
|
||
|
||
insert_documents(self.collection, embeddings, texts, sources)
|
||
return len(split_texts)
|
||
|
||
# -------------------------- 新增:角色扮演式查询方法 --------------------------
|
||
def role_play_query(self, question, role, top_k=3):
|
||
"""
|
||
角色扮演查询:按指定角色生成符合身份的答案
|
||
:param question: 用户问题(如“Python怎么定义函数?”)
|
||
:param role: 指定角色(如“Python讲师”“产品经理”“小学生辅导员”)
|
||
:param top_k: 检索相关文档的数量
|
||
:return: 包含角色化答案的结果字典
|
||
"""
|
||
# 1. 复用现有检索逻辑(和query方法完全一致,确保上下文相关性)
|
||
query_embedding = self.processor.embedder.embed_query(question)
|
||
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
|
||
context = "\n".join([doc['text'] for doc in retrieved_docs])
|
||
|
||
# 2. 调用新增的“角色化答案生成方法”(区别于基础的_generate_answer)
|
||
role_answer = self._generate_role_answer(question, context, role)
|
||
|
||
# 3. 返回结构和query方法一致,便于后续使用
|
||
return {
|
||
'answer': role_answer, # 角色化的答案
|
||
'role': role, # 明确返回当前角色
|
||
'retrieved_docs': retrieved_docs,
|
||
'context': context
|
||
}
|
||
# ----------------------------------------------------------------------------
|
||
|
||
def _generate_answer(self, question, context):
|
||
"""基础答案生成:无角色限制"""
|
||
from config import OLLAMA_BASE_URL
|
||
import requests
|
||
|
||
prompt = f"""
|
||
基于以下上下文回答问题。如果上下文不包含相关信息,请说明无法基于提供的资料回答。
|
||
要求:答案简洁、准确,符合技术文档规范。
|
||
|
||
上下文:
|
||
{context}
|
||
|
||
问题: {question}
|
||
|
||
回答:
|
||
"""
|
||
|
||
try:
|
||
base_url = OLLAMA_BASE_URL.rstrip('/')
|
||
ollama_api_url = f"{base_url}/api/generate"
|
||
request_body = {
|
||
"model": LLM_MODEL,
|
||
"prompt": prompt,
|
||
"stream": False
|
||
}
|
||
|
||
response = requests.post(ollama_api_url, json=request_body, timeout=30)
|
||
if response.status_code == 200:
|
||
return response.json().get("response", "未获取到答案内容")
|
||
else:
|
||
return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}"
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
return f"连接Ollama服务失败({ollama_api_url}),请检查网络"
|
||
except requests.exceptions.Timeout:
|
||
return f"连接Ollama服务超时(30秒),可能是模型生成过慢"
|
||
except Exception as e:
|
||
return f"生成答案出错:{str(e)}"
|
||
|
||
# -------------------------- 新增:角色化答案生成方法 --------------------------
|
||
def _generate_role_answer(self, question, context, role):
|
||
"""
|
||
角色化答案生成:按指定角色调整prompt语气和内容风格
|
||
:param question: 用户问题
|
||
:param context: 检索到的相关文档上下文
|
||
:param role: 指定角色
|
||
:return: 符合角色身份的答案
|
||
"""
|
||
from config import OLLAMA_BASE_URL
|
||
import requests
|
||
|
||
# 核心区别:在prompt中加入“角色定义”,让模型按身份生成内容
|
||
# 不同角色对应不同的语气要求(可根据需要扩展更多角色的提示词)
|
||
role_prompt_map = {
|
||
"纱露朵": SALT_PROMPT,
|
||
"Reisasol" : REISASOL_PROMPT
|
||
}
|
||
|
||
# 获取当前角色的提示词(若角色不在预设中,用默认提示词)
|
||
role_desc = role_prompt_map.get(
|
||
role,
|
||
f"你是{role},请基于上下文回答问题,保持语气符合{role}的身份。"
|
||
)
|
||
|
||
# 构造角色化prompt
|
||
prompt = f"""
|
||
{role_desc}
|
||
核心要求:基于以下上下文回答,如果上下文不包含相关信息,请给出最合适的回答,并确保内容符合{role}的身份。”。
|
||
|
||
上下文:
|
||
{context}
|
||
|
||
问题: {question}
|
||
|
||
回答:
|
||
"""
|
||
|
||
# 后续API调用逻辑和_generate_answer一致(复用网络请求代码)
|
||
try:
|
||
base_url = OLLAMA_BASE_URL.rstrip('/')
|
||
ollama_api_url = f"{base_url}/api/generate"
|
||
request_body = {
|
||
"model": LLM_MODEL,
|
||
"prompt": prompt,
|
||
"stream": False
|
||
}
|
||
|
||
response = requests.post(ollama_api_url, json=request_body, timeout=30)
|
||
if response.status_code == 200:
|
||
return response.json().get("response", "未获取到答案内容")
|
||
else:
|
||
return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}"
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
return f"连接Ollama服务失败({ollama_api_url}),请检查网络"
|
||
except requests.exceptions.Timeout:
|
||
return f"连接Ollama服务超时(30秒),可能是模型生成过慢"
|
||
except Exception as e:
|
||
return f"生成角色化答案出错:{str(e)}"
|
||
# ---------------------------------------------------------------------------- |