62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import json
|
|
from rag_system import RAGSystem
|
|
|
|
|
|
def load_songs_from_json(file_path):
|
|
"""从JSON文件加载歌曲数据"""
|
|
try:
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
songs_data = json.load(f)
|
|
return songs_data
|
|
except FileNotFoundError:
|
|
print(f"文件 {file_path} 未找到")
|
|
return []
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSON解析错误: {e}")
|
|
return []
|
|
|
|
|
|
def convert_songs_to_documents(songs_data):
|
|
"""将歌曲数据转换为文档格式"""
|
|
documents = []
|
|
for song in songs_data:
|
|
content = f"歌曲名称: {song.get('title', '未知')}, 歌手: {song.get('artist', '未知')}, BPM: {song.get('bpm', '未知')}, 版本: {song.get('version', '未知')}"
|
|
documents.append({
|
|
'content': content,
|
|
'source': f"歌曲数据 - {song.get('title', '未知歌曲')}"
|
|
})
|
|
return documents
|
|
|
|
|
|
def main():
|
|
# 初始化RAG系统
|
|
rag_system = RAGSystem()
|
|
|
|
# 从JSON文件读取歌曲数据并添加到知识库
|
|
songs_file = "./put.json"
|
|
songs_data = load_songs_from_json(songs_file)
|
|
if songs_data:
|
|
print("正在添加歌曲数据到知识库...")
|
|
song_documents = convert_songs_to_documents(songs_data)
|
|
count = rag_system.add_documents(song_documents)
|
|
print(f"成功添加 {count} 个歌曲文档到知识库")
|
|
|
|
# 示例查询
|
|
questions = [
|
|
"pandora怎么样",
|
|
"你是谁",
|
|
# "upsertMusic怎么用,不是upsertMusic01"
|
|
]
|
|
|
|
for question in questions:
|
|
print(f"\n问题: {question}")
|
|
result = rag_system.role_play_query(question, "Reisasol")
|
|
print(f"答案: {result['answer']}")
|
|
print("参考文档:")
|
|
for i, doc in enumerate(result['retrieved_docs'], 1):
|
|
print(f" {i}. {doc['text'][:100]}... (来源: {doc['source']})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|