100 lines
2.6 KiB
Python
100 lines
2.6 KiB
Python
# file: app.py
|
||
from flask import Flask, request, jsonify
|
||
from rag_system import RAGSystem
|
||
|
||
app = Flask(__name__)
|
||
rag_system = RAGSystem()
|
||
|
||
|
||
@app.route('/add_documents', methods=['POST'])
|
||
def add_documents():
|
||
"""
|
||
添加文档到知识库
|
||
请求体: JSON对象,包含documents字段
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
documents = data.get('documents', [])
|
||
|
||
if not documents:
|
||
return jsonify({'error': 'No documents provided'}), 400
|
||
|
||
count = rag_system.add_documents(documents)
|
||
return jsonify({
|
||
'message': f'Successfully added {count} document chunks to the knowledge base',
|
||
'added_count': count
|
||
}), 201
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/query', methods=['POST'])
|
||
def query():
|
||
"""
|
||
查询知识库
|
||
请求体: JSON对象,包含question字段和可选的top_k参数
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
question = data.get('question')
|
||
top_k = data.get('top_k', 3)
|
||
|
||
if not question:
|
||
return jsonify({'error': 'Question is required'}), 400
|
||
|
||
result = rag_system.query(question, top_k)
|
||
return jsonify(result), 200
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/add_songs', methods=['POST'])
|
||
def add_songs():
|
||
"""
|
||
添加歌曲数据到知识库
|
||
请求体: JSON对象,包含songs_data字段
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
songs_data = data.get('songs_data', [])
|
||
|
||
if not songs_data:
|
||
return jsonify({'error': 'No songs data provided'}), 400
|
||
|
||
count = rag_system.add_song_data(songs_data)
|
||
return jsonify({
|
||
'message': f'Successfully added {count} song documents to the knowledge base',
|
||
'added_count': count
|
||
}), 201
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
@app.route('/role_query', methods=['POST'])
|
||
def role_query():
|
||
"""
|
||
角色扮演查询
|
||
请求体: JSON对象,包含question, role字段和可选的top_k参数
|
||
"""
|
||
try:
|
||
data = request.get_json()
|
||
question = data.get('question')
|
||
role = data.get('role')
|
||
top_k = data.get('top_k', 3)
|
||
|
||
if not question or not role:
|
||
return jsonify({'error': 'Question and role are required'}), 400
|
||
|
||
result = rag_system.role_play_query(question, role, top_k)
|
||
return jsonify(result), 200
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': str(e)}), 500
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='0.0.0.0', port=5000, debug=True)
|