Spaces:
Sleeping
Sleeping
Update util/vector_base.py
Browse files- util/vector_base.py +79 -78
util/vector_base.py
CHANGED
|
@@ -1,79 +1,80 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
from langchain_chroma import Chroma
|
| 3 |
-
from langchain_core.documents import Document
|
| 4 |
-
sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
|
| 5 |
-
|
| 6 |
-
from
|
| 7 |
-
import
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
requirement =
|
| 53 |
-
requirement = requirement
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
'
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
requirements_dict_v2[requirement]['
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
"
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from langchain_chroma import Chroma
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
# sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
|
| 5 |
+
sys.path.append('/home/user/app/util')
|
| 6 |
+
from Embeddings import TextEmb3LargeEmbedding
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
class EmbeddingFunction():
|
| 11 |
+
def __init__(self, embeddingmodel):
|
| 12 |
+
self.embeddingmodel = embeddingmodel
|
| 13 |
+
def embed_query(self, query):
|
| 14 |
+
return list(self.embeddingmodel.get_embedding(query))
|
| 15 |
+
def embed_documents(self, documents):
|
| 16 |
+
return [self.embeddingmodel.get_embedding(document) for document in documents]
|
| 17 |
+
|
| 18 |
+
def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
|
| 19 |
+
"""
|
| 20 |
+
判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
|
| 21 |
+
方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
|
| 22 |
+
上限而导致初始化失败。
|
| 23 |
+
"""
|
| 24 |
+
persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
|
| 25 |
+
persist_path = Path(persist_directory)
|
| 26 |
+
if not persist_path.exists and not documents:
|
| 27 |
+
raise ValueError("vector store does not exist and documents is empty")
|
| 28 |
+
elif persist_path.exists():
|
| 29 |
+
print("vector store already exists")
|
| 30 |
+
vector_store = Chroma(
|
| 31 |
+
collection_name=collection_name,
|
| 32 |
+
embedding_function=embedding,
|
| 33 |
+
persist_directory=persist_directory
|
| 34 |
+
)
|
| 35 |
+
else:
|
| 36 |
+
print("start creating vector store")
|
| 37 |
+
vector_store = Chroma(
|
| 38 |
+
collection_name=collection_name,
|
| 39 |
+
embedding_function=embedding,
|
| 40 |
+
persist_directory=persist_directory
|
| 41 |
+
)
|
| 42 |
+
for document in documents:
|
| 43 |
+
vector_store.add_documents(documents=[document])
|
| 44 |
+
time.sleep(1)
|
| 45 |
+
return vector_store
|
| 46 |
+
|
| 47 |
+
if __name__=="__main__":
|
| 48 |
+
import pandas as pd
|
| 49 |
+
requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
|
| 50 |
+
requirements_dict_v2 = {}
|
| 51 |
+
for index, row in requirements_data.iterrows():
|
| 52 |
+
requirement = row['Requirement'].split("- ")[1]
|
| 53 |
+
requirement = requirement + ": " + row['Details']
|
| 54 |
+
requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
|
| 55 |
+
if requirement not in requirements_dict_v2:
|
| 56 |
+
requirements_dict_v2[requirement] = {
|
| 57 |
+
'PO': set(),
|
| 58 |
+
'safeguard': set()
|
| 59 |
+
}
|
| 60 |
+
requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
|
| 61 |
+
requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
|
| 62 |
+
index = 0
|
| 63 |
+
documents = []
|
| 64 |
+
for key, value in requirements_dict_v2.items():
|
| 65 |
+
page_content = key
|
| 66 |
+
metadata = {
|
| 67 |
+
"index": index,
|
| 68 |
+
"version":2,
|
| 69 |
+
"PO": str([po for po in value['PO'] if po]),
|
| 70 |
+
"safeguard":str([safeguard for safeguard in value['safeguard']])
|
| 71 |
+
}
|
| 72 |
+
index += 1
|
| 73 |
+
document=Document(
|
| 74 |
+
page_content=page_content,
|
| 75 |
+
metadata=metadata
|
| 76 |
+
)
|
| 77 |
+
documents.append(document)
|
| 78 |
+
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
|
| 79 |
+
embedding = EmbeddingFunction(embeddingmodel)
|
| 80 |
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
|