Files
AI-Search/songall.py
2025-12-15 14:38:31 +08:00

240 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
from fastapi import FastAPI, Query, Body # 新增 Body 导入
from typing import List, Dict
import uvicorn
# 配置参数
# Milvus配置
MILVUS_HOST = "100.80.156.98"
MILVUS_PORT = "19530"
COLLECTION_NAME = "song_knowledge_base" # 歌曲专属集合
EMBEDDING_DIM = 768
# Ollama配置仅用于生成嵌入无AI回答逻辑
OLLAMA_BASE_URL = "http://100.89.166.61:11434/"
EMBEDDING_MODEL = "nomic-embed-text"
# 文本分割配置
CHUNK_SIZE = 500
CHUNK_OVERLAP = 50
# 初始化FastAPI应用
app = FastAPI(title="歌曲模糊查询API服务", version="1.0")
# Milvus客户端工具函数
def connect_to_milvus():
"""连接到Milvus数据库"""
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
def create_song_collection():
"""创建歌曲专属Milvus集合"""
if utility.has_collection(COLLECTION_NAME):
return Collection(COLLECTION_NAME)
# 定义字段(适配歌曲数据)
id_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True
)
embedding_field = FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
dim=EMBEDDING_DIM
)
text_field = FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535 # 存储歌曲信息文本
)
song_id_field = FieldSchema(
name="song_id",
dtype=DataType.INT64, # 歌曲原始ID
max_length=64
)
title_field = FieldSchema(
name="title",
dtype=DataType.VARCHAR,
max_length=256 # 歌曲名称(用于快速匹配)
)
schema = CollectionSchema(
fields=[id_field, embedding_field, text_field, song_id_field, title_field],
description="Song knowledge base collection for fuzzy search"
)
collection = Collection(
name=COLLECTION_NAME,
schema=schema,
using='default',
shards_num=2
)
# 创建向量索引(用于模糊语义匹配)
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
def insert_song_documents(collection, embeddings, texts, song_ids, titles):
"""插入歌曲文档到集合"""
insert_data = [
embeddings,
texts,
song_ids,
titles
]
collection.insert(insert_data)
collection.flush()
def search_song_by_fuzzy(collection, query_text, top_k=10):
"""模糊查询歌曲(基于语义嵌入匹配)"""
# 生成查询文本的嵌入向量
embedder = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
query_embedding = embedder.embed_query(query_text)
# Milvus向量搜索移除 output_fields 中的 "distance"
collection.load()
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10}
}
results = collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["song_id", "title", "text"] # 去掉 "distance"
)
# 格式化结果distance 从 hit 对象中获取,无需从 entity 中提取)
matched_songs = []
for hits in results:
for hit in hits:
matched_songs.append({
"song_id": hit.entity.get("song_id"),
"title": hit.entity.get("title"),
"detail": hit.entity.get("text"),
"similarity_score": 1 / (1 + hit.distance) # hit.distance 直接获取
})
return matched_songs
# 文档处理器(简化版,仅用于歌曲文本处理)
class SongDocumentProcessor:
def __init__(self):
self.embedder = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", " ", ""]
)
def process_songs(self, songs_data: List[Dict]):
"""处理歌曲数据,生成嵌入向量"""
# 构造歌曲文本信息
texts = []
song_ids = []
titles = []
for song in songs_data:
song_id = song.get("id", 0)
title = song.get("title", "未知歌曲")
# 拼接歌曲详情文本(可扩展其他字段)
detail_text = f"歌曲ID: {song_id}, 歌曲名称: {title}"
texts.append(detail_text)
song_ids.append(song_id)
titles.append(title)
# 生成嵌入向量
embeddings = self.embedder.embed_documents(texts)
return embeddings, texts, song_ids, titles
# 初始化Milvus连接和集合
connect_to_milvus()
song_collection = create_song_collection()
song_processor = SongDocumentProcessor()
# API接口定义
@app.post("/api/songs/insert", summary="录入歌曲数据")
def insert_songs(
songs: List[Dict] = Body(..., description="歌曲列表,格式:[{\"id\":0,\"title\":\"实例歌曲\"}]")
):
"""
录入歌曲数据到知识库:
- 接收歌曲列表,格式为[{"id": 歌曲ID, "title": "歌曲名称"}]
- 自动处理并存储到Milvus支持后续模糊查询
"""
if not songs:
return {"code": 400, "message": "歌曲数据不能为空", "data": None}
# 处理歌曲数据
embeddings, texts, song_ids, titles = song_processor.process_songs(songs)
# 插入Milvus
insert_song_documents(song_collection, embeddings, texts, song_ids, titles)
return {
"code": 200,
"message": f"成功录入 {len(songs)} 首歌曲",
"data": {
"inserted_count": len(songs),
"example": songs[:1] # 返回第一条作为示例
}
}
@app.get("/api/songs/search", summary="模糊查询歌曲")
def fuzzy_search_songs(
keyword: str = Query(..., description="查询关键词(歌曲名称模糊匹配)"),
top_k: int = Query(10, ge=1, le=50, description="返回匹配数量1-50之间")
):
"""
模糊查询歌曲(基于语义相似度):
- 输入关键词,返回语义最相似的歌曲列表
- 支持同义词、拼写误差等模糊场景匹配
"""
if not keyword.strip():
return {"code": 400, "message": "查询关键词不能为空", "data": None}
# 执行模糊搜索
matched_songs = search_song_by_fuzzy(song_collection, keyword, top_k)
return {
"code": 200,
"message": f"找到 {len(matched_songs)} 首匹配歌曲",
"data": {
"keyword": keyword,
"matched_count": len(matched_songs),
"songs": matched_songs
}
}
# 主函数运行API服务
def main():
print("启动歌曲模糊查询API服务...")
print(f"服务地址http://127.0.0.1:58329")
print(f"API文档http://127.0.0.1:58329/docs")
# 启动uvicorn服务
uvicorn.run(app, host="0.0.0.0", port=58329)
if __name__ == "__main__":
main()