1
This commit is contained in:
106
org/milvus_client.py
Normal file
106
org/milvus_client.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType, utility
|
||||
from config import MILVUS_HOST, MILVUS_PORT, COLLECTION_NAME, EMBEDDING_DIM
|
||||
|
||||
|
||||
def connect_to_milvus():
|
||||
"""连接到Milvus数据库"""
|
||||
connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
||||
|
||||
|
||||
def create_collection():
|
||||
"""创建Milvus集合"""
|
||||
# # 如果集合存在,先删除(仅首次运行需要,后续可注释)
|
||||
# if utility.has_collection(COLLECTION_NAME):
|
||||
# utility.drop_collection(COLLECTION_NAME) # 添加这行代码删除旧集合
|
||||
|
||||
"""创建Milvus集合"""
|
||||
if utility.has_collection(COLLECTION_NAME):
|
||||
return Collection(COLLECTION_NAME)
|
||||
|
||||
# 定义字段
|
||||
id_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True
|
||||
)
|
||||
|
||||
embedding_field = FieldSchema(
|
||||
name="embedding",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=EMBEDDING_DIM
|
||||
)
|
||||
|
||||
text_field = FieldSchema(
|
||||
name="text",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535
|
||||
)
|
||||
|
||||
source_field = FieldSchema(
|
||||
name="source",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=256
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields=[id_field, embedding_field, text_field, source_field],
|
||||
description="Knowledge base collection"
|
||||
)
|
||||
|
||||
collection = Collection(
|
||||
name=COLLECTION_NAME,
|
||||
schema=schema,
|
||||
using='default',
|
||||
shards_num=2
|
||||
)
|
||||
|
||||
# 创建索引
|
||||
index_params = {
|
||||
"index_type": "IVF_FLAT",
|
||||
"metric_type": "L2",
|
||||
"params": {"nlist": 128}
|
||||
}
|
||||
|
||||
collection.create_index(field_name="embedding", index_params=index_params)
|
||||
return collection
|
||||
|
||||
|
||||
def insert_documents(collection, embeddings, texts, sources):
|
||||
"""插入文档到集合"""
|
||||
insert_data = [
|
||||
embeddings,
|
||||
texts,
|
||||
sources
|
||||
]
|
||||
|
||||
collection.insert(insert_data)
|
||||
collection.flush()
|
||||
|
||||
|
||||
def search_documents(collection, query_embedding, top_k=5):
|
||||
"""搜索相似文档"""
|
||||
collection.load()
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10}
|
||||
}
|
||||
|
||||
results = collection.search(
|
||||
data=[query_embedding],
|
||||
anns_field="embedding",
|
||||
param=search_params,
|
||||
limit=top_k,
|
||||
output_fields=["text", "source"]
|
||||
)
|
||||
|
||||
retrieved_docs = []
|
||||
for hits in results:
|
||||
for hit in hits:
|
||||
retrieved_docs.append({
|
||||
'text': hit.entity.get('text'),
|
||||
'source': hit.entity.get('source'),
|
||||
'distance': hit.distance
|
||||
})
|
||||
|
||||
return retrieved_docs
|
||||
Reference in New Issue
Block a user