This commit is contained in:
2025-12-15 14:38:31 +08:00
commit 22778e22fb
17 changed files with 938 additions and 0 deletions

240
songall.py Normal file
View File

@@ -0,0 +1,240 @@
# 导入必要的库
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
from fastapi import FastAPI, Query, Body # 新增 Body 导入
from typing import List, Dict
import uvicorn
# 配置参数
# Milvus配置
MILVUS_HOST = "100.80.156.98"
MILVUS_PORT = "19530"
COLLECTION_NAME = "song_knowledge_base" # 歌曲专属集合
EMBEDDING_DIM = 768
# Ollama配置仅用于生成嵌入无AI回答逻辑
OLLAMA_BASE_URL = "http://100.89.166.61:11434/"
EMBEDDING_MODEL = "nomic-embed-text"
# 文本分割配置
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
# 初始化FastAPI应用
app = FastAPI(title="歌曲模糊查询API服务", version="1.0")
# Milvus客户端工具函数
def connect_to_milvus():
"""连接到Milvus数据库"""
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
def create_song_collection():
"""创建歌曲专属Milvus集合"""
if utility.has_collection(COLLECTION_NAME):
return Collection(COLLECTION_NAME)
# 定义字段(适配歌曲数据)
id_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True
)
embedding_field = FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
dim=EMBEDDING_DIM
)
text_field = FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535 # 存储歌曲信息文本
)
song_id_field = FieldSchema(
name="song_id",
dtype=DataType.INT64, # 歌曲原始ID
max_length=64
)
title_field = FieldSchema(
name="title",
dtype=DataType.VARCHAR,
max_length=256 # 歌曲名称(用于快速匹配)
)
schema = CollectionSchema(
fields=[id_field, embedding_field, text_field, song_id_field, title_field],
description="Song knowledge base collection for fuzzy search"
)
collection = Collection(
name=COLLECTION_NAME,
schema=schema,
using='default',
shards_num=2
)
# 创建向量索引(用于模糊语义匹配)
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
def insert_song_documents(collection, embeddings, texts, song_ids, titles):
"""插入歌曲文档到集合"""
insert_data = [
embeddings,
texts,
song_ids,
titles
]
collection.insert(insert_data)
collection.flush()
def search_song_by_fuzzy(collection, query_text, top_k=10):
"""模糊查询歌曲(基于语义嵌入匹配)"""
# 生成查询文本的嵌入向量
embedder = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
query_embedding = embedder.embed_query(query_text)
# Milvus向量搜索移除 output_fields 中的 "distance"
collection.load()
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10}
}
results = collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["song_id", "title", "text"] # 去掉 "distance"
)
# 格式化结果distance 从 hit 对象中获取,无需从 entity 中提取)
matched_songs = []
for hits in results:
for hit in hits:
matched_songs.append({
"song_id": hit.entity.get("song_id"),
"title": hit.entity.get("title"),
"detail": hit.entity.get("text"),
"similarity_score": 1 / (1 + hit.distance) # hit.distance 直接获取
})
return matched_songs
# 文档处理器(简化版,仅用于歌曲文本处理)
class SongDocumentProcessor:
def __init__(self):
self.embedder = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", " ", ""]
)
def process_songs(self, songs_data: List[Dict]):
"""处理歌曲数据,生成嵌入向量"""
# 构造歌曲文本信息
texts = []
song_ids = []
titles = []
for song in songs_data:
song_id = song.get("id", 0)
title = song.get("title", "未知歌曲")
# 拼接歌曲详情文本(可扩展其他字段)
detail_text = f"歌曲ID: {song_id}, 歌曲名称: {title}"
texts.append(detail_text)
song_ids.append(song_id)
titles.append(title)
# 生成嵌入向量
embeddings = self.embedder.embed_documents(texts)
return embeddings, texts, song_ids, titles
# 初始化Milvus连接和集合
connect_to_milvus()
song_collection = create_song_collection()
song_processor = SongDocumentProcessor()
# API接口定义
@app.post("/api/songs/insert", summary="录入歌曲数据")
def insert_songs(
songs: List[Dict] = Body(..., description="歌曲列表,格式:[{\"id\":0,\"title\":\"实例歌曲\"}]")
):
"""
录入歌曲数据到知识库:
- 接收歌曲列表,格式为[{"id": 歌曲ID, "title": "歌曲名称"}]
- 自动处理并存储到Milvus支持后续模糊查询
"""
if not songs:
return {"code": 400, "message": "歌曲数据不能为空", "data": None}
# 处理歌曲数据
embeddings, texts, song_ids, titles = song_processor.process_songs(songs)
# 插入Milvus
insert_song_documents(song_collection, embeddings, texts, song_ids, titles)
return {
"code": 200,
"message": f"成功录入 {len(songs)} 首歌曲",
"data": {
"inserted_count": len(songs),
"example": songs[:1] # 返回第一条作为示例
}
}
@app.get("/api/songs/search", summary="模糊查询歌曲")
def fuzzy_search_songs(
keyword: str = Query(..., description="查询关键词(歌曲名称模糊匹配)"),
top_k: int = Query(10, ge=1, le=50, description="返回匹配数量1-50之间")
):
"""
模糊查询歌曲(基于语义相似度):
- 输入关键词,返回语义最相似的歌曲列表
- 支持同义词、拼写误差等模糊场景匹配
"""
if not keyword.strip():
return {"code": 400, "message": "查询关键词不能为空", "data": None}
# 执行模糊搜索
matched_songs = search_song_by_fuzzy(song_collection, keyword, top_k)
return {
"code": 200,
"message": f"找到 {len(matched_songs)} 首匹配歌曲",
"data": {
"keyword": keyword,
"matched_count": len(matched_songs),
"songs": matched_songs
}
}
# 主函数运行API服务
def main():
print("启动歌曲模糊查询API服务...")
print(f"服务地址http://127.0.0.1:58329")
print(f"API文档http://127.0.0.1:58329/docs")
# 启动uvicorn服务
uvicorn.run(app, host="0.0.0.0", port=58329)
if __name__ == "__main__":
main()