File size: 3,908 Bytes
64772a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader,UnstructuredFileLoader
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.retrievers import BM25Retriever
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import ResponseHandlingException
from glob import glob
from llama_index.vector_stores.qdrant import QdrantVectorStore
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import models, SentenceTransformer
from langchain.embeddings.base import Embeddings
from qdrant_client.models import VectorParams
import torch

# from llama_index import SimpleDirectoryReader, StorageContext
class ClinicalBertEmbeddings(Embeddings):
    def __init__(self, model_name: str = "medicalai/ClinicalBERT"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def embed(self, text: str):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
        return embeddings.squeeze().numpy()

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    def embed_documents(self, texts):
        return [self.embed(text) for text in texts]

    def embed_query(self, text):
        return self.embed(text)

# Initialize our custom embeddings class
embeddings = ClinicalBertEmbeddings()    

# # Load ClinicalBERT
# tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")
# model = AutoModel.from_pretrained("medicalai/ClinicalBERT")

# # Use mean pooling
# pooling_model = models.Pooling(model.config.hidden_size, pooling_mode_mean_tokens=True)

# # Create custom SentenceTransformer model
# # sentence_transformer_model = SentenceTransformer(modules=[model, pooling_model])
# # Use mean pooling
# sentence_transformer_model = SentenceTransformer(modules=[model])
# embeddings = SentenceTransformerEmbeddings(model=sentence_transformer_model)
print(embeddings)

# Tokenize sentences
# loading from directory
loader = DirectoryLoader("Data/", glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredFileLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=70)
texts = text_splitter.split_documents(documents)

# Get elements

print("Create Vector Database")
url = "http://localhost:6333"
client = QdrantClient(url=url, prefer_grpc=False)
collection_name = "vector_db"

try:
# Check if the collection already exists
    collection_info = client.get_collection(collection_name=collection_name)
    print(f"Collection '{collection_name}' already exists.")
except Exception as e:
    print(f"Collection '{collection_name}' does not exist. Creating a new one.")
    print(f"Error: {e}")
    client.create_collection(
        collection_name=collection_name,
        vectors_config= VectorParams(
            size=768,
            distance="Cosine"
        )
    )
    qdrant = Qdrant.from_documents(
    texts,
    embeddings,
    url=url,
    prefer_grpc=False,
    collection_name=collection_name,
    # optimizer_config=optimizer_config
    )
    keyword_retriever = BM25Retriever.from_documents(texts)
    keyword_retriever.k =  3
    print("vector database created.................")