107 lines
2.6 KiB
Python
107 lines
2.6 KiB
Python
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
|
|
from config import MILVUS_HOST, MILVUS_PORT, COLLECTION_NAME, EMBEDDING_DIM
|
|
|
|
|
|
def connect_to_milvus():
|
|
"""连接到Milvus数据库"""
|
|
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
|
|
|
|
|
def create_collection():
|
|
"""创建Milvus集合"""
|
|
# # 如果集合存在,先删除(仅首次运行需要,后续可注释)
|
|
# if utility.has_collection(COLLECTION_NAME):
|
|
# utility.drop_collection(COLLECTION_NAME) # 添加这行代码删除旧集合
|
|
|
|
"""创建Milvus集合"""
|
|
if utility.has_collection(COLLECTION_NAME):
|
|
return Collection(COLLECTION_NAME)
|
|
|
|
# 定义字段
|
|
id_field = FieldSchema(
|
|
name="id",
|
|
dtype=DataType.INT64,
|
|
is_primary=True,
|
|
auto_id=True
|
|
)
|
|
|
|
embedding_field = FieldSchema(
|
|
name="embedding",
|
|
dtype=DataType.FLOAT_VECTOR,
|
|
dim=EMBEDDING_DIM
|
|
)
|
|
|
|
text_field = FieldSchema(
|
|
name="text",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=65535
|
|
)
|
|
|
|
source_field = FieldSchema(
|
|
name="source",
|
|
dtype=DataType.VARCHAR,
|
|
max_length=256
|
|
)
|
|
|
|
schema = CollectionSchema(
|
|
fields=[id_field, embedding_field, text_field, source_field],
|
|
description="Knowledge base collection"
|
|
)
|
|
|
|
collection = Collection(
|
|
name=COLLECTION_NAME,
|
|
schema=schema,
|
|
using='default',
|
|
shards_num=2
|
|
)
|
|
|
|
# 创建索引
|
|
index_params = {
|
|
"index_type": "IVF_FLAT",
|
|
"metric_type": "L2",
|
|
"params": {"nlist": 128}
|
|
}
|
|
|
|
collection.create_index(field_name="embedding", index_params=index_params)
|
|
return collection
|
|
|
|
|
|
def insert_documents(collection, embeddings, texts, sources):
|
|
"""插入文档到集合"""
|
|
insert_data = [
|
|
embeddings,
|
|
texts,
|
|
sources
|
|
]
|
|
|
|
collection.insert(insert_data)
|
|
collection.flush()
|
|
|
|
|
|
def search_documents(collection, query_embedding, top_k=5):
|
|
"""搜索相似文档"""
|
|
collection.load()
|
|
search_params = {
|
|
"metric_type": "L2",
|
|
"params": {"nprobe": 10}
|
|
}
|
|
|
|
results = collection.search(
|
|
data=[query_embedding],
|
|
anns_field="embedding",
|
|
param=search_params,
|
|
limit=top_k,
|
|
output_fields=["text", "source"]
|
|
)
|
|
|
|
retrieved_docs = []
|
|
for hits in results:
|
|
for hit in hits:
|
|
retrieved_docs.append({
|
|
'text': hit.entity.get('text'),
|
|
'source': hit.entity.get('source'),
|
|
'distance': hit.distance
|
|
})
|
|
|
|
return retrieved_docs
|