Files
AI-Search/org/rag_system.py
2025-12-15 14:38:31 +08:00

209 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from milvus_client import connect_to_milvus, create_collection, insert_documents, search_documents
from document_processor import DocumentProcessor
from config import LLM_MODEL
from org.config import SALT_PROMPT, REISASOL_PROMPT
class RAGSystem:
def __init__(self):
# 连接Milvus
connect_to_milvus()
# 创建集合
self.collection = create_collection()
# 初始化文档处理器
self.processor = DocumentProcessor()
# 加载集合
self.collection.load()
def add_documents(self, documents):
"""添加文档到知识库"""
# 分割文档
split_texts = self.processor.split_documents(documents)
# 生成嵌入
embeddings = self.processor.generate_embeddings(split_texts)
# 准备插入数据
texts = [text['content'] for text in split_texts]
sources = [text['source'] for text in split_texts]
# 插入数据
insert_documents(self.collection, embeddings, texts, sources)
return len(split_texts)
def query(self, question, top_k=3):
"""基础查询:生成标准答案"""
# 生成查询嵌入
query_embedding = self.processor.embedder.embed_query(question)
# 检索相关文档
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
# 构造上下文
context = "\n".join([doc['text'] for doc in retrieved_docs])
# 生成标准答案
answer = self._generate_answer(question, context)
return {
'answer': answer,
'retrieved_docs': retrieved_docs,
'context': context
}
def convert_songs_to_documents(songs_data):
"""将歌曲数据转换为文档格式"""
documents = []
for song in songs_data:
content = f"歌曲名称: {song['title']}, 歌手: {song['artist']}, BPM: {song['bpm']}, 版本: {song['version']}"
documents.append({
'content': content,
'source': f"歌曲数据 - {song['title']}"
})
return documents
def add_song_data(self, songs_data):
"""添加歌曲数据到知识库"""
# 转换数据格式
song_documents = self.convert_songs_to_documents(songs_data)
# 复用现有文档处理流程
split_texts = self.processor.split_documents(song_documents)
embeddings = self.processor.generate_embeddings(split_texts)
texts = [text['content'] for text in split_texts]
sources = [text['source'] for text in split_texts]
insert_documents(self.collection, embeddings, texts, sources)
return len(split_texts)
# -------------------------- 新增:角色扮演式查询方法 --------------------------
def role_play_query(self, question, role, top_k=3):
"""
角色扮演查询:按指定角色生成符合身份的答案
:param question: 用户问题如“Python怎么定义函数
:param role: 指定角色如“Python讲师”“产品经理”“小学生辅导员”
:param top_k: 检索相关文档的数量
:return: 包含角色化答案的结果字典
"""
# 1. 复用现有检索逻辑和query方法完全一致确保上下文相关性
query_embedding = self.processor.embedder.embed_query(question)
retrieved_docs = search_documents(self.collection, query_embedding, top_k)
context = "\n".join([doc['text'] for doc in retrieved_docs])
# 2. 调用新增的“角色化答案生成方法”区别于基础的_generate_answer
role_answer = self._generate_role_answer(question, context, role)
# 3. 返回结构和query方法一致便于后续使用
return {
'answer': role_answer, # 角色化的答案
'role': role, # 明确返回当前角色
'retrieved_docs': retrieved_docs,
'context': context
}
# ----------------------------------------------------------------------------
def _generate_answer(self, question, context):
"""基础答案生成:无角色限制"""
from config import OLLAMA_BASE_URL
import requests
prompt = f"""
基于以下上下文回答问题。如果上下文不包含相关信息,请说明无法基于提供的资料回答。
要求:答案简洁、准确,符合技术文档规范。
上下文:
{context}
问题: {question}
回答:
"""
try:
base_url = OLLAMA_BASE_URL.rstrip('/')
ollama_api_url = f"{base_url}/api/generate"
request_body = {
"model": LLM_MODEL,
"prompt": prompt,
"stream": False
}
response = requests.post(ollama_api_url, json=request_body, timeout=30)
if response.status_code == 200:
return response.json().get("response", "未获取到答案内容")
else:
return f"Ollama API请求失败状态码{response.status_code},原因:{response.text}"
except requests.exceptions.ConnectionError:
return f"连接Ollama服务失败{ollama_api_url}),请检查网络"
except requests.exceptions.Timeout:
return f"连接Ollama服务超时30秒可能是模型生成过慢"
except Exception as e:
return f"生成答案出错:{str(e)}"
# -------------------------- 新增:角色化答案生成方法 --------------------------
def _generate_role_answer(self, question, context, role):
"""
角色化答案生成按指定角色调整prompt语气和内容风格
:param question: 用户问题
:param context: 检索到的相关文档上下文
:param role: 指定角色
:return: 符合角色身份的答案
"""
from config import OLLAMA_BASE_URL
import requests
# 核心区别在prompt中加入“角色定义”让模型按身份生成内容
# 不同角色对应不同的语气要求(可根据需要扩展更多角色的提示词)
role_prompt_map = {
"纱露朵": SALT_PROMPT,
"Reisasol" : REISASOL_PROMPT
}
# 获取当前角色的提示词(若角色不在预设中,用默认提示词)
role_desc = role_prompt_map.get(
role,
f"你是{role},请基于上下文回答问题,保持语气符合{role}的身份。"
)
# 构造角色化prompt
prompt = f"""
{role_desc}
核心要求:基于以下上下文回答,如果上下文不包含相关信息,请给出最合适的回答,并确保内容符合{role}的身份。”。
上下文:
{context}
问题: {question}
回答:
"""
# 后续API调用逻辑和_generate_answer一致复用网络请求代码
try:
base_url = OLLAMA_BASE_URL.rstrip('/')
ollama_api_url = f"{base_url}/api/generate"
request_body = {
"model": LLM_MODEL,
"prompt": prompt,
"stream": False
}
response = requests.post(ollama_api_url, json=request_body, timeout=30)
if response.status_code == 200:
return response.json().get("response", "未获取到答案内容")
else:
return f"Ollama API请求失败状态码{response.status_code},原因:{response.text}"
except requests.exceptions.ConnectionError:
return f"连接Ollama服务失败{ollama_api_url}),请检查网络"
except requests.exceptions.Timeout:
return f"连接Ollama服务超时30秒可能是模型生成过慢"
except Exception as e:
return f"生成角色化答案出错:{str(e)}"
# ----------------------------------------------------------------------------