1
This commit is contained in:
240
songall.py
Normal file
240
songall.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# 导入必要的库
|
||||
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from fastapi import FastAPI, Query, Body # 新增 Body 导入
|
||||
from typing import List, Dict
|
||||
import uvicorn
|
||||
|
||||
# 配置参数
|
||||
# Milvus配置
|
||||
MILVUS_HOST = "100.80.156.98"
|
||||
MILVUS_PORT = "19530"
|
||||
COLLECTION_NAME = "song_knowledge_base" # 歌曲专属集合
|
||||
EMBEDDING_DIM = 768
|
||||
|
||||
# Ollama配置(仅用于生成嵌入,无AI回答逻辑)
|
||||
OLLAMA_BASE_URL = "http://100.89.166.61:11434/"
|
||||
EMBEDDING_MODEL = "nomic-embed-text"
|
||||
|
||||
# 文本分割配置
|
||||
CHUNK_SIZE = 500
|
||||
CHUNK_OVERLAP = 50
|
||||
|
||||
# 初始化FastAPI应用
|
||||
app = FastAPI(title="歌曲模糊查询API服务", version="1.0")
|
||||
|
||||
|
||||
# Milvus客户端工具函数
|
||||
def connect_to_milvus():
|
||||
"""连接到Milvus数据库"""
|
||||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
|
||||
def create_song_collection():
|
||||
"""创建歌曲专属Milvus集合"""
|
||||
if utility.has_collection(COLLECTION_NAME):
|
||||
return Collection(COLLECTION_NAME)
|
||||
|
||||
# 定义字段(适配歌曲数据)
|
||||
id_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True
|
||||
)
|
||||
embedding_field = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=EMBEDDING_DIM
|
||||
)
|
||||
text_field = FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535 # 存储歌曲信息文本
|
||||
)
|
||||
song_id_field = FieldSchema(
|
||||
name="song_id",
|
||||
dtype=DataType.INT64, # 歌曲原始ID
|
||||
max_length=64
|
||||
)
|
||||
title_field = FieldSchema(
|
||||
name="title",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=256 # 歌曲名称(用于快速匹配)
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields=[id_field, embedding_field, text_field, song_id_field, title_field],
|
||||
description="Song knowledge base collection for fuzzy search"
|
||||
)
|
||||
|
||||
collection = Collection(
|
||||
name=COLLECTION_NAME,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
|
||||
# 创建向量索引(用于模糊语义匹配)
|
||||
index_params = {
|
||||
"index_type": "IVF_FLAT",
|
||||
"metric_type": "L2",
|
||||
"params": {"nlist": 128}
|
||||
}
|
||||
collection.create_index(field_name="embedding", index_params=index_params)
|
||||
return collection
|
||||
|
||||
|
||||
def insert_song_documents(collection, embeddings, texts, song_ids, titles):
|
||||
"""插入歌曲文档到集合"""
|
||||
insert_data = [
|
||||
embeddings,
|
||||
texts,
|
||||
song_ids,
|
||||
titles
|
||||
]
|
||||
collection.insert(insert_data)
|
||||
collection.flush()
|
||||
|
||||
|
||||
def search_song_by_fuzzy(collection, query_text, top_k=10):
|
||||
"""模糊查询歌曲(基于语义嵌入匹配)"""
|
||||
# 生成查询文本的嵌入向量
|
||||
embedder = OllamaEmbeddings(
|
||||
model=EMBEDDING_MODEL,
|
||||
base_url=OLLAMA_BASE_URL
|
||||
)
|
||||
query_embedding = embedder.embed_query(query_text)
|
||||
|
||||
# Milvus向量搜索(移除 output_fields 中的 "distance")
|
||||
collection.load()
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10}
|
||||
}
|
||||
results = collection.search(
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["song_id", "title", "text"] # 去掉 "distance"
|
||||
)
|
||||
|
||||
# 格式化结果(distance 从 hit 对象中获取,无需从 entity 中提取)
|
||||
matched_songs = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
matched_songs.append({
|
||||
"song_id": hit.entity.get("song_id"),
|
||||
"title": hit.entity.get("title"),
|
||||
"detail": hit.entity.get("text"),
|
||||
"similarity_score": 1 / (1 + hit.distance) # hit.distance 直接获取
|
||||
})
|
||||
return matched_songs
|
||||
|
||||
# 文档处理器(简化版,仅用于歌曲文本处理)
|
||||
class SongDocumentProcessor:
|
||||
def __init__(self):
|
||||
self.embedder = OllamaEmbeddings(
|
||||
model=EMBEDDING_MODEL,
|
||||
base_url=OLLAMA_BASE_URL
|
||||
)
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=CHUNK_OVERLAP,
|
||||
separators=["\n\n", "\n", " ", ""]
|
||||
)
|
||||
|
||||
def process_songs(self, songs_data: List[Dict]):
|
||||
"""处理歌曲数据,生成嵌入向量"""
|
||||
# 构造歌曲文本信息
|
||||
texts = []
|
||||
song_ids = []
|
||||
titles = []
|
||||
for song in songs_data:
|
||||
song_id = song.get("id", 0)
|
||||
title = song.get("title", "未知歌曲")
|
||||
# 拼接歌曲详情文本(可扩展其他字段)
|
||||
detail_text = f"歌曲ID: {song_id}, 歌曲名称: {title}"
|
||||
texts.append(detail_text)
|
||||
song_ids.append(song_id)
|
||||
titles.append(title)
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = self.embedder.embed_documents(texts)
|
||||
return embeddings, texts, song_ids, titles
|
||||
|
||||
|
||||
# 初始化Milvus连接和集合
|
||||
connect_to_milvus()
|
||||
song_collection = create_song_collection()
|
||||
song_processor = SongDocumentProcessor()
|
||||
|
||||
|
||||
# API接口定义
|
||||
@app.post("/api/songs/insert", summary="录入歌曲数据")
|
||||
def insert_songs(
|
||||
songs: List[Dict] = Body(..., description="歌曲列表,格式:[{\"id\":0,\"title\":\"实例歌曲\"}]")
|
||||
):
|
||||
"""
|
||||
录入歌曲数据到知识库:
|
||||
- 接收歌曲列表,格式为[{"id": 歌曲ID, "title": "歌曲名称"}]
|
||||
- 自动处理并存储到Milvus,支持后续模糊查询
|
||||
"""
|
||||
if not songs:
|
||||
return {"code": 400, "message": "歌曲数据不能为空", "data": None}
|
||||
|
||||
# 处理歌曲数据
|
||||
embeddings, texts, song_ids, titles = song_processor.process_songs(songs)
|
||||
# 插入Milvus
|
||||
insert_song_documents(song_collection, embeddings, texts, song_ids, titles)
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"message": f"成功录入 {len(songs)} 首歌曲",
|
||||
"data": {
|
||||
"inserted_count": len(songs),
|
||||
"example": songs[:1] # 返回第一条作为示例
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/songs/search", summary="模糊查询歌曲")
|
||||
def fuzzy_search_songs(
|
||||
keyword: str = Query(..., description="查询关键词(歌曲名称模糊匹配)"),
|
||||
top_k: int = Query(10, ge=1, le=50, description="返回匹配数量,1-50之间")
|
||||
):
|
||||
"""
|
||||
模糊查询歌曲(基于语义相似度):
|
||||
- 输入关键词,返回语义最相似的歌曲列表
|
||||
- 支持同义词、拼写误差等模糊场景匹配
|
||||
"""
|
||||
if not keyword.strip():
|
||||
return {"code": 400, "message": "查询关键词不能为空", "data": None}
|
||||
|
||||
# 执行模糊搜索
|
||||
matched_songs = search_song_by_fuzzy(song_collection, keyword, top_k)
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"message": f"找到 {len(matched_songs)} 首匹配歌曲",
|
||||
"data": {
|
||||
"keyword": keyword,
|
||||
"matched_count": len(matched_songs),
|
||||
"songs": matched_songs
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 主函数(运行API服务)
|
||||
def main():
|
||||
print("启动歌曲模糊查询API服务...")
|
||||
print(f"服务地址:http://127.0.0.1:58329")
|
||||
print(f"API文档:http://127.0.0.1:58329/docs")
|
||||
# 启动uvicorn服务
|
||||
uvicorn.run(app, host="0.0.0.0", port=58329)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user