Jack Morris commited on
Commit
546fe43
·
1 Parent(s): 6b3423f

add model use information

Browse files
Files changed (1) hide show
  1. README.md +93 -1
README.md CHANGED
@@ -8645,4 +8645,96 @@ model-index:
8645
  ---
8646
  # Contextual Document Embeddings (CDE)
8647
 
8648
- Our new model that naturally integrates "context tokens" into the embedding process. More information coming soon!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8645
  ---
8646
  # Contextual Document Embeddings (CDE)
8647
 
8648
+ Our new model that naturally integrates "context tokens" into the embedding process.
8649
+
8650
+ # Using `cde-small-v1`
8651
+
8652
+ Our embedding model needs to be used in *two stages*. The first stage is to gather some dataset information by embedding a subset of the corpus using our "first-stage" model. The second stage is to actually embed queries and documents, conditioning on the corpus information from the first stage. Note that we can do the first stage part offline and only use the second-stage weights at inference time.
8653
+
8654
+
8655
+ #### Note on prefixes
8656
+
8657
+ *Nota bene*: Like all state-of-the-art embedding models, our model was trained with task-specific prefixes. To do retrieval, you can prepend the following strings to queries & documents:
8658
+
8659
+ ```python
8660
+ query_prefix = "search_query: "
8661
+ document_prefix = "search_document: "
8662
+ ```
8663
+
8664
+ ## First stage
8665
+
8666
+ ```python
8667
+ minicorpus_size = model.config.transductive_corpus_size
8668
+ minicorpus_docs = [ ... ] # Put some strings here that are representative of your corpus, for example by calling random.sample(corpus, k=minicorpus_size)
8669
+ assert len(minicorpus_docs) == minicorpus_size # You must use exactly this many documents in the minicorpus. You can oversample if your corpus is smaller.
8670
+ minicorpus_docs = tokenizer(
8671
+ [document_prefix + doc for doc in minicorpus_docs],
8672
+ truncation=True,
8673
+ padding=True,
8674
+ max_length=512,
8675
+ return_tensors="pt"
8676
+ )
8677
+ import torch
8678
+ from tqdm.autonotebook import tqdm
8679
+
8680
+ batch_size = 32
8681
+
8682
+ dataset_embeddings = []
8683
+ for i in tqdm(range(0, len(minicorpus_docs["input_ids"]), batch_size)):
8684
+ minicorpus_docs_batch = {k: v[i:i+batch_size] for k,v in minicorpus_docs.items()}
8685
+ with torch.no_grad():
8686
+ dataset_embeddings.append(
8687
+ model.first_stage_model(**minicorpus_docs_batch)
8688
+ )
8689
+
8690
+ dataset_embeddings = torch.cat(dataset_embeddings)
8691
+
8692
+ ## Running the second stage
8693
+
8694
+ Now that we have obtained "dataset embeddings" we can embed documents and queries like normal. Remember to use the document prefix for documents:
8695
+ ```python
8696
+ docs = tokenizer(
8697
+ [document_prefix + doc for doc in docs],
8698
+ truncation=True,
8699
+ padding=True,
8700
+ max_length=512,
8701
+ return_tensors="pt"
8702
+ ).to(device)
8703
+
8704
+ with torch.no_grad():
8705
+ doc_embeddings = model.second_stage_model(
8706
+ input_ids=docs["input_ids"],
8707
+ attention_mask=docs["attention_mask"],
8708
+ dataset_embeddings=dataset_embeddings,
8709
+ )
8710
+ doc_embeddings /= doc_embeddings.norm(p=2, dim=1, keepdim=True)
8711
+ ```
8712
+
8713
+ and the query prefix for queries:
8714
+ ```python
8715
+ queries = queries.select(range(16))["text"]
8716
+ queries = tokenizer(
8717
+ [query_prefix + query for query in queries],
8718
+ truncation=True,
8719
+ padding=True,
8720
+ max_length=512,
8721
+ return_tensors="pt"
8722
+ ).to(device)
8723
+
8724
+ with torch.no_grad():
8725
+ query_embeddings = model.second_stage_model(
8726
+ input_ids=queries["input_ids"],
8727
+ attention_mask=queries["attention_mask"],
8728
+ dataset_embeddings=dataset_embeddings,
8729
+ )
8730
+ query_embeddings /= query_embeddings.norm(p=2, dim=1, keepdim=True)
8731
+ ```
8732
+
8733
+ these embeddings can be compared using dot product, since they're normalized.
8734
+
8735
+ ### What if I don't know what my corpus will be ahead of time?
8736
+
8737
+ ### Colab demo
8738
+
8739
+ We've set up a short demo in a Colab notebook showing how you might use our model:
8740
+ [Try our model in Colab:](https://colab.research.google.com/drive/1r8xwbp7_ySL9lP-ve4XMJAHjidB9UkbL?usp=sharing)