240 lines
7.3 KiB
Python
240 lines
7.3 KiB
Python
# 导入必要的库
|
||
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.embeddings import OllamaEmbeddings
|
||
from fastapi import FastAPI, Query, Body # 新增 Body 导入
|
||
from typing import List, Dict
|
||
import uvicorn
|
||
|
||
# 配置参数
|
||
# Milvus配置
|
||
MILVUS_HOST = "100.80.156.98"
|
||
MILVUS_PORT = "19530"
|
||
COLLECTION_NAME = "song_knowledge_base" # 歌曲专属集合
|
||
EMBEDDING_DIM = 768
|
||
|
||
# Ollama配置(仅用于生成嵌入,无AI回答逻辑)
|
||
OLLAMA_BASE_URL = "http://100.89.166.61:11434/"
|
||
EMBEDDING_MODEL = "nomic-embed-text"
|
||
|
||
# 文本分割配置
|
||
CHUNK_SIZE = 500
|
||
CHUNK_OVERLAP = 50
|
||
|
||
# 初始化FastAPI应用
|
||
app = FastAPI(title="歌曲模糊查询API服务", version="1.0")
|
||
|
||
|
||
# Milvus客户端工具函数
|
||
def connect_to_milvus():
|
||
"""连接到Milvus数据库"""
|
||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||
|
||
|
||
def create_song_collection():
|
||
"""创建歌曲专属Milvus集合"""
|
||
if utility.has_collection(COLLECTION_NAME):
|
||
return Collection(COLLECTION_NAME)
|
||
|
||
# 定义字段(适配歌曲数据)
|
||
id_field = FieldSchema(
|
||
name="id",
|
||
dtype=DataType.INT64,
|
||
is_primary=True,
|
||
auto_id=True
|
||
)
|
||
embedding_field = FieldSchema(
|
||
name="embedding",
|
||
dtype=DataType.FLOAT_VECTOR,
|
||
dim=EMBEDDING_DIM
|
||
)
|
||
text_field = FieldSchema(
|
||
name="text",
|
||
dtype=DataType.VARCHAR,
|
||
max_length=65535 # 存储歌曲信息文本
|
||
)
|
||
song_id_field = FieldSchema(
|
||
name="song_id",
|
||
dtype=DataType.INT64, # 歌曲原始ID
|
||
max_length=64
|
||
)
|
||
title_field = FieldSchema(
|
||
name="title",
|
||
dtype=DataType.VARCHAR,
|
||
max_length=256 # 歌曲名称(用于快速匹配)
|
||
)
|
||
|
||
schema = CollectionSchema(
|
||
fields=[id_field, embedding_field, text_field, song_id_field, title_field],
|
||
description="Song knowledge base collection for fuzzy search"
|
||
)
|
||
|
||
collection = Collection(
|
||
name=COLLECTION_NAME,
|
||
schema=schema,
|
||
using='default',
|
||
shards_num=2
|
||
)
|
||
|
||
# 创建向量索引(用于模糊语义匹配)
|
||
index_params = {
|
||
"index_type": "IVF_FLAT",
|
||
"metric_type": "L2",
|
||
"params": {"nlist": 128}
|
||
}
|
||
collection.create_index(field_name="embedding", index_params=index_params)
|
||
return collection
|
||
|
||
|
||
def insert_song_documents(collection, embeddings, texts, song_ids, titles):
|
||
"""插入歌曲文档到集合"""
|
||
insert_data = [
|
||
embeddings,
|
||
texts,
|
||
song_ids,
|
||
titles
|
||
]
|
||
collection.insert(insert_data)
|
||
collection.flush()
|
||
|
||
|
||
def search_song_by_fuzzy(collection, query_text, top_k=10):
|
||
"""模糊查询歌曲(基于语义嵌入匹配)"""
|
||
# 生成查询文本的嵌入向量
|
||
embedder = OllamaEmbeddings(
|
||
model=EMBEDDING_MODEL,
|
||
base_url=OLLAMA_BASE_URL
|
||
)
|
||
query_embedding = embedder.embed_query(query_text)
|
||
|
||
# Milvus向量搜索(移除 output_fields 中的 "distance")
|
||
collection.load()
|
||
search_params = {
|
||
"metric_type": "L2",
|
||
"params": {"nprobe": 10}
|
||
}
|
||
results = collection.search(
|
||
data=[query_embedding],
|
||
anns_field="embedding",
|
||
param=search_params,
|
||
limit=top_k,
|
||
output_fields=["song_id", "title", "text"] # 去掉 "distance"
|
||
)
|
||
|
||
# 格式化结果(distance 从 hit 对象中获取,无需从 entity 中提取)
|
||
matched_songs = []
|
||
for hits in results:
|
||
for hit in hits:
|
||
matched_songs.append({
|
||
"song_id": hit.entity.get("song_id"),
|
||
"title": hit.entity.get("title"),
|
||
"detail": hit.entity.get("text"),
|
||
"similarity_score": 1 / (1 + hit.distance) # hit.distance 直接获取
|
||
})
|
||
return matched_songs
|
||
|
||
# 文档处理器(简化版,仅用于歌曲文本处理)
|
||
class SongDocumentProcessor:
|
||
def __init__(self):
|
||
self.embedder = OllamaEmbeddings(
|
||
model=EMBEDDING_MODEL,
|
||
base_url=OLLAMA_BASE_URL
|
||
)
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=CHUNK_OVERLAP,
|
||
separators=["\n\n", "\n", " ", ""]
|
||
)
|
||
|
||
def process_songs(self, songs_data: List[Dict]):
|
||
"""处理歌曲数据,生成嵌入向量"""
|
||
# 构造歌曲文本信息
|
||
texts = []
|
||
song_ids = []
|
||
titles = []
|
||
for song in songs_data:
|
||
song_id = song.get("id", 0)
|
||
title = song.get("title", "未知歌曲")
|
||
# 拼接歌曲详情文本(可扩展其他字段)
|
||
detail_text = f"歌曲ID: {song_id}, 歌曲名称: {title}"
|
||
texts.append(detail_text)
|
||
song_ids.append(song_id)
|
||
titles.append(title)
|
||
|
||
# 生成嵌入向量
|
||
embeddings = self.embedder.embed_documents(texts)
|
||
return embeddings, texts, song_ids, titles
|
||
|
||
|
||
# 初始化Milvus连接和集合
|
||
connect_to_milvus()
|
||
song_collection = create_song_collection()
|
||
song_processor = SongDocumentProcessor()
|
||
|
||
|
||
# API接口定义
|
||
@app.post("/api/songs/insert", summary="录入歌曲数据")
|
||
def insert_songs(
|
||
songs: List[Dict] = Body(..., description="歌曲列表,格式:[{\"id\":0,\"title\":\"实例歌曲\"}]")
|
||
):
|
||
"""
|
||
录入歌曲数据到知识库:
|
||
- 接收歌曲列表,格式为[{"id": 歌曲ID, "title": "歌曲名称"}]
|
||
- 自动处理并存储到Milvus,支持后续模糊查询
|
||
"""
|
||
if not songs:
|
||
return {"code": 400, "message": "歌曲数据不能为空", "data": None}
|
||
|
||
# 处理歌曲数据
|
||
embeddings, texts, song_ids, titles = song_processor.process_songs(songs)
|
||
# 插入Milvus
|
||
insert_song_documents(song_collection, embeddings, texts, song_ids, titles)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": f"成功录入 {len(songs)} 首歌曲",
|
||
"data": {
|
||
"inserted_count": len(songs),
|
||
"example": songs[:1] # 返回第一条作为示例
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/api/songs/search", summary="模糊查询歌曲")
|
||
def fuzzy_search_songs(
|
||
keyword: str = Query(..., description="查询关键词(歌曲名称模糊匹配)"),
|
||
top_k: int = Query(10, ge=1, le=50, description="返回匹配数量,1-50之间")
|
||
):
|
||
"""
|
||
模糊查询歌曲(基于语义相似度):
|
||
- 输入关键词,返回语义最相似的歌曲列表
|
||
- 支持同义词、拼写误差等模糊场景匹配
|
||
"""
|
||
if not keyword.strip():
|
||
return {"code": 400, "message": "查询关键词不能为空", "data": None}
|
||
|
||
# 执行模糊搜索
|
||
matched_songs = search_song_by_fuzzy(song_collection, keyword, top_k)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": f"找到 {len(matched_songs)} 首匹配歌曲",
|
||
"data": {
|
||
"keyword": keyword,
|
||
"matched_count": len(matched_songs),
|
||
"songs": matched_songs
|
||
}
|
||
}
|
||
|
||
|
||
# 主函数(运行API服务)
|
||
def main():
|
||
print("启动歌曲模糊查询API服务...")
|
||
print(f"服务地址:http://127.0.0.1:58329")
|
||
print(f"API文档:http://127.0.0.1:58329/docs")
|
||
# 启动uvicorn服务
|
||
uvicorn.run(app, host="0.0.0.0", port=58329)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |