from milvus_client import connect_to_milvus, create_collection, insert_documents, search_documents from document_processor import DocumentProcessor from config import LLM_MODEL from org.config import SALT_PROMPT, REISASOL_PROMPT class RAGSystem: def __init__(self): # 连接Milvus connect_to_milvus() # 创建集合 self.collection = create_collection() # 初始化文档处理器 self.processor = DocumentProcessor() # 加载集合 self.collection.load() def add_documents(self, documents): """添加文档到知识库""" # 分割文档 split_texts = self.processor.split_documents(documents) # 生成嵌入 embeddings = self.processor.generate_embeddings(split_texts) # 准备插入数据 texts = [text['content'] for text in split_texts] sources = [text['source'] for text in split_texts] # 插入数据 insert_documents(self.collection, embeddings, texts, sources) return len(split_texts) def query(self, question, top_k=3): """基础查询:生成标准答案""" # 生成查询嵌入 query_embedding = self.processor.embedder.embed_query(question) # 检索相关文档 retrieved_docs = search_documents(self.collection, query_embedding, top_k) # 构造上下文 context = "\n".join([doc['text'] for doc in retrieved_docs]) # 生成标准答案 answer = self._generate_answer(question, context) return { 'answer': answer, 'retrieved_docs': retrieved_docs, 'context': context } def convert_songs_to_documents(songs_data): """将歌曲数据转换为文档格式""" documents = [] for song in songs_data: content = f"歌曲名称: {song['title']}, 歌手: {song['artist']}, BPM: {song['bpm']}, 版本: {song['version']}" documents.append({ 'content': content, 'source': f"歌曲数据 - {song['title']}" }) return documents def add_song_data(self, songs_data): """添加歌曲数据到知识库""" # 转换数据格式 song_documents = self.convert_songs_to_documents(songs_data) # 复用现有文档处理流程 split_texts = self.processor.split_documents(song_documents) embeddings = self.processor.generate_embeddings(split_texts) texts = [text['content'] for text in split_texts] sources = [text['source'] for text in split_texts] insert_documents(self.collection, embeddings, texts, sources) return len(split_texts) # -------------------------- 新增:角色扮演式查询方法 -------------------------- def role_play_query(self, question, role, top_k=3): """ 角色扮演查询:按指定角色生成符合身份的答案 :param question: 用户问题(如“Python怎么定义函数?”) :param role: 指定角色(如“Python讲师”“产品经理”“小学生辅导员”) :param top_k: 检索相关文档的数量 :return: 包含角色化答案的结果字典 """ # 1. 复用现有检索逻辑(和query方法完全一致,确保上下文相关性) query_embedding = self.processor.embedder.embed_query(question) retrieved_docs = search_documents(self.collection, query_embedding, top_k) context = "\n".join([doc['text'] for doc in retrieved_docs]) # 2. 调用新增的“角色化答案生成方法”(区别于基础的_generate_answer) role_answer = self._generate_role_answer(question, context, role) # 3. 返回结构和query方法一致,便于后续使用 return { 'answer': role_answer, # 角色化的答案 'role': role, # 明确返回当前角色 'retrieved_docs': retrieved_docs, 'context': context } # ---------------------------------------------------------------------------- def _generate_answer(self, question, context): """基础答案生成:无角色限制""" from config import OLLAMA_BASE_URL import requests prompt = f""" 基于以下上下文回答问题。如果上下文不包含相关信息,请说明无法基于提供的资料回答。 要求:答案简洁、准确,符合技术文档规范。 上下文: {context} 问题: {question} 回答: """ try: base_url = OLLAMA_BASE_URL.rstrip('/') ollama_api_url = f"{base_url}/api/generate" request_body = { "model": LLM_MODEL, "prompt": prompt, "stream": False } response = requests.post(ollama_api_url, json=request_body, timeout=30) if response.status_code == 200: return response.json().get("response", "未获取到答案内容") else: return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}" except requests.exceptions.ConnectionError: return f"连接Ollama服务失败({ollama_api_url}),请检查网络" except requests.exceptions.Timeout: return f"连接Ollama服务超时(30秒),可能是模型生成过慢" except Exception as e: return f"生成答案出错:{str(e)}" # -------------------------- 新增:角色化答案生成方法 -------------------------- def _generate_role_answer(self, question, context, role): """ 角色化答案生成:按指定角色调整prompt语气和内容风格 :param question: 用户问题 :param context: 检索到的相关文档上下文 :param role: 指定角色 :return: 符合角色身份的答案 """ from config import OLLAMA_BASE_URL import requests # 核心区别:在prompt中加入“角色定义”,让模型按身份生成内容 # 不同角色对应不同的语气要求(可根据需要扩展更多角色的提示词) role_prompt_map = { "纱露朵": SALT_PROMPT, "Reisasol" : REISASOL_PROMPT } # 获取当前角色的提示词(若角色不在预设中,用默认提示词) role_desc = role_prompt_map.get( role, f"你是{role},请基于上下文回答问题,保持语气符合{role}的身份。" ) # 构造角色化prompt prompt = f""" {role_desc} 核心要求:基于以下上下文回答,如果上下文不包含相关信息,请给出最合适的回答,并确保内容符合{role}的身份。”。 上下文: {context} 问题: {question} 回答: """ # 后续API调用逻辑和_generate_answer一致(复用网络请求代码) try: base_url = OLLAMA_BASE_URL.rstrip('/') ollama_api_url = f"{base_url}/api/generate" request_body = { "model": LLM_MODEL, "prompt": prompt, "stream": False } response = requests.post(ollama_api_url, json=request_body, timeout=30) if response.status_code == 200: return response.json().get("response", "未获取到答案内容") else: return f"Ollama API请求失败,状态码:{response.status_code},原因:{response.text}" except requests.exceptions.ConnectionError: return f"连接Ollama服务失败({ollama_api_url}),请检查网络" except requests.exceptions.Timeout: return f"连接Ollama服务超时(30秒),可能是模型生成过慢" except Exception as e: return f"生成角色化答案出错:{str(e)}" # ----------------------------------------------------------------------------