cheesyFishes commited on
Commit
d9f893a
·
1 Parent(s): 7a47d4d

move device around more often

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -22,7 +22,7 @@ elif torch.backends.mps.is_available():
22
 
23
  image_embed_model = HuggingFaceEmbedding(
24
  model_name="llamaindex/vdr-2b-multi-v1",
25
- device=device,
26
  trust_remote_code=True,
27
  token=os.getenv("HUGGINGFACE_TOKEN"),
28
  model_kwargs={"torch_dtype": torch.float16},
@@ -31,7 +31,7 @@ image_embed_model = HuggingFaceEmbedding(
31
 
32
  text_embed_model = HuggingFaceEmbedding(
33
  model_name="BAAI/bge-small-en",
34
- device=device,
35
  trust_remote_code=True,
36
  token=os.getenv("HUGGINGFACE_TOKEN"),
37
  embed_batch_size=2,
@@ -80,6 +80,9 @@ def create_index(file, llama_parse_key, progress=gr.Progress()):
80
  image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"]))
81
 
82
  # Create index
 
 
 
83
  progress(0.9, desc="Creating final index...")
84
  index = MultiModalVectorStoreIndex.from_documents(
85
  text_docs + image_docs,
@@ -92,11 +95,17 @@ def create_index(file, llama_parse_key, progress=gr.Progress()):
92
 
93
  except Exception as e:
94
  return None, f"Error creating index: {str(e)}"
 
 
 
 
95
 
96
  def run_search(index, query, text_top_k, image_top_k):
97
  if not index:
98
  return "Please create or select an index first.", [], []
99
-
 
 
100
  retriever = index.as_retriever(
101
  similarity_top_k=text_top_k,
102
  image_similarity_top_k=image_top_k,
@@ -105,6 +114,10 @@ def run_search(index, query, text_top_k, image_top_k):
105
  image_nodes = retriever.text_to_image_retrieve(query)
106
  text_nodes = retriever.text_retrieve(query)
107
 
 
 
 
 
108
  # Extract text and scores from nodes
109
  text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes]
110
 
 
22
 
23
  image_embed_model = HuggingFaceEmbedding(
24
  model_name="llamaindex/vdr-2b-multi-v1",
25
+ device="cpu",
26
  trust_remote_code=True,
27
  token=os.getenv("HUGGINGFACE_TOKEN"),
28
  model_kwargs={"torch_dtype": torch.float16},
 
31
 
32
  text_embed_model = HuggingFaceEmbedding(
33
  model_name="BAAI/bge-small-en",
34
+ device="cpu",
35
  trust_remote_code=True,
36
  token=os.getenv("HUGGINGFACE_TOKEN"),
37
  embed_batch_size=2,
 
80
  image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"]))
81
 
82
  # Create index
83
+ # move models back to CPU
84
+ index._image_embed_model._model.to(device)
85
+ index._embed_model._model.to(device)
86
  progress(0.9, desc="Creating final index...")
87
  index = MultiModalVectorStoreIndex.from_documents(
88
  text_docs + image_docs,
 
95
 
96
  except Exception as e:
97
  return None, f"Error creating index: {str(e)}"
98
+ finally:
99
+ # move models back to CPU
100
+ index._image_embed_model._model.to("cpu")
101
+ index._embed_model._model.to("cpu")
102
 
103
  def run_search(index, query, text_top_k, image_top_k):
104
  if not index:
105
  return "Please create or select an index first.", [], []
106
+ # move models back to CPU
107
+ index._image_embed_model._model.to(device)
108
+ index._embed_model._model.to(device)
109
  retriever = index.as_retriever(
110
  similarity_top_k=text_top_k,
111
  image_similarity_top_k=image_top_k,
 
114
  image_nodes = retriever.text_to_image_retrieve(query)
115
  text_nodes = retriever.text_retrieve(query)
116
 
117
+ # move models back to CPU
118
+ index._image_embed_model._model.to("cpu")
119
+ index._embed_model._model.to("cpu")
120
+
121
  # Extract text and scores from nodes
122
  text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes]
123