jxm commited on
Commit
629557b
·
verified ·
1 Parent(s): 5f56f85

Create sentence_transformers_impl.py

Browse files
Files changed (1) hide show
  1. sentence_transformers_impl.py +155 -0
sentence_transformers_impl.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from typing import Any, Optional
7
+
8
+ import torch
9
+ from torch import nn
10
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class Transformer(nn.Module):
16
+ """Hugging Face AutoModel to generate token embeddings.
17
+ Loads the correct class, e.g. BERT / RoBERTa etc.
18
+ Args:
19
+ model_name_or_path: Hugging Face models name
20
+ (https://huggingface.co/models)
21
+ max_seq_length: Truncate any inputs longer than max_seq_length
22
+ model_args: Keyword arguments passed to the Hugging Face
23
+ Transformers model
24
+ tokenizer_args: Keyword arguments passed to the Hugging Face
25
+ Transformers tokenizer
26
+ config_args: Keyword arguments passed to the Hugging Face
27
+ Transformers config
28
+ cache_dir: Cache dir for Hugging Face Transformers to store/load
29
+ models
30
+ do_lower_case: If true, lowercases the input (independent if the
31
+ model is cased or not)
32
+ tokenizer_name_or_path: Name or path of the tokenizer. When
33
+ None, then model_name_or_path is used
34
+ backend: Backend used for model inference. Can be `torch`, `onnx`,
35
+ or `openvino`. Default is `torch`.
36
+ """
37
+
38
+ save_in_root: bool = True
39
+
40
+ def __init__(
41
+ self,
42
+ model_name_or_path: str,
43
+ model_args: dict[str, Any] | None = None,
44
+ tokenizer_args: dict[str, Any] | None = None,
45
+ config_args: dict[str, Any] | None = None,
46
+ cache_dir: str | None = None,
47
+ **kwargs,
48
+ ) -> None:
49
+ super().__init__()
50
+ if model_args is None:
51
+ model_args = {}
52
+ if tokenizer_args is None:
53
+ tokenizer_args = {}
54
+ if config_args is None:
55
+ config_args = {}
56
+
57
+ if not model_args.get("trust_remote_code", False):
58
+ raise ValueError(
59
+ "You need to set `trust_remote_code=True` to load this model."
60
+ )
61
+
62
+ self.config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
63
+ self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir, **model_args)
64
+
65
+ self.tokenizer = AutoTokenizer.from_pretrained(
66
+ "bert-base-uncased",
67
+ cache_dir=cache_dir,
68
+ **tokenizer_args,
69
+ )
70
+
71
+ def __repr__(self) -> str:
72
+ return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} "
73
+
74
+ def forward(self, features: dict[str, torch.Tensor], dataset_embeddings: Optional[torch.Tensor] = None, **kwargs) -> dict[str, torch.Tensor]:
75
+ """Returns token_embeddings, cls_token"""
76
+ # If we don't have embeddings, then run the 1st stage model.
77
+ # If we do, then run the 2nd stage model.
78
+ if dataset_embeddings is None:
79
+ sentence_embedding = self.auto_model.first_stage_model(
80
+ input_ids=features["input_ids"],
81
+ attention_mask=features["attention_mask"],
82
+ )
83
+ else:
84
+ sentence_embedding = self.auto_model.second_stage_model(
85
+ input_ids=features["input_ids"],
86
+ attention_mask=features["attention_mask"],
87
+ dataset_embeddings=dataset_embeddings,
88
+ )
89
+
90
+ features["sentence_embedding"] = sentence_embedding
91
+ return features
92
+
93
+ def get_word_embedding_dimension(self) -> int:
94
+ return self.auto_model.config.hidden_size
95
+
96
+ def tokenize(
97
+ self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
98
+ ) -> dict[str, torch.Tensor]:
99
+ """Tokenizes a text and maps tokens to token-ids"""
100
+ output = {}
101
+ if isinstance(texts[0], str):
102
+ to_tokenize = [texts]
103
+ elif isinstance(texts[0], dict):
104
+ to_tokenize = []
105
+ output["text_keys"] = []
106
+ for lookup in texts:
107
+ text_key, text = next(iter(lookup.items()))
108
+ to_tokenize.append(text)
109
+ output["text_keys"].append(text_key)
110
+ to_tokenize = [to_tokenize]
111
+ else:
112
+ batch1, batch2 = [], []
113
+ for text_tuple in texts:
114
+ batch1.append(text_tuple[0])
115
+ batch2.append(text_tuple[1])
116
+ to_tokenize = [batch1, batch2]
117
+
118
+ max_seq_length = self.config.max_seq_length
119
+ output.update(
120
+ self.tokenizer(
121
+ *to_tokenize,
122
+ padding=padding,
123
+ truncation="longest_first",
124
+ return_tensors="pt",
125
+ max_length=max_seq_length,
126
+ )
127
+ )
128
+ return output
129
+
130
+ def get_config_dict(self) -> dict[str, Any]:
131
+ return {}
132
+
133
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
134
+ self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
135
+ self.tokenizer.save_pretrained(output_path)
136
+
137
+ with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
138
+ json.dump(self.get_config_dict(), fOut, indent=2)
139
+
140
+ @classmethod
141
+ def load(cls, input_path: str) -> Transformer:
142
+ sbert_config_path = os.path.join(input_path, "sentence_bert_config.json")
143
+ if not os.path.exists(sbert_config_path):
144
+ return cls(model_name_or_path=input_path)
145
+
146
+ with open(sbert_config_path) as fIn:
147
+ config = json.load(fIn)
148
+ # Don't allow configs to set trust_remote_code
149
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
150
+ config["model_args"].pop("trust_remote_code")
151
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
152
+ config["tokenizer_args"].pop("trust_remote_code")
153
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
154
+ config["config_args"].pop("trust_remote_code")
155
+ return cls(model_name_or_path=input_path, **config)