Files
AI-Search/org/milvus_client.py
2025-12-15 14:38:31 +08:00

107 lines
2.6 KiB
Python

from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
from config import MILVUS_HOST, MILVUS_PORT, COLLECTION_NAME, EMBEDDING_DIM
def connect_to_milvus():
"""连接到Milvus数据库"""
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
def create_collection():
"""创建Milvus集合"""
# # 如果集合存在,先删除(仅首次运行需要,后续可注释)
# if utility.has_collection(COLLECTION_NAME):
# utility.drop_collection(COLLECTION_NAME) # 添加这行代码删除旧集合
"""创建Milvus集合"""
if utility.has_collection(COLLECTION_NAME):
return Collection(COLLECTION_NAME)
# 定义字段
id_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True
)
embedding_field = FieldSchema(
name="embedding",
dtype=DataType.FLOAT_VECTOR,
dim=EMBEDDING_DIM
)
text_field = FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535
)
source_field = FieldSchema(
name="source",
dtype=DataType.VARCHAR,
max_length=256
)
schema = CollectionSchema(
fields=[id_field, embedding_field, text_field, source_field],
description="Knowledge base collection"
)
collection = Collection(
name=COLLECTION_NAME,
schema=schema,
using='default',
shards_num=2
)
# 创建索引
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
def insert_documents(collection, embeddings, texts, sources):
"""插入文档到集合"""
insert_data = [
embeddings,
texts,
sources
]
collection.insert(insert_data)
collection.flush()
def search_documents(collection, query_embedding, top_k=5):
"""搜索相似文档"""
collection.load()
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10}
}
results = collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["text", "source"]
)
retrieved_docs = []
for hits in results:
for hit in hits:
retrieved_docs.append({
'text': hit.entity.get('text'),
'source': hit.entity.get('source'),
'distance': hit.distance
})
return retrieved_docs