emozilla commited on
Commit
2010c83
·
1 Parent(s): 5169b80

update inference code

Browse files
Files changed (14) hide show
  1. aliases.py +7 -0
  2. beam_search.py +1078 -0
  3. checkpoint.py +1671 -0
  4. config.py +1106 -0
  5. configuration_olmo.py +7 -1
  6. exceptions.py +50 -0
  7. initialization.py +95 -0
  8. model.py +1778 -0
  9. modeling_olmo.py +97 -67
  10. tokenization_olmo_fast.py +0 -16
  11. tokenizer.py +180 -0
  12. tokenizer_config.json +6 -3
  13. torch_util.py +139 -0
  14. util.py +655 -0
aliases.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from os import PathLike
2
+ from typing import Union
3
+
4
+ __all__ = ["PathOrStr"]
5
+
6
+
7
+ PathOrStr = Union[str, PathLike]
beam_search.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a self-contained and flexible beam search implementation adapted from
3
+ AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
+ """
5
+
6
+ import copy
7
+ import warnings
8
+ from abc import abstractmethod
9
+ from inspect import signature
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
+
12
+ import torch
13
+
14
+ __all__ = [
15
+ "Sampler",
16
+ "DeterministicSampler",
17
+ "MultinomialSampler",
18
+ "TopKSampler",
19
+ "TopPSampler",
20
+ "GumbelSampler",
21
+ "FinalSequenceScorer",
22
+ "SequenceLogProbabilityScorer",
23
+ "LengthNormalizedSequenceLogProbabilityScorer",
24
+ "Constraint",
25
+ "RepeatedNGramBlockingConstraint",
26
+ "BeamSearch",
27
+ ]
28
+
29
+ StateType = Dict[str, torch.Tensor]
30
+ StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
+ StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
+
33
+ StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
+ """
35
+ The type of step function that can be passed to [`BeamSearch.search`](#search).
36
+
37
+ This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
+ or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
+ """
40
+
41
+ ConstraintStateType = List[List[Dict[str, Any]]]
42
+
43
+
44
+ class Sampler:
45
+ """
46
+ An abstract class that can be used to sample candidates (either nodes or beams)
47
+ within `BeamSearch`.
48
+
49
+ A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
+
51
+ `init_state()` takes three arguments:
52
+
53
+ - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
+ - the batch size, an int,
55
+ - and the number of classes, also an int.
56
+
57
+ It returns a state dictionary with any state tensors needed for subsequent
58
+ calls to `sample_nodes()` and `sample_beams()`.
59
+
60
+ By default this method just returns an empty dictionary.
61
+
62
+ Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
+
64
+ - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
+ - an integer representing the number of samples to take for each example in the batch,
66
+ - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
+ track of state.
68
+
69
+ For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
+ `num_examples = beam_size * per_node_beam_size`.
71
+
72
+ The return value should be a tuple containing:
73
+
74
+ - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
+ - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
+ - and the updated state dictionary.
77
+
78
+ A default implementation of `sample_beams` is provided, which just deterministically
79
+ picks the `k` examples with highest log probability.
80
+ """
81
+
82
+ def init_state(
83
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
+ ) -> StateType:
85
+ del start_class_log_probabilities, batch_size, num_classes
86
+ return {}
87
+
88
+ @abstractmethod
89
+ def sample_nodes(
90
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
+ raise NotImplementedError
93
+
94
+ def sample_beams(
95
+ self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
+ del state
98
+ selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
+ return selected_log_probs, selected_indices, {}
100
+
101
+
102
+ class DeterministicSampler(Sampler):
103
+ """
104
+ A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
+ log probability.
106
+ """
107
+
108
+ def sample_nodes(
109
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
+ del state
112
+ selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
+ return selected_log_probs, selected_indices, {}
114
+
115
+
116
+ class MultinomialSampler(Sampler):
117
+ """
118
+ A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
+ in the default, non-deterministic way.
120
+
121
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
+ above 1.0 produces a flatter probability distribution.
123
+ :param with_replacement: Whether to sample with replacement.
124
+
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ temperature: float = 1.0,
130
+ with_replacement: bool = False,
131
+ ) -> None:
132
+ self.temperature = temperature
133
+ self.with_replacement = with_replacement
134
+
135
+ def sample_nodes(
136
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
+ if self.temperature != 1.0:
139
+ _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
+ else:
141
+ _probabilities = log_probs.exp()
142
+
143
+ selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
+
145
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
+
147
+
148
+ class TopKSampler(Sampler):
149
+ """
150
+ A `Sampler` which redistributes the probability mass function for nodes among the
151
+ top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
+
153
+ Beams are sampled in the default, deterministic way.
154
+
155
+ :param k: The number of top choices to be selected from.
156
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
+ above 1.0 produces a flatter probability distribution.
158
+ :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ k: int = 1,
164
+ temperature: float = 1.0,
165
+ with_replacement: bool = False,
166
+ ):
167
+ self.k = k
168
+ self.temperature = temperature or 1.0
169
+ self.with_replacement = with_replacement
170
+
171
+ def sample_nodes(
172
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
+ if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
+ raise ValueError(
176
+ "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
+ )
178
+
179
+ # shape (both): (batch_size, k)
180
+ top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
+
182
+ # Apply temperature if necessary.
183
+ # shape: (batch_size, k)
184
+ if self.temperature != 1.0:
185
+ top_k_log_probs = top_k_log_probs / self.temperature
186
+
187
+ # Re-normalize the subset.
188
+ # shape: (batch_size, k)
189
+ normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
+
191
+ # Sample from the re-normalized subset.
192
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
+ # shape: (batch_size, per_node_beam_size)
194
+ sampled_indices = torch.multinomial(
195
+ normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
+ )
197
+
198
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
+ # shape: (batch_size, per_node_beam_size)
200
+ indices = top_k_indices.gather(-1, sampled_indices)
201
+
202
+ return log_probs.gather(1, indices), indices, state
203
+
204
+
205
+ class TopPSampler(Sampler):
206
+ """
207
+ A `Sampler` which redistributes the probability mass function for nodes among
208
+ the top choices with a cumulative probability of at least `p`, then samples from that subset
209
+ after re-normalizing the probabilities.
210
+
211
+ Beams are sampled in the default, deterministic way.
212
+
213
+ :param p:
214
+ The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
+ examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
+ insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
+ `per_node_beam_size` examples will be chosen.
218
+ :param temperature:
219
+ A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
+ above 1.0 produces a flatter probability distribution.
221
+ :param with_replacement:
222
+ If set to `True`, samples will be selected with replacement from the top choices.
223
+
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ p: float = 0.9,
229
+ temperature: float = 1.0,
230
+ with_replacement: bool = False,
231
+ ):
232
+ if p < 0.0 or p > 1.0:
233
+ raise ValueError("p must be a positive float no greater than 1.0")
234
+ self.p = p
235
+ self.temperature = temperature or 1.0
236
+ self.with_replacement = with_replacement
237
+
238
+ def sample_nodes(
239
+ self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
+ if not per_node_beam_size <= log_probs.size()[1]:
242
+ raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
+
244
+ # First apply temperature coefficient:
245
+ if self.temperature != 1.0:
246
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
+ else:
248
+ _log_probs = log_probs
249
+
250
+ # Sort the probabilities in descending order to then find cumulative sum
251
+ log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
+
253
+ # shape: (batch_size, num_classes)
254
+ probabilities_descending = log_probs_descending.exp()
255
+ probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
+
257
+ # Create a mask for filtering out probabilities that don't make the top `p`.
258
+ # shape: (batch_size, num_classes)
259
+ exclusion_mask = probabilities_summed >= self.p
260
+
261
+ # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
+ exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
+ exclusion_mask[..., 0] = False
264
+
265
+ # Make sure there's at least `per_node_beam_size` options to be selected.
266
+ if not self.with_replacement:
267
+ exclusion_mask[..., :per_node_beam_size] = False
268
+
269
+ log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
+
271
+ # Now re-normalized the included log probs.
272
+ # shape: (batch_size, num_classes)
273
+ filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
+
275
+ # Sample from the re-normalized subset.
276
+ # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
+ # shape: (batch_size, per_node_beam_size)
278
+ sampled_indices = torch.multinomial(
279
+ filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
+ )
281
+
282
+ # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
+ # shape: (batch_size, per_node_beam_size)
284
+ selected_indices = sorting_indices.gather(-1, sampled_indices)
285
+
286
+ # Return (selected log probabilities, selected classes)
287
+ # shape: (len(log_probs),1) , (len(log_probs), 1)
288
+ return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
+
290
+
291
+ class GumbelSampler(Sampler):
292
+ """
293
+ A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
+ [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
+ Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
+ (https://api.semanticscholar.org/CorpusID:76662039).
297
+
298
+ :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
+ above 1.0 produces a flatter probability distribution.
300
+ """
301
+
302
+ def __init__(self, temperature: float = 1.0):
303
+ self.temperature = temperature
304
+
305
+ def init_state(
306
+ self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
+ ) -> StateType:
308
+ # shape: (batch_size, num_classes)
309
+ zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
+
311
+ # shape: (batch_size, num_classes)
312
+ G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
+
314
+ return {"G_phi_S": G_phi_S}
315
+
316
+ def sample_nodes(
317
+ self,
318
+ log_probs: torch.Tensor,
319
+ per_node_beam_size: int,
320
+ state: StateType,
321
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
+ # First apply temperature coefficient:
323
+ # shape: (batch_size * beam_size, num_classes)
324
+ if self.temperature != 1.0:
325
+ _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
+ else:
327
+ _log_probs = log_probs
328
+
329
+ # shape: (group_size,)
330
+ phi_S = state["phi_S"]
331
+
332
+ # shape: (group_size, num_classes)
333
+ phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
+
335
+ # shape: (group_size, num_classes)
336
+ phi_S_new = phi_S + _log_probs
337
+
338
+ # shape: (group_size, 1)
339
+ G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
+
341
+ # shape: (group_size, num_classes)
342
+ G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
+
344
+ # Replace NaNs with very negative number.
345
+ # shape: (group_size, num_classes)
346
+ # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
+
348
+ # shape (both): (group_size, per_node_beam_size)
349
+ top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
+
351
+ # shape: (group_size, per_node_beam_size)
352
+ top_log_probs = log_probs.gather(1, top_indices)
353
+
354
+ return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
+
356
+ def sample_beams(
357
+ self,
358
+ log_probs: torch.Tensor,
359
+ beam_size: int,
360
+ state: StateType,
361
+ ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
+ """
363
+ Returns the beams with the highest perturbed log probabilities.
364
+ """
365
+ # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
+
367
+ batch_size = log_probs.size()[0]
368
+
369
+ # shape: (batch_size * beam_size, per_node_beam_size)
370
+ G_phi_S = state["G_phi_S"]
371
+
372
+ # shape: (batch_size, beam_size * per_node_beam_size)
373
+ G_phi_S = G_phi_S.reshape_as(log_probs)
374
+
375
+ # shape (both): (batch_size, beam_size)
376
+ G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
+
378
+ # shape: (batch_size, beam_size)
379
+ selected_log_probs = log_probs.gather(1, selected_indices)
380
+
381
+ # Now sort the selected beams by their true log prob.
382
+ # shape (all): (batch_size, beam_size)
383
+ selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
+ selected_indices = selected_indices.gather(1, sort_indices)
385
+ G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
+
387
+ # shape: (batch_size * beam_size,)
388
+ G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
+
390
+ # shape: (batch_size * beam_size,)
391
+ phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
+
393
+ return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
+
395
+ def gumbel(self, phi) -> torch.Tensor:
396
+ """
397
+ Sample `Gumbel(phi)`.
398
+
399
+ `phi` should have shape `(batch_size, num_classes)`.
400
+ """
401
+ return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
+
403
+ def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
+ """
405
+ Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
+
407
+ `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
+ shape `(batch_size, 1)`.
409
+ """
410
+ # Shape: (batch_size, num_classes)
411
+ G_phi = self.gumbel(phi)
412
+
413
+ # Now we find the maximum from these samples.
414
+ # Shape: (batch_size, )
415
+ Z, _ = G_phi.max(dim=-1)
416
+
417
+ # Shape: (batch_size, num_classes)
418
+ v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
+
420
+ # Shape: (batch_size, num_classes)
421
+ return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
+
423
+
424
+ class FinalSequenceScorer:
425
+ """
426
+ An abstract class that can be used to score the final generated sequences found
427
+ by beam search. Given the predicted sequences and the corresponding log probabilities of
428
+ those sequences, the class calculates and returns the final score of the sequences.
429
+
430
+ The default implementation scores the sequences using the sum of the log probabilities of
431
+ the sequence, which is passed as input.
432
+ """
433
+
434
+ @abstractmethod
435
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
+ """
437
+ Score the final predictions found by beam search.
438
+ Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
+
440
+ :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
+ :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
+ of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
+ :param end_index: The index of the end symbol.
444
+
445
+ """
446
+ raise NotImplementedError
447
+
448
+
449
+ class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
+ """
451
+ A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
+ across the sequence's tokens.
453
+ """
454
+
455
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
+ del predictions, end_index
457
+ # The sum of the sequence log probabilities is the input parameter, so just
458
+ # return it.
459
+ return log_probabilities
460
+
461
+
462
+ class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
+ """
464
+ A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
+ tokens in the sequence. It optionally includes a length penalty which promotes
466
+ or demotes sequences based on their lengths. The final score for a sequence will
467
+ be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
+ here includes the end token.
469
+
470
+ :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
+ A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
+ """
473
+
474
+ def __init__(self, length_penalty: float = 1.0):
475
+ super().__init__()
476
+ self.length_penalty = length_penalty
477
+
478
+ def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
+ # shape: (batch_size, beam_size)
480
+ lengths = (predictions != end_index).long().sum(dim=2)
481
+
482
+ # If the sequence ended during beam search, the `log_probabilities` will include
483
+ # the transition to the end token. Therefore, in such situations, `lengths` is
484
+ # actually off by 1. This corrects for that.
485
+ # shape: (batch_size, beam_size)
486
+ is_end_token = predictions[:, :, -1] == end_index
487
+ lengths += is_end_token.long()
488
+
489
+ # shape: (batch_size, beam_size)
490
+ average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
+ return average_log_probs
492
+
493
+
494
+ class Constraint:
495
+ """
496
+ An abstract class that can be used to enforce constraints on the output predictions
497
+ by manipulating the class log probabilities during beam search.
498
+
499
+ A `Constraint` just has three methods that need to be implemented by subclasses:
500
+ `init_state()`, `apply()` and `_update_state()`.
501
+
502
+ `init_state()` takes one argument:
503
+
504
+ - the batch size, an int
505
+
506
+ It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
+ calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
+ Each inner list should be of length 1.
509
+
510
+ `apply()` takes two arguments:
511
+
512
+ - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
+ and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
+ - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
+ log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
+
517
+ The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
+ for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
+ the corresponding log probability to a negligible value such as `float("-inf")` or
520
+ `torch.finfo(class_log_probabilities.dtype).min`.
521
+
522
+ `_update_state()` takes two arguments:
523
+
524
+ - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
+ copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
+ directly edited in-place without affecting the others.
527
+ - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
+ step of beam search.
529
+
530
+ The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
+ length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
+
533
+ """
534
+
535
+ @abstractmethod
536
+ def init_state(
537
+ self,
538
+ batch_size: int,
539
+ ) -> ConstraintStateType:
540
+ raise NotImplementedError
541
+
542
+ @abstractmethod
543
+ def apply(
544
+ self,
545
+ state: ConstraintStateType,
546
+ class_log_probabilities: torch.Tensor,
547
+ ) -> torch.Tensor:
548
+ raise NotImplementedError
549
+
550
+ @staticmethod
551
+ def _copy_state(
552
+ state: ConstraintStateType,
553
+ batch_size: int,
554
+ beam_size: int,
555
+ last_backpointer: Optional[torch.Tensor] = None,
556
+ ) -> ConstraintStateType:
557
+ """
558
+ Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
+ is not appropriate for your constraint, you will need to implement the copying yourself.
560
+ """
561
+ new_state = []
562
+ for i in range(batch_size):
563
+ batch_state = []
564
+ for j in range(beam_size):
565
+ if last_backpointer is None:
566
+ # This is the first prediction, so the backpointer is 0
567
+ backpointer = 0
568
+ else:
569
+ backpointer = last_backpointer[i, j].item()
570
+ batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
+ new_state.append(batch_state)
572
+ return new_state
573
+
574
+ def update_state(
575
+ self,
576
+ state: ConstraintStateType,
577
+ last_prediction: torch.Tensor,
578
+ last_backpointer: Optional[torch.Tensor] = None,
579
+ ) -> ConstraintStateType:
580
+ batch_size, beam_size = last_prediction.size()
581
+ new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
+ return self._update_state(new_state, last_prediction)
583
+
584
+ @abstractmethod
585
+ def _update_state(
586
+ self,
587
+ state: ConstraintStateType,
588
+ last_prediction: torch.Tensor,
589
+ ) -> ConstraintStateType:
590
+ raise NotImplementedError
591
+
592
+
593
+ class RepeatedNGramBlockingConstraint(Constraint):
594
+ def __init__(self, ngram_size: int, **kwargs) -> None:
595
+ super().__init__(**kwargs)
596
+ self.ngram_size = ngram_size
597
+
598
+ def init_state(
599
+ self,
600
+ batch_size: int,
601
+ ) -> ConstraintStateType:
602
+ return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
+
604
+ def apply(
605
+ self,
606
+ state: ConstraintStateType,
607
+ class_log_probabilities: torch.Tensor,
608
+ ) -> torch.Tensor:
609
+ for i, batch in enumerate(state):
610
+ for j, beam in enumerate(batch):
611
+ current_prefix = tuple(beam["current_prefix"])
612
+ seen_ngrams = beam["seen_ngrams"]
613
+ try:
614
+ disallowed_indices = seen_ngrams[current_prefix]
615
+ class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
+ class_log_probabilities.dtype
617
+ ).min
618
+ except KeyError:
619
+ # We have not seen this prefix before, so there is no index
620
+ # that needs to be blocked
621
+ pass
622
+ return class_log_probabilities
623
+
624
+ def _update_state(
625
+ self,
626
+ state: ConstraintStateType,
627
+ last_prediction: torch.Tensor,
628
+ ) -> ConstraintStateType:
629
+ for i, batch in enumerate(state):
630
+ for j, beam in enumerate(batch):
631
+ prediction = last_prediction[i, j].item()
632
+ prefix = beam["current_prefix"]
633
+ seen_ngrams = beam["seen_ngrams"]
634
+
635
+ if len(prefix) == self.ngram_size - 1:
636
+ # This is a new ngram that we have to remember
637
+ if tuple(prefix) not in seen_ngrams:
638
+ seen_ngrams[tuple(prefix)] = []
639
+ seen_ngrams[tuple(prefix)].append(prediction)
640
+
641
+ # Create the new prefix, removing the oldest index if the prefix
642
+ # is too long
643
+ prefix.append(prediction)
644
+ if len(prefix) == self.ngram_size:
645
+ prefix.pop(0)
646
+ return state
647
+
648
+
649
+ class BeamSearch:
650
+ """
651
+ Implements the beam search algorithm for decoding the most likely sequences.
652
+
653
+ :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
+
655
+ :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
+ of the predicted sequences.
657
+
658
+ :param beam_size: The width of the beam used.
659
+
660
+ :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
+ If not given, this just defaults to `beam_size`. Setting this parameter
662
+ to a number smaller than `beam_size` may give better results, as it can introduce
663
+ more diversity into the search. See
664
+ [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
+ (https://api.semanticscholar.org/CorpusID:2229477).
666
+
667
+ :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
+ If not specified, `DeterministicSampler` will be used, which just takes the
669
+ `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
+
671
+ Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
+ [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
+
674
+ :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
+ the predicted sequences. This does not include the start or end tokens. If `None`,
676
+ no minimum is enforced.
677
+
678
+ :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
+ The output from this module is what is returned by the `search` method. If not
680
+ specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
+ by the sum of the token log probabilities.
682
+
683
+ :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
+ provided, no constraints will be enforced.
685
+
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ end_index: int,
691
+ *,
692
+ max_steps: int = 50,
693
+ beam_size: int = 10,
694
+ per_node_beam_size: Optional[int] = None,
695
+ sampler: Optional[Sampler] = None,
696
+ min_steps: Optional[int] = None,
697
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
+ constraints: Optional[List[Constraint]] = None,
699
+ ) -> None:
700
+ if not max_steps > 0:
701
+ raise ValueError("max_steps must be positive")
702
+ if not beam_size > 0:
703
+ raise ValueError("beam_size must be positive")
704
+ if per_node_beam_size is not None and not per_node_beam_size > 0:
705
+ raise ValueError("per_node_beam_size must be positive")
706
+ if min_steps is not None:
707
+ if not min_steps >= 0:
708
+ raise ValueError("min_steps must be non-negative")
709
+ if not min_steps <= max_steps:
710
+ raise ValueError("min_steps must be less than or equal to max_steps")
711
+
712
+ self._end_index = end_index
713
+ self.max_steps = max_steps
714
+ self.beam_size = beam_size
715
+ self.per_node_beam_size = per_node_beam_size or beam_size
716
+ self.sampler = sampler or DeterministicSampler()
717
+ self.min_steps = min_steps or 0
718
+ self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
+ self.constraints = constraints or []
720
+
721
+ @staticmethod
722
+ def _reconstruct_sequences(predictions, backpointers):
723
+ # Reconstruct the sequences.
724
+ # shape: [(batch_size, beam_size, 1)]
725
+ reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
+
727
+ if not backpointers:
728
+ return reconstructed_predictions
729
+
730
+ # shape: (batch_size, beam_size)
731
+ cur_backpointers = backpointers[-1]
732
+
733
+ for timestep in range(len(predictions) - 2, 0, -1):
734
+ # shape: (batch_size, beam_size, 1)
735
+ cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
+
737
+ reconstructed_predictions.append(cur_preds)
738
+
739
+ # shape: (batch_size, beam_size)
740
+ cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
+
742
+ # shape: (batch_size, beam_size, 1)
743
+ final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
+
745
+ reconstructed_predictions.append(final_preds)
746
+
747
+ return reconstructed_predictions
748
+
749
+ def search(
750
+ self,
751
+ start_predictions: torch.Tensor,
752
+ start_state: StateType,
753
+ step: StepFunctionType,
754
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
755
+ """
756
+ Given a starting state and a step function, apply beam search to find the
757
+ most likely target sequences.
758
+
759
+ Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
+ has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
+ has shape `(batch_size, beam_size)`.
762
+
763
+ .. note::
764
+ If your step function returns `-inf` for some log probabilities
765
+ (like if you're using a masked log-softmax) then some of the "best"
766
+ sequences returned may also have `-inf` log probability. Specifically
767
+ this happens when the beam size is smaller than the number of actions
768
+ with finite log probability (non-zero probability) returned by the step function.
769
+ Therefore if you're using a mask you may want to check the results from `search`
770
+ and potentially discard sequences with non-finite log probability.
771
+
772
+ :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
+ Usually the initial predictions are just the index of the "start" token
774
+ in the target vocabulary.
775
+
776
+ :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
+ should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
+ number of dimensions.
779
+
780
+ :param step: A function that is responsible for computing the next most likely tokens,
781
+ given the current state and the predictions from the last time step.
782
+ The function should accept two or three arguments:
783
+
784
+ - a tensor of shape `(group_size,)` or representing the index of the predicted
785
+ tokens from the last time step,
786
+ - the current state, a `StateType`, and
787
+ - optionally, the timestep, an `int`.
788
+
789
+ The `group_size` will be `batch_size * beam_size`, except in the initial
790
+ step, for which it will just be `batch_size`.
791
+
792
+ The function is expected to return a tuple, where the first element
793
+ is a tensor of shape `(group_size, vocab_size)` containing
794
+ the log probabilities of the tokens for the next step, and the second
795
+ element is the updated state. The tensor in the state should have shape
796
+ `(group_size, *)`, where `*` means any other number of dimensions.
797
+
798
+ """
799
+ step_signature = signature(step)
800
+ if len(step_signature.parameters) < 3:
801
+ # If the step function we're given does not take the time step argument, wrap it
802
+ # in one that does.
803
+ old_step = cast(StepFunctionTypeNoTimestep, step)
804
+
805
+ def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
+ del time_step
807
+ return old_step(last_predictions, state)
808
+
809
+ return self._search(start_predictions, start_state, new_step)
810
+ else:
811
+ return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
+
813
+ def _search(
814
+ self,
815
+ start_predictions: torch.Tensor,
816
+ start_state: StateType,
817
+ step: StepFunctionTypeWithTimestep,
818
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
819
+ batch_size = start_predictions.size()[0]
820
+
821
+ # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
+ # include the start symbols, which are implicit.
823
+ predictions: List[torch.Tensor] = []
824
+
825
+ # List of (batch_size, beam_size) tensors. One for each time step. None for
826
+ # the first. Stores the index n for the parent prediction, i.e.
827
+ # predictions[t-1][i][n], that it came from.
828
+ backpointers: List[torch.Tensor] = []
829
+
830
+ constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
+
832
+ # Calculate the first timestep. This is done outside the main loop
833
+ # because we are going from a single decoder input (the output from the
834
+ # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
+ # within the main loop we are going from the `beam_size` elements of the
836
+ # beam to `beam_size`^2 candidates from which we will select the top
837
+ # `beam_size` elements for the next iteration.
838
+ # shape: (batch_size, num_classes)
839
+ start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
+
841
+ num_classes = start_class_log_probabilities.size()[1]
842
+
843
+ # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
+ if self.per_node_beam_size > num_classes:
845
+ raise ValueError(
846
+ f"Vocab size ({num_classes:d}) too small "
847
+ f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
+ f"Please decrease beam_size or per_node_beam_size."
849
+ )
850
+
851
+ sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
+
853
+ # Apply all constraints.
854
+ if self.constraints:
855
+ # shape: (batch_size, 1, num_classes)
856
+ expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
858
+ expanded_start_class_log_probabilities = constraint.apply(
859
+ constraint_state, expanded_start_class_log_probabilities
860
+ )
861
+ start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
+
863
+ # Prevent selecting the end symbol if there is any min_steps constraint
864
+ if self.min_steps >= 1:
865
+ start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
+ start_class_log_probabilities.dtype
867
+ ).min
868
+
869
+ # Get the initial predicted classed and their log probabilities.
870
+ # shape: (batch_size, beam_size), (batch_size, beam_size)
871
+ (
872
+ start_top_log_probabilities,
873
+ start_predicted_classes,
874
+ sampler_state,
875
+ ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
+
877
+ if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
+ warnings.warn(
879
+ "Empty sequences predicted. You may want to increase the beam size or ensure "
880
+ "your step function is working properly.",
881
+ RuntimeWarning,
882
+ )
883
+ return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
+
885
+ # The log probabilities for the last time step.
886
+ # shape: (batch_size, beam_size)
887
+ last_log_probabilities = start_top_log_probabilities
888
+
889
+ # shape: [(batch_size, beam_size)]
890
+ predictions.append(start_predicted_classes)
891
+
892
+ # Log probability tensor that mandates that the end token is selected.
893
+ # shape: (batch_size * beam_size, num_classes)
894
+ log_probs_after_end = start_class_log_probabilities.new_full(
895
+ (batch_size * self.beam_size, num_classes),
896
+ torch.finfo(start_class_log_probabilities.dtype).min,
897
+ )
898
+ log_probs_after_end[:, self._end_index] = 0.0
899
+
900
+ # Set the same state for each element in the beam.
901
+ self._update_initial_state(state, batch_size)
902
+
903
+ for i, constraint in enumerate(self.constraints):
904
+ constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
+
906
+ for timestep in range(self.max_steps - 1):
907
+ # shape: (batch_size * beam_size,)
908
+ last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
+
910
+ # If every predicted token from the last step is `self._end_index`,
911
+ # then we can stop early.
912
+ if (last_predictions == self._end_index).all():
913
+ break
914
+ # Take a step. This get the predicted log probs of the next classes
915
+ # and updates the state.
916
+ # shape: (batch_size * beam_size, num_classes)
917
+ class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
+
919
+ # Apply all constraints.
920
+ if self.constraints:
921
+ # shape: (batch_size, beam_size, num_classes)
922
+ reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
+ for constraint, constraint_state in zip(self.constraints, constraint_states):
924
+ reshaped_class_log_probabilities = constraint.apply(
925
+ constraint_state, reshaped_class_log_probabilities
926
+ )
927
+ # shape: (batch_size * beam_size, num_classes)
928
+ class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
+
930
+ # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
+ # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
+ # before the for loop). Here we block the end index if the search is not allowed to
933
+ # terminate on this iteration.
934
+ if timestep + 2 <= self.min_steps:
935
+ class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
+
937
+ # shape: (batch_size * beam_size, num_classes)
938
+ last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
+ batch_size * self.beam_size, num_classes
940
+ )
941
+
942
+ # Here we are finding any beams where we predicted the end token in
943
+ # the previous timestep and replacing the distribution with a
944
+ # one-hot distribution, forcing the beam to predict the end token
945
+ # this timestep as well.
946
+ # shape: (batch_size * beam_size, num_classes)
947
+ cleaned_log_probabilities = torch.where(
948
+ last_predictions_expanded == self._end_index,
949
+ log_probs_after_end,
950
+ class_log_probabilities,
951
+ )
952
+
953
+ # shape (both): (batch_size * beam_size, per_node_beam_size)
954
+ top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
+ cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
+ )
957
+
958
+ # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
+ # so that we can add them to the current log probs for this timestep.
960
+ # This lets us maintain the log probability of each element on the beam.
961
+ # shape: (batch_size * beam_size, per_node_beam_size)
962
+ expanded_last_log_probabilities = (
963
+ last_log_probabilities.unsqueeze(2)
964
+ .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
+ .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
+ )
967
+
968
+ # shape: (batch_size * beam_size, per_node_beam_size)
969
+ summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
+
971
+ # shape: (batch_size, beam_size * per_node_beam_size)
972
+ reshaped_summed = summed_top_log_probabilities.reshape(
973
+ batch_size, self.beam_size * self.per_node_beam_size
974
+ )
975
+
976
+ # shape: (batch_size, beam_size * per_node_beam_size)
977
+ reshaped_predicted_classes = predicted_classes.reshape(
978
+ batch_size, self.beam_size * self.per_node_beam_size
979
+ )
980
+
981
+ # Keep only the top `beam_size` beam indices.
982
+ # shape (both): (batch_size, beam_size)
983
+ (
984
+ restricted_beam_log_probs,
985
+ restricted_beam_indices,
986
+ sampler_state,
987
+ ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
+
989
+ # Use the beam indices to extract the corresponding classes.
990
+ # shape: (batch_size, beam_size)
991
+ restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
+
993
+ predictions.append(restricted_predicted_classes)
994
+
995
+ # shape: (batch_size, beam_size)
996
+ last_log_probabilities = restricted_beam_log_probs
997
+
998
+ # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
+ # indices with a common ancestor are grouped together. Hence
1000
+ # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
+ # division as the tensor is a LongTensor.)
1002
+ # shape: (batch_size, beam_size)
1003
+ backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
+ backpointers.append(backpointer)
1005
+
1006
+ # Keep only the pieces of the state tensors corresponding to the
1007
+ # ancestors created this iteration.
1008
+ self._update_state(state, backpointer)
1009
+
1010
+ for i, constraint in enumerate(self.constraints):
1011
+ constraint_states[i] = constraint.update_state(
1012
+ constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
+ )
1014
+
1015
+ # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
+ # log probabilities are expected when using constraints).
1017
+ if not self.constraints and (
1018
+ not torch.isfinite(last_log_probabilities).all()
1019
+ or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
+ ):
1021
+ warnings.warn(
1022
+ "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
+ "Some final sequences may not make sense. "
1024
+ "This can happen when the beam size is larger than the number of valid (non-zero "
1025
+ "probability) transitions that the step function produces.",
1026
+ RuntimeWarning,
1027
+ )
1028
+
1029
+ reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
+
1031
+ # shape: (batch_size, beam_size, max_steps)
1032
+ all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
+
1034
+ # Calculate the final sequence scores
1035
+ # shape: (batch_size, beam_size)
1036
+ final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
+
1038
+ # Sort the sequences based on the final scores so the best scoring
1039
+ # sequence is at index 0
1040
+ sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
+ sorted_all_predictions = torch.gather(
1042
+ all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
+ )
1044
+
1045
+ return sorted_all_predictions, sorted_final_scores
1046
+
1047
+ def _update_initial_state(self, state: StateType, batch_size: int):
1048
+ """
1049
+ Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
+ """
1051
+ for key, state_tensor in state.items():
1052
+ if state_tensor is None:
1053
+ continue
1054
+ # shape: (batch_size * beam_size, *)
1055
+ _, *last_dims = state_tensor.size()
1056
+ state[key] = (
1057
+ state_tensor.unsqueeze(1)
1058
+ .expand(batch_size, self.beam_size, *last_dims)
1059
+ .reshape(batch_size * self.beam_size, *last_dims)
1060
+ )
1061
+
1062
+ def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
+ batch_size = backpointer.size()[0]
1064
+
1065
+ for key, state_tensor in state.items():
1066
+ if state_tensor is None:
1067
+ continue
1068
+ _, *last_dims = state_tensor.size()
1069
+ # shape: (batch_size, beam_size, *)
1070
+ expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
+ batch_size, self.beam_size, *last_dims
1072
+ )
1073
+ # shape: (batch_size * beam_size, *)
1074
+ state[key] = (
1075
+ state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
+ .gather(1, expanded_backpointer)
1077
+ .reshape(batch_size * self.beam_size, *last_dims)
1078
+ )
checkpoint.py ADDED
@@ -0,0 +1,1671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import io
3
+ import logging
4
+ import pickle
5
+ import shutil
6
+ import traceback
7
+ from abc import ABCMeta, abstractmethod
8
+ from collections import defaultdict
9
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
10
+ from contextlib import contextmanager
11
+ from copy import deepcopy
12
+ from dataclasses import dataclass, field, replace
13
+ from functools import reduce
14
+ from multiprocessing import shared_memory
15
+ from pathlib import Path
16
+ from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed.checkpoint as dist_cp
21
+ import torch.multiprocessing as mp
22
+ from packaging import version
23
+ from torch.distributed import _remote_device
24
+ from torch.distributed._shard._utils import narrow_tensor_by_index
25
+ from torch.distributed._shard.metadata import ShardMetadata
26
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
27
+ from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
28
+ from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
29
+ from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
30
+ from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
31
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
+ from torch.distributed.fsdp import StateDictType
33
+ from torch.distributed.fsdp.api import (
34
+ FullOptimStateDictConfig,
35
+ FullStateDictConfig,
36
+ ShardedOptimStateDictConfig,
37
+ ShardedStateDictConfig,
38
+ )
39
+ from torch.futures import Future
40
+
41
+ try:
42
+ from torch.distributed.fsdp.flat_param import FlatParamHandle # type: ignore
43
+ except ModuleNotFoundError:
44
+ from torch.distributed.fsdp._flat_param import FlatParamHandle # type: ignore
45
+
46
+ from . import util
47
+
48
+ from .aliases import PathOrStr
49
+ from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
50
+ from .exceptions import OLMoCheckpointError
51
+ from .optim import Optimizer, fix_optim_state_dict
52
+ from .safetensors_util import safetensors_file_to_state_dict
53
+ from .torch_util import (
54
+ barrier,
55
+ gc_cuda,
56
+ get_fs_local_rank,
57
+ get_global_rank,
58
+ get_world_size,
59
+ )
60
+ from .util import (
61
+ _get_s3_client,
62
+ default_thread_count,
63
+ dir_is_empty,
64
+ get_bytes_range,
65
+ get_progress_bar,
66
+ resource_path,
67
+ upload,
68
+ wait_for,
69
+ )
70
+
71
+ __all__ = [
72
+ "save_fsdp_model_and_optim_state",
73
+ "load_fsdp_model_and_optim_state",
74
+ "load_fsdp_optim_state",
75
+ "save_state_dict",
76
+ "load_state_dict",
77
+ "load_model_state",
78
+ "RemoteFileSystemWriter",
79
+ "RemoteFileSystemReader",
80
+ "Checkpointer",
81
+ "FullCheckpointer",
82
+ "TorchNewStyleShardedCheckpointer",
83
+ "TorchLegacyShardedCheckpointer",
84
+ "LocalShardedCheckpointer",
85
+ "build_sharded_checkpointer",
86
+ ]
87
+
88
+
89
+ log = logging.getLogger(__name__)
90
+
91
+ MODEL_AND_OPTIM_FOLDER = "model_and_optim"
92
+
93
+
94
+ def save_fsdp_model_and_optim_state(
95
+ checkpoint_dir: PathOrStr,
96
+ fsdp_model: FSDP,
97
+ optim: Optimizer,
98
+ *,
99
+ upload_to: Optional[str] = None,
100
+ save_overwrite: bool = False,
101
+ ):
102
+ """
103
+ Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
104
+ functions. This should be used during distributed training and should be called by all ranks.
105
+
106
+ :param checkpoint_dir: The directory to save to.
107
+ :param fsdp_model: The FSDP model.
108
+ :param optim: The FSDP model's optimizer.
109
+ :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
110
+ :param save_overwrite: Overwrite existing files.
111
+
112
+ :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
113
+ """
114
+ checkpoint_dir = Path(checkpoint_dir)
115
+ target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
116
+ if save_overwrite:
117
+ if get_fs_local_rank() == 0:
118
+ shutil.rmtree(target_dir, ignore_errors=True)
119
+ elif not dir_is_empty(target_dir):
120
+ raise FileExistsError(target_dir)
121
+ barrier()
122
+ if get_fs_local_rank() == 0:
123
+ target_dir.mkdir(exist_ok=True, parents=True)
124
+ barrier()
125
+ with FSDP.state_dict_type(
126
+ fsdp_model,
127
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
128
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
129
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
130
+ ):
131
+ model_and_optim_state = {
132
+ "model": fsdp_model.state_dict(),
133
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
134
+ }
135
+ dist_cp.save_state_dict(
136
+ model_and_optim_state,
137
+ RemoteFileSystemWriter(
138
+ target_dir,
139
+ upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
140
+ save_overwrite=save_overwrite,
141
+ ),
142
+ )
143
+
144
+
145
+ def load_fsdp_model_and_optim_state(
146
+ checkpoint_dir: PathOrStr,
147
+ fsdp_model: FSDP,
148
+ optim: Optimizer,
149
+ *,
150
+ local_cache: Optional[PathOrStr] = None,
151
+ load_optimizer_state: bool = True,
152
+ ):
153
+ """
154
+ Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
155
+ functions. This should be used during distributed training and should be called by all ranks.
156
+
157
+ :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
158
+ :param fsdp_model: The FSDP model.
159
+ :param optim: The FSDP model's optimizer.
160
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
161
+ remote "directory" but there might be a cached version of the same artifacts.
162
+ :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
163
+
164
+ :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
165
+ """
166
+ load_path = str(checkpoint_dir).rstrip("/")
167
+ local_cache = None if local_cache is None else Path(local_cache)
168
+ with FSDP.state_dict_type(
169
+ fsdp_model,
170
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
171
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
172
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
173
+ ):
174
+ # Load the model state dict in place.
175
+ log.info("Loading model state...")
176
+ model_state = {"model": fsdp_model.state_dict()}
177
+ dist_cp.load_state_dict(
178
+ model_state,
179
+ RemoteFileSystemReader(
180
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
181
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
182
+ ),
183
+ )
184
+ fsdp_model.load_state_dict(model_state["model"])
185
+
186
+ if not load_optimizer_state:
187
+ return
188
+
189
+ # Load optim state dict in place.
190
+ log.info("Loading sharded optimizer state...")
191
+ optim_state = load_sharded_optimizer_state_dict(
192
+ model_state_dict=model_state["model"],
193
+ optimizer_key="optim",
194
+ storage_reader=RemoteFileSystemReader(
195
+ f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
196
+ local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
197
+ ),
198
+ )
199
+ del model_state
200
+ gc_cuda()
201
+ load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])
202
+
203
+
204
+ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
205
+ log.info("Flattening sharded optimizer state...")
206
+ # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
207
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
208
+ flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
209
+ else:
210
+ flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
211
+ del optim_state
212
+ gc.collect()
213
+ log.info("Loading flattened optimizer state...")
214
+ # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
215
+ # which takes up unnecessary GPU memory.
216
+ for state in flattened_osd["state"].values():
217
+ for k in state.keys():
218
+ v = state[k]
219
+ if isinstance(v, torch.Tensor):
220
+ state[k] = v.to(device="cpu")
221
+ gc_cuda()
222
+ optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))
223
+
224
+
225
+ def save_state_dict(
226
+ checkpoint_dir: PathOrStr,
227
+ fname: str,
228
+ state_dict: Dict[str, Any],
229
+ *,
230
+ upload_to: Optional[str] = None,
231
+ save_overwrite: bool = False,
232
+ synchronize: bool = True,
233
+ ):
234
+ """
235
+ Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
236
+ This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
237
+ for each rank.
238
+
239
+ :param checkpoint_dir: The directory to save to.
240
+ :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
241
+ :param state_dict: The state dict to save.
242
+ :param upload_to: Optional, a remote "directory" to upload the file to.
243
+ :param save_overwrite: Overwrite existing files.
244
+ :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
245
+ this function from a single rank.
246
+
247
+ :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
248
+ """
249
+ checkpoint_dir = Path(checkpoint_dir)
250
+ target_path = checkpoint_dir / fname
251
+ if save_overwrite:
252
+ target_path.unlink(missing_ok=True)
253
+ elif target_path.is_file():
254
+ raise FileExistsError(target_path)
255
+ if synchronize:
256
+ barrier()
257
+ target_path.parent.mkdir(exist_ok=True, parents=True)
258
+ if synchronize:
259
+ barrier()
260
+ torch.save(state_dict, target_path)
261
+ if upload_to is not None:
262
+ upload_target = f"{upload_to.rstrip('/')}/{fname}"
263
+ log.info(f"Uploading {target_path} to {upload_target}...")
264
+ upload(target_path, upload_target, save_overwrite=save_overwrite)
265
+
266
+
267
+ def load_state_dict(
268
+ checkpoint_dir: PathOrStr,
269
+ fname: str,
270
+ *,
271
+ local_cache: Optional[PathOrStr] = None,
272
+ map_location: Optional[str] = None,
273
+ ):
274
+ """
275
+ Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
276
+ This can be used during distributed training or not.
277
+
278
+ :param checkpoint_dir: A local or remote checkpoint directory.
279
+ :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
280
+ :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
281
+ remote "directory" but there might be a cached version of the same artifacts.
282
+
283
+ :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
284
+ """
285
+ if fname.endswith(".pt"):
286
+ # Try safetensors version first.
287
+ try:
288
+ path = resource_path(
289
+ str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
290
+ )
291
+ return safetensors_file_to_state_dict(path, map_location=map_location)
292
+ except FileNotFoundError:
293
+ pass
294
+
295
+ path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
296
+ return torch.load(path, map_location=map_location)
297
+
298
+
299
+ def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
300
+ """
301
+ Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
302
+ Note that ``model`` should not be wrapped with FSDP.
303
+ """
304
+ state_dict = {"model": model.state_dict()}
305
+ dist_cp.load_state_dict(
306
+ state_dict,
307
+ RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
308
+ no_dist=True,
309
+ )
310
+ model.load_state_dict(state_dict["model"])
311
+
312
+
313
+ class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
314
+ """
315
+ A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
316
+ directly to a cloud bucket when ``upload_to`` is specified.
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ path: PathOrStr,
322
+ single_file_per_rank: bool = True,
323
+ sync_files: bool = True,
324
+ thread_count: Optional[int] = None,
325
+ per_thread_copy_ahead: int = 10_000_000,
326
+ upload_to: Optional[str] = None,
327
+ save_overwrite: bool = False,
328
+ ) -> None:
329
+ if thread_count is not None and thread_count <= 0:
330
+ raise ValueError("thread count must be at least 1")
331
+ super().__init__(
332
+ path,
333
+ single_file_per_rank=single_file_per_rank,
334
+ sync_files=sync_files,
335
+ # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
336
+ # returns because uploading big checkpoint files with multiple threads causes
337
+ # boto3 to fail in weird ways.
338
+ thread_count=thread_count or 1,
339
+ per_thread_copy_ahead=per_thread_copy_ahead,
340
+ )
341
+ self.upload_to = None if upload_to is None else upload_to.rstrip("/")
342
+ self.save_overwrite = save_overwrite
343
+
344
+ def write_data(
345
+ self,
346
+ plan: dist_cp.SavePlan,
347
+ planner: dist_cp.SavePlanner,
348
+ ) -> Future[List[WriteResult]]:
349
+ fut = super().write_data(plan, planner)
350
+ if self.upload_to is not None:
351
+ files_to_upload = set()
352
+ for write_result in fut.wait():
353
+ files_to_upload.add(write_result.storage_data.relative_path)
354
+
355
+ # Create the global S3 client up front to work around a threading issue in boto.
356
+ if self.upload_to.startswith("s3://"):
357
+ _get_s3_client("s3")
358
+ elif self.upload_to.startswith("r2://"):
359
+ _get_s3_client("r2")
360
+
361
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
362
+ futures = []
363
+ for fname in files_to_upload:
364
+ source = self.path / fname
365
+ target = f"{self.upload_to}/{fname}"
366
+ log.info(f"Uploading {source} to {target}...")
367
+ futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
368
+ for f in as_completed(futures):
369
+ try:
370
+ f.result()
371
+ except BaseException:
372
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
373
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
374
+ # sure we're raising a simple error type that can be pickled.
375
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
376
+ return fut
377
+
378
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
379
+ super().finish(metadata, results)
380
+ if self.upload_to is not None:
381
+ source = self.path / ".metadata"
382
+ target = f"{self.upload_to}/.metadata"
383
+ log.info(f"Uploading {source} to {target}...")
384
+ upload(source, target, save_overwrite=self.save_overwrite)
385
+
386
+
387
+ class RemoteFileSystemReader(dist_cp.StorageReader):
388
+ """
389
+ A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
390
+ that can read data directly from cloud storage as well as a local directory.
391
+ """
392
+
393
+ def __init__(
394
+ self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
395
+ ):
396
+ super().__init__()
397
+ if thread_count is not None and thread_count <= 0:
398
+ raise ValueError("thread count must be at least 1")
399
+ self.path = str(path).rstrip("/")
400
+ self.cache = None if local_cache is None else Path(local_cache)
401
+ self.thread_count = thread_count or default_thread_count()
402
+ self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
403
+ self._metadata: Optional[Metadata] = None
404
+
405
+ def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
406
+ if self.cache is not None and (path := self.cache / relative_path).is_file():
407
+ return get_bytes_range(path, offset, length)
408
+ else:
409
+ return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
410
+
411
+ def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
412
+ sinfo = self.storage_data[read_item.storage_index]
413
+ content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
414
+ return (read_item, content)
415
+
416
+ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
417
+ # Create the global S3 client up front to work around a threading issue in boto.
418
+ if isinstance(self.path, str):
419
+ if self.path.startswith("s3://"):
420
+ _get_s3_client("s3")
421
+ elif self.path.startswith("r2://"):
422
+ _get_s3_client("r2")
423
+
424
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
425
+ read_item_content_futures = []
426
+ for read_item in plan.items:
427
+ read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
428
+ read_item_content_results = []
429
+ for f in as_completed(read_item_content_futures):
430
+ try:
431
+ read_item_content_results.append(f.result())
432
+ except BaseException:
433
+ # NOTE: we might get an error here that can't be pickled, which causes a different failure
434
+ # later when PyTorch tries to reduce that error across ranks. So here we just make
435
+ # sure we're raising a simple error type that can be pickled.
436
+ raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
437
+
438
+ # Modified from `FileSystemReader.read_data()`
439
+ for read_item, content in read_item_content_results:
440
+ bytes = io.BytesIO(content)
441
+ bytes.seek(0)
442
+ if read_item.type == LoadItemType.BYTE_IO:
443
+ planner.load_bytes(read_item, bytes)
444
+ else:
445
+ tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
446
+ tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
447
+ target_tensor = planner.resolve_tensor(read_item).detach()
448
+
449
+ assert (
450
+ target_tensor.size() == tensor.size()
451
+ ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
452
+ target_tensor.copy_(tensor)
453
+ planner.commit_tensor(read_item, target_tensor)
454
+
455
+ fut: Future = Future()
456
+ fut.set_result(None)
457
+ return fut
458
+
459
+ def read_metadata(self) -> Metadata:
460
+ if self._metadata is None:
461
+ with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
462
+ self._metadata = pickle.load(metadata_file)
463
+ return self._metadata
464
+
465
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
466
+ del is_coordinator
467
+ self.storage_data = metadata.storage_data
468
+ assert self.storage_data is not None
469
+
470
+ def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
471
+ return plan
472
+
473
+ def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
474
+ return global_plan
475
+
476
+
477
+ class Checkpointer(metaclass=ABCMeta):
478
+ def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
479
+ self.cfg = cfg
480
+ self.thread_count = thread_count or default_thread_count()
481
+
482
+ @abstractmethod
483
+ def save_checkpoint(
484
+ self,
485
+ dir: PathOrStr,
486
+ fsdp_model: FSDP,
487
+ optim: Optimizer,
488
+ train_state: Dict[str, Any],
489
+ *,
490
+ upload_to: Optional[str] = None,
491
+ ) -> None:
492
+ raise NotImplementedError
493
+
494
+ @abstractmethod
495
+ def restore_checkpoint(
496
+ self,
497
+ load_path: PathOrStr,
498
+ fsdp_model: FSDP,
499
+ optim: Optimizer,
500
+ *,
501
+ local_cache: Optional[PathOrStr] = None,
502
+ load_optimizer_state: bool = True,
503
+ ) -> Dict[str, Any]:
504
+ """
505
+ Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
506
+ """
507
+ raise NotImplementedError
508
+
509
+ def unshard_checkpoint(
510
+ self,
511
+ load_path: PathOrStr,
512
+ *,
513
+ local_cache: Optional[PathOrStr] = None,
514
+ load_optimizer_state: bool = True,
515
+ load_trainer_state: bool = True,
516
+ device: Optional[torch.device] = None,
517
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
518
+ """
519
+ Unshard a checkpoint.
520
+
521
+ Note this is not marked abstract because child classes are not required to implemented this.
522
+ """
523
+ del load_path, local_cache, load_optimizer_state, load_trainer_state, device
524
+ raise NotImplementedError
525
+
526
+ @contextmanager
527
+ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
528
+ # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
529
+ checkpoint_dir = Path(dir)
530
+ if not dir_is_empty(checkpoint_dir):
531
+ if self.cfg.save_overwrite:
532
+ if get_fs_local_rank() == 0:
533
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
534
+ else:
535
+ raise FileExistsError(checkpoint_dir)
536
+ # No need to mkdir here since we'll directly replace the temporary directory with
537
+ # this directory below.
538
+ barrier()
539
+
540
+ # Prepare temporary directory. We don't have to be as careful here, we can
541
+ # just remove it if it already exists.
542
+ checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
543
+ if get_fs_local_rank() == 0:
544
+ shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
545
+ checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)
546
+
547
+ barrier()
548
+
549
+ # Yield temporary directory for `.save_checkpoint()` to use.
550
+ yield checkpoint_dir_tmp
551
+
552
+ barrier()
553
+
554
+ # Finally if all went well replace the temporary directory with the actual
555
+ # checkpoint directory.
556
+ if get_fs_local_rank() == 0:
557
+ # Replace temp directory with target checkpoint directory.
558
+ try:
559
+ checkpoint_dir_tmp.replace(checkpoint_dir)
560
+ except FileNotFoundError:
561
+ # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
562
+ # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
563
+ # file-systems.
564
+ if not checkpoint_dir.exists():
565
+ raise
566
+
567
+ # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
568
+ # replacing the temp directory with the final directory from rank 0 might not be immediately
569
+ # realized in the file systems of the other ranks.
570
+ # So we wait here across all ranks until that final checkpoint directory is visible.
571
+ wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)
572
+
573
+ barrier()
574
+
575
+ def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
576
+ if get_global_rank() == 0:
577
+ log.info("Saving config...")
578
+ self.cfg.save(config_path := Path(dir) / "config.yaml")
579
+ if upload_to is not None:
580
+ upload_target = f"{upload_to}/config.yaml"
581
+ log.info(f"Uploading {config_path} to {upload_target}")
582
+ upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)
583
+
584
+
585
+ class FullCheckpointer(Checkpointer):
586
+ """
587
+ A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
588
+ """
589
+
590
+ def save_checkpoint(
591
+ self,
592
+ dir: PathOrStr,
593
+ fsdp_model: FSDP,
594
+ optim: Optimizer,
595
+ trainer_state: Dict[str, Any],
596
+ *,
597
+ upload_to: Optional[str] = None,
598
+ ) -> None:
599
+ with self._temporary_wd(dir) as checkpoint_dir:
600
+ with FSDP.state_dict_type(
601
+ fsdp_model,
602
+ state_dict_type=StateDictType.FULL_STATE_DICT,
603
+ state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
604
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
605
+ ):
606
+ # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
607
+ # First the model state.
608
+ model_state_dict = fsdp_model.state_dict()
609
+ if get_global_rank() == 0:
610
+ log.info("Saving model state...")
611
+ save_state_dict(
612
+ checkpoint_dir,
613
+ "model.pt",
614
+ model_state_dict,
615
+ upload_to=upload_to,
616
+ save_overwrite=self.cfg.save_overwrite,
617
+ synchronize=False,
618
+ )
619
+ del model_state_dict
620
+ barrier()
621
+
622
+ # Then the optimizer state.
623
+ optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
624
+ if get_global_rank() == 0:
625
+ log.info("Saving optim state...")
626
+ save_state_dict(
627
+ checkpoint_dir,
628
+ "optim.pt",
629
+ optim_state_dict,
630
+ upload_to=upload_to,
631
+ save_overwrite=self.cfg.save_overwrite,
632
+ synchronize=False,
633
+ )
634
+ del optim_state_dict
635
+ barrier()
636
+
637
+ # Save trainer state.
638
+ if get_global_rank() == 0:
639
+ log.info("Saving trainer state...")
640
+ save_state_dict(
641
+ checkpoint_dir,
642
+ "train.pt",
643
+ trainer_state,
644
+ upload_to=upload_to,
645
+ save_overwrite=self.cfg.save_overwrite,
646
+ synchronize=False,
647
+ )
648
+ # Save config.
649
+ self._save_config(checkpoint_dir, upload_to=upload_to)
650
+
651
+ def restore_checkpoint(
652
+ self,
653
+ load_path: PathOrStr,
654
+ fsdp_model: FSDP,
655
+ optim: Optimizer,
656
+ *,
657
+ local_cache: Optional[PathOrStr] = None,
658
+ load_optimizer_state: bool = True,
659
+ ) -> Dict[str, Any]:
660
+ with FSDP.state_dict_type(
661
+ fsdp_model,
662
+ state_dict_type=StateDictType.FULL_STATE_DICT,
663
+ state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
664
+ optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
665
+ ):
666
+ with torch.no_grad():
667
+ # fill everything with NaN, so we can check afterwards that every parameter has been restored
668
+ for module_name, module in fsdp_model.named_modules():
669
+ if not isinstance(module, FSDP):
670
+ continue
671
+ for param in module.params:
672
+ param.fill_(torch.nan)
673
+
674
+ # restore params from checkpoint
675
+ state_dict_to_load = load_state_dict(
676
+ load_path, "model.pt", local_cache=local_cache, map_location="cpu"
677
+ )
678
+ (
679
+ state_dict_to_load,
680
+ og_keys_to_new,
681
+ ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)
682
+
683
+ for module_name, module in fsdp_model.named_modules():
684
+ if not isinstance(module, FSDP):
685
+ continue
686
+ for param in module.params:
687
+ assert param._is_flat_param
688
+ for fqn, spi in zip(param._fqns, param._shard_param_infos):
689
+ if not spi.in_shard:
690
+ continue
691
+ key = f"{module_name}.{fqn}"
692
+ key = key.replace("_fsdp_wrapped_module.", "")
693
+ key = key.lstrip(".")
694
+ t = state_dict_to_load[key]
695
+ t = t.flatten()
696
+ param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
697
+ t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
698
+ )
699
+
700
+ # make sure that every parameter has been restored
701
+ for module_name, module in fsdp_model.named_modules():
702
+ if not isinstance(module, FSDP):
703
+ continue
704
+ for param in module.params:
705
+ if torch.isnan(param).any():
706
+ raise ValueError(
707
+ f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
708
+ )
709
+
710
+ # Load optimizer state.
711
+ if load_optimizer_state:
712
+ optim_state_dict_to_load = load_state_dict(
713
+ load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
714
+ )
715
+ optim_state_dict_to_load = self._make_optim_state_dict_compatible(
716
+ optim_state_dict_to_load,
717
+ og_keys_to_new,
718
+ )
719
+ load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
720
+ del optim_state_dict_to_load
721
+
722
+ # Load other state.
723
+ try:
724
+ trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
725
+ except FileNotFoundError:
726
+ # for backwards compatibility
727
+ trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
728
+ barrier()
729
+ return trainer_state
730
+
731
+ def _make_optim_state_dict_compatible(
732
+ self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
733
+ ) -> Dict[str, Any]:
734
+ # This state dict comes in two forms: one where the state keys are integers and one where the
735
+ # keys are fully qualified parameter names. The latter case is easier to deal with here so we
736
+ # first transform the integer key form into the FQN key form.
737
+ if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
738
+ id_to_fqn: Dict[int, str] = {}
739
+ for group in optim_state_dict["param_groups"]:
740
+ new_param_names = []
741
+ for fqn, id in zip(group["param_names"], group["params"]):
742
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
743
+ id_to_fqn[id] = fqn
744
+ new_param_names.append(fqn)
745
+ group["param_names"] = new_param_names
746
+ group["params"] = new_param_names
747
+ for id in list(optim_state_dict["state"].keys()):
748
+ optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
749
+ else:
750
+ # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
751
+ for group in optim_state_dict["param_groups"]:
752
+ group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
753
+ group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
754
+ assert group["param_names"] == group["params"]
755
+ for key in list(optim_state_dict["state"].keys()):
756
+ optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
757
+ "state"
758
+ ].pop(key)
759
+
760
+ # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
761
+ # First fix param names in the state.
762
+ for og_key, new_keys in og_keys_to_new.items():
763
+ og_state = optim_state_dict["state"].pop(og_key, None)
764
+ if og_state is None:
765
+ continue
766
+ for i, new_key in enumerate(new_keys):
767
+ if i == len(new_keys) - 1:
768
+ optim_state_dict["state"][new_key] = og_state
769
+ else:
770
+ optim_state_dict["state"][new_key] = deepcopy(og_state)
771
+ # Now fix param names in the param groups.
772
+ for group in optim_state_dict["param_groups"]:
773
+ og_names = group["params"]
774
+ new_names = []
775
+ for og_key in og_names:
776
+ for new_key in og_keys_to_new[og_key]:
777
+ new_names.append(new_key)
778
+ group["params"] = new_names
779
+ group["param_names"] = new_names
780
+
781
+ return optim_state_dict
782
+
783
+ def load_checkpoint(
784
+ self,
785
+ load_path: PathOrStr,
786
+ *,
787
+ local_cache: Optional[PathOrStr] = None,
788
+ load_optimizer_state: bool = True,
789
+ device: Optional[torch.device] = None,
790
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
791
+ device = device if device is not None else torch.device("cpu")
792
+ model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device) # type: ignore
793
+ optim_state = None
794
+ if load_optimizer_state:
795
+ optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device) # type: ignore
796
+ return model_state, optim_state
797
+
798
+
799
+ class TorchNewStyleShardedCheckpointer(Checkpointer):
800
+ """
801
+ A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
802
+ """
803
+
804
+ def save_checkpoint(
805
+ self,
806
+ dir: PathOrStr,
807
+ fsdp_model: FSDP,
808
+ optim: Optimizer,
809
+ trainer_state: Dict[str, Any],
810
+ *,
811
+ upload_to: Optional[str] = None,
812
+ ) -> None:
813
+ with self._temporary_wd(dir) as checkpoint_dir:
814
+ # Save model and optim state.
815
+ save_fsdp_model_and_optim_state(
816
+ checkpoint_dir,
817
+ fsdp_model,
818
+ optim,
819
+ upload_to=upload_to,
820
+ save_overwrite=self.cfg.save_overwrite,
821
+ )
822
+
823
+ # Save trainer state.
824
+ log.info("Saving trainer state...")
825
+ save_state_dict(
826
+ checkpoint_dir,
827
+ f"train/rank{get_global_rank()}.pt",
828
+ trainer_state,
829
+ upload_to=upload_to,
830
+ save_overwrite=self.cfg.save_overwrite,
831
+ )
832
+
833
+ # Save config.
834
+ self._save_config(checkpoint_dir, upload_to=upload_to)
835
+
836
+ def restore_checkpoint(
837
+ self,
838
+ load_path: PathOrStr,
839
+ fsdp_model: FSDP,
840
+ optim: Optimizer,
841
+ *,
842
+ local_cache: Optional[PathOrStr] = None,
843
+ load_optimizer_state: bool = True,
844
+ ) -> Dict[str, Any]:
845
+ # Load model and optimizer state in place.
846
+ log.info("Loading model and optimizer state...")
847
+ load_fsdp_model_and_optim_state(
848
+ load_path,
849
+ fsdp_model,
850
+ optim,
851
+ local_cache=local_cache,
852
+ load_optimizer_state=load_optimizer_state,
853
+ )
854
+
855
+ # Load trainer state dict.
856
+ log.info("Loading trainer state...")
857
+ try:
858
+ trainer_state = load_state_dict(
859
+ load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
860
+ )
861
+ except FileNotFoundError:
862
+ # Fall back to rank 0 train state.
863
+ # This can happen when we're restoring a checkpoint with a different world size.
864
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
865
+ barrier()
866
+ return trainer_state
867
+
868
+
869
+ class TorchLegacyShardedCheckpointer(Checkpointer):
870
+ """
871
+ A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
872
+ and optim state.
873
+
874
+ The world size must be kept consistent when using this checkpointer.
875
+ """
876
+
877
+ def save_checkpoint(
878
+ self,
879
+ dir: PathOrStr,
880
+ fsdp_model: FSDP,
881
+ optim: Optimizer,
882
+ trainer_state: Dict[str, Any],
883
+ *,
884
+ upload_to: Optional[str] = None,
885
+ ) -> None:
886
+ with self._temporary_wd(dir) as checkpoint_dir:
887
+ with FSDP.state_dict_type(
888
+ fsdp_model,
889
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
890
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
891
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
892
+ ):
893
+ state_dict = {
894
+ "model": fsdp_model.state_dict(),
895
+ "optim": FSDP.optim_state_dict(fsdp_model, optim),
896
+ **trainer_state,
897
+ }
898
+ save_state_dict(
899
+ checkpoint_dir,
900
+ f"rank{get_global_rank()}.pt",
901
+ state_dict,
902
+ upload_to=upload_to,
903
+ save_overwrite=self.cfg.save_overwrite,
904
+ )
905
+
906
+ # Save config.
907
+ self._save_config(checkpoint_dir, upload_to=upload_to)
908
+
909
+ def restore_checkpoint(
910
+ self,
911
+ load_path: PathOrStr,
912
+ fsdp_model: FSDP,
913
+ optim: Optimizer,
914
+ *,
915
+ local_cache: Optional[PathOrStr] = None,
916
+ load_optimizer_state: bool = True,
917
+ ) -> Dict[str, Any]:
918
+ with FSDP.state_dict_type(
919
+ fsdp_model,
920
+ state_dict_type=StateDictType.SHARDED_STATE_DICT,
921
+ state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
922
+ optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
923
+ ):
924
+ # Deserialize state dict.
925
+ state_dict = load_state_dict(
926
+ load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
927
+ )
928
+
929
+ # Load model and optimizer state.
930
+ log.info("Loading model state...")
931
+ fsdp_model.load_state_dict(state_dict["model"])
932
+ del state_dict["model"]
933
+ if load_optimizer_state:
934
+ log.info("Loading optimizer state...")
935
+ load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
936
+ del state_dict["optim"]
937
+
938
+ barrier()
939
+ return state_dict
940
+
941
+ def unshard_checkpoint(
942
+ self,
943
+ load_path: PathOrStr,
944
+ *,
945
+ local_cache: Optional[PathOrStr] = None,
946
+ load_optimizer_state: bool = True,
947
+ load_trainer_state: bool = True,
948
+ device: Optional[torch.device] = None,
949
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
950
+ assert local_cache is None, "this method currently only supports local files"
951
+ full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
952
+ model_state = full_state_dict.pop("model")
953
+ optim_state = full_state_dict.pop("optim")
954
+ return (
955
+ model_state,
956
+ optim_state if load_optimizer_state else None,
957
+ full_state_dict if load_trainer_state else None,
958
+ )
959
+
960
+ def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
961
+ key = tuple() if key is None else key
962
+ if isinstance(state, (list, tuple, set)):
963
+ for i, sub_state in enumerate(state):
964
+ self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
965
+ elif isinstance(state, dict):
966
+ for name in state.keys():
967
+ self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
968
+ elif isinstance(state, ShardedTensor):
969
+ self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
970
+ return
971
+ else:
972
+ return
973
+
974
+ def _get_shard_placement_and_rank_sizes(
975
+ self, shards_metadata: List[ShardMetadata], world_size: int
976
+ ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
977
+ def shard_size(shard_md):
978
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
979
+
980
+ rank_sizes = [0 for _ in range(world_size)]
981
+ shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
982
+ for shard_md in shards_metadata:
983
+ shard_rank = cast(_remote_device, shard_md.placement).rank()
984
+ assert shard_rank is not None
985
+ if shard_rank >= world_size:
986
+ raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")
987
+
988
+ shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
989
+ rank_sizes[shard_rank] += shard_size(shard_md)
990
+
991
+ return shard_placement, rank_sizes
992
+
993
+ def _copy_sharded_tensor_to_shared_mem(
994
+ self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
995
+ ) -> Any:
996
+ shard0_md = sharded_tensor.metadata()
997
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
998
+ shard0_md.shards_metadata, world_size
999
+ )
1000
+
1001
+ rank_size = rank_sizes[rank]
1002
+ assert rank_size >= 0
1003
+ if rank_size == 0:
1004
+ return
1005
+
1006
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1007
+ numpy_type = np.float32
1008
+
1009
+ sharded_memory_name = "-".join(key + (str(rank),))
1010
+
1011
+ shm = shared_memory.SharedMemory(
1012
+ create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
1013
+ )
1014
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1015
+
1016
+ for local_shard in sharded_tensor.local_shards():
1017
+ shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
1018
+ assert shard_rank == rank
1019
+
1020
+ src = local_shard.tensor.flatten()
1021
+ shard_offset = shard_placement[local_shard.metadata][1]
1022
+
1023
+ np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()
1024
+
1025
+ shm.close()
1026
+
1027
+ def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
1028
+ shard_number = int(shard_filepath.name[4:-3])
1029
+ log.info("Starting unsharding shard number %d to shared memory", shard_number)
1030
+
1031
+ with self._patch_sharded_tensor_load():
1032
+ shard = torch.load(shard_filepath, map_location="cpu")
1033
+ log.debug("Done loading shard number %d", shard_number)
1034
+
1035
+ self._copy_sharded_tensors_to_shared_mem(
1036
+ shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
1037
+ )
1038
+ log.info("Done unsharding shard number %d to shared memory", shard_number)
1039
+
1040
+ def _unshard_using_sharded_mem(
1041
+ self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
1042
+ ) -> Any:
1043
+ return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))
1044
+
1045
+ def _unshard_state_using_shared_mem(
1046
+ self, state: Any, world_size: int, device: torch.device, key: Tuple
1047
+ ) -> Any:
1048
+ if isinstance(state, (list, tuple, set)):
1049
+ return state.__class__(
1050
+ self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
1051
+ for i, sub_state in enumerate(state)
1052
+ )
1053
+ elif isinstance(state, dict):
1054
+ return {
1055
+ name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
1056
+ for name in state.keys()
1057
+ }
1058
+ elif isinstance(state, ShardedTensor):
1059
+ return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
1060
+ elif isinstance(state, torch.Tensor):
1061
+ return state.to(device=device)
1062
+ else:
1063
+ return state
1064
+
1065
+ def _unshard_tensor_using_shared_mem(
1066
+ self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
1067
+ ) -> torch.Tensor:
1068
+ shard0_md = sharded_tensor.metadata()
1069
+
1070
+ def shard_size(shard_md):
1071
+ return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
1072
+
1073
+ shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
1074
+ shard0_md.shards_metadata, world_size
1075
+ )
1076
+
1077
+ assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
1078
+ numpy_type = np.float32
1079
+
1080
+ out = torch.empty(
1081
+ *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
1082
+ )
1083
+ dims = len(sharded_tensor.metadata().size)
1084
+ for shard_md, (rank, rank_offset) in shard_placement.items():
1085
+ if rank >= world_size:
1086
+ raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")
1087
+
1088
+ sharded_memory_name = "-".join(key + (str(rank),))
1089
+ shm = shared_memory.SharedMemory(name=sharded_memory_name)
1090
+
1091
+ rank_size = rank_sizes[rank]
1092
+ assert rank_size >= 0
1093
+ if rank_size == 0:
1094
+ continue
1095
+
1096
+ np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)
1097
+
1098
+ tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
1099
+ tensor = tensor.view(shard_md.shard_sizes)
1100
+
1101
+ out_narrow_view = out
1102
+ for dim in range(dims):
1103
+ out_narrow_view = out_narrow_view.narrow(
1104
+ dim,
1105
+ shard_md.shard_offsets[dim],
1106
+ shard_md.shard_sizes[dim],
1107
+ )
1108
+
1109
+ out_narrow_view.copy_(tensor)
1110
+
1111
+ shm.close()
1112
+ shm.unlink()
1113
+
1114
+ return out
1115
+
1116
+ @contextmanager
1117
+ def _patch_sharded_tensor_load(self):
1118
+ """
1119
+ Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
1120
+ """
1121
+
1122
+ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
1123
+ ret = func(*args)
1124
+ if type(ret) is not new_type:
1125
+ ret = ret.as_subclass(new_type)
1126
+
1127
+ # Shortcut the construction of ShardedTensor
1128
+ # This is in the top 5 of my worst hacks.
1129
+ if isinstance(ret, ShardedTensor):
1130
+ ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
1131
+ return ret
1132
+
1133
+ # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
1134
+ # Tensor does define __setstate__ even though it doesn't define
1135
+ # __getstate__. So only use __setstate__ if it is NOT the one defined
1136
+ # on Tensor
1137
+ if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
1138
+ ret.__setstate__(state)
1139
+ else:
1140
+ ret = torch._utils._set_obj_state(ret, state)
1141
+ return ret
1142
+
1143
+ original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
1144
+ try:
1145
+ torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
1146
+ yield
1147
+ finally:
1148
+ torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2
1149
+
1150
+ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
1151
+ """
1152
+ The current unsharding implementation consists of:
1153
+
1154
+ 1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
1155
+ 2. Loading 1 shard on the main process as a base unsharded object.
1156
+ 3. Using the sharded tensors in shared memory to populate the base unsharded object.
1157
+
1158
+ This implementation replaced a prior implementation that instead loaded
1159
+ all shards using threads, because that implementation turned out to
1160
+ be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
1161
+ The current implementation is slower than the old one in many scenarios,
1162
+ but is significantly faster in the above mentioned case (e.g. 30 minutes)
1163
+ if there are enough CPUs.
1164
+ """
1165
+
1166
+ input_dir = Path(input_dir)
1167
+ skip_keys = skip_keys or set()
1168
+
1169
+ shard_filepaths = list(input_dir.glob("rank*.pt"))
1170
+ world_size = len(shard_filepaths)
1171
+ if world_size == 0:
1172
+ raise RuntimeError("No shards found for unsharding")
1173
+
1174
+ log.info("Number of shards: %d", world_size)
1175
+ shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
1176
+ min_ram_required_estimate_gb = shard_size_gb * world_size
1177
+ log.info(
1178
+ "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
1179
+ )
1180
+
1181
+ log.info("Copying sharded tensors to shared memory using multiple processes")
1182
+ # Copy sharded data to shared memory using multiple processes, so this process can load
1183
+ # from memory rather than disk. We spawn a new process instead of forking since shared memory
1184
+ # appears to get deleted when forked processes end for some reason.
1185
+ executor = ProcessPoolExecutor(
1186
+ mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
1187
+ )
1188
+ futures = []
1189
+ for shard_filepath in shard_filepaths:
1190
+ shard_rank = int(shard_filepath.name[4:-3])
1191
+
1192
+ if shard_rank >= world_size:
1193
+ raise RuntimeError(
1194
+ f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
1195
+ )
1196
+
1197
+ futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))
1198
+
1199
+ for f in as_completed(futures):
1200
+ f.result()
1201
+ executor.shutdown()
1202
+
1203
+ log.info("Loading a shard on the main process to be unsharded state")
1204
+ with self._patch_sharded_tensor_load():
1205
+ state = torch.load(shard_filepaths[0], map_location="cpu")
1206
+
1207
+ for key in skip_keys:
1208
+ if key in state:
1209
+ del state[key]
1210
+
1211
+ log.info("Unsharding from %d shards ...", world_size)
1212
+ return self._unshard_using_sharded_mem(state, world_size, device, input_dir)
1213
+
1214
+
1215
+ @dataclass
1216
+ class _LocalShardedCheckpointerMetadata(BaseConfig):
1217
+ world_size: int = field(default_factory=get_world_size)
1218
+
1219
+
1220
+ @dataclass
1221
+ class _FlatParamShard:
1222
+ full_shape: torch.Size
1223
+ shard_offsets: Tuple[int, int]
1224
+ shard_data: Optional[torch.Tensor]
1225
+
1226
+ def copy_into(self, full_tensor: torch.Tensor) -> None:
1227
+ assert self.shard_data is not None
1228
+ full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
1229
+ assert self.shard_data.shape == full_tensor_shard_view.shape
1230
+ full_tensor_shard_view.copy_(self.shard_data)
1231
+
1232
+
1233
+ class LocalShardedCheckpointer(Checkpointer):
1234
+ """
1235
+ A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
1236
+ The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.
1237
+
1238
+ The world size must be kept consistent when using this checkpointer. However, you can easily
1239
+ reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
1240
+ using :meth:`unshard_checkpoint()` (no distributed initialization required).
1241
+ """
1242
+
1243
+ # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
1244
+ _FLAT_PARAM_METADATA_TO_SAVE = (
1245
+ "_fqns",
1246
+ "_shard_param_offsets",
1247
+ "_shard_indices",
1248
+ "_numels",
1249
+ "_numels_with_padding",
1250
+ "_shapes",
1251
+ "_shard_numel_padded",
1252
+ "_shard_param_infos",
1253
+ )
1254
+
1255
+ def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
1256
+ """
1257
+ Returns a list of FSDP modules with their FQN.
1258
+ """
1259
+ modules = []
1260
+ for name, module in fsdp_model.named_modules():
1261
+ if isinstance(module, FSDP):
1262
+ modules.append((name, module))
1263
+ return modules
1264
+
1265
+ def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
1266
+ from torch.distributed.fsdp._runtime_utils import _lazy_init
1267
+
1268
+ # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
1269
+ # an FSDP state dict through the built-in methods.
1270
+ if torch.cuda.is_available():
1271
+ torch.cuda.synchronize()
1272
+ _lazy_init(fsdp_model, fsdp_model)
1273
+
1274
+ def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
1275
+ if version.parse(torch.__version__) < version.parse("2.1.0"):
1276
+ return fsdp_model._handles # type: ignore
1277
+ elif version.parse(torch.__version__) < version.parse("2.3.0"):
1278
+ # Handle could be None if the FSDP wrapper doesn't manage any parameters.
1279
+ if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
1280
+ return [fsdp_model._handle] # type: ignore
1281
+ else:
1282
+ return []
1283
+ else:
1284
+ # Need to verify FSDP internals with newer versions.
1285
+ raise NotImplementedError
1286
+
1287
+ @torch.no_grad()
1288
+ def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
1289
+ self._prepare_fsdp_model(fsdp_model)
1290
+ module_data = []
1291
+ for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
1292
+ handle_data = []
1293
+ for handle in self._fsdp_handles(fsdp_module):
1294
+ data: Dict[str, Any] = {}
1295
+ # This is a `FlatParameter` instance.
1296
+ # See `torch.distributed.fsdp.flat_param` for the API.
1297
+ flat_param = handle.flat_param
1298
+ data["flat_param.data"] = flat_param.detach()
1299
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1300
+ if hasattr(flat_param, key):
1301
+ data[f"flat_param.{key}"] = getattr(flat_param, key)
1302
+ handle_data.append(data)
1303
+ module_data.append({"handles": handle_data, "name": module_fqn})
1304
+ return {"modules": module_data}
1305
+
1306
+ @torch.no_grad()
1307
+ def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
1308
+ """Load the state produced from `self._get_flat_param_state_to_save()`."""
1309
+ self._prepare_fsdp_model(fsdp_model)
1310
+ fsdp_modules = self._fsdp_modules(fsdp_model)
1311
+ assert len(model_state["modules"]) == len(fsdp_modules)
1312
+ for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
1313
+ handles = self._fsdp_handles(fsdp_module)
1314
+ assert len(handles) == len(module_data["handles"])
1315
+ for handle, data in zip(handles, module_data["handles"]):
1316
+ flat_param = handle.flat_param
1317
+ # Make sure metadata matches.
1318
+ for key in self._FLAT_PARAM_METADATA_TO_SAVE:
1319
+ if hasattr(flat_param, key):
1320
+ assert getattr(flat_param, key) == data[f"flat_param.{key}"]
1321
+ # Load the flat sharded data.
1322
+ flat_param.copy_(data["flat_param.data"])
1323
+
1324
+ def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
1325
+ if get_fs_local_rank() == 0:
1326
+ log.info("Saving metadata...")
1327
+ metadata = _LocalShardedCheckpointerMetadata()
1328
+ metadata.save(metadata_path := Path(dir) / "metadata.yaml")
1329
+ if upload_to is not None and get_global_rank() == 0:
1330
+ upload_target = f"{upload_to}/metadata.yaml"
1331
+ log.info(f"Uploading {metadata_path} to {upload_target}")
1332
+ upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)
1333
+
1334
+ def _load_metadata(
1335
+ self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
1336
+ ) -> _LocalShardedCheckpointerMetadata:
1337
+ metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
1338
+ return _LocalShardedCheckpointerMetadata.load(metadata_path)
1339
+
1340
+ def save_checkpoint(
1341
+ self,
1342
+ dir: PathOrStr,
1343
+ fsdp_model: FSDP,
1344
+ optim: Optimizer,
1345
+ trainer_state: Dict[str, Any],
1346
+ *,
1347
+ upload_to: Optional[str] = None,
1348
+ ) -> None:
1349
+ with self._temporary_wd(dir) as checkpoint_dir:
1350
+ # Gather local FSDP flat params data to save.
1351
+ # We also save some flat param metadata like the corresponding fully qualified names (fqns)
1352
+ # of each original parameter so we can validate that the sharding is the same when loading
1353
+ # one of these checkpoints.
1354
+ log.info("Saving local FSDP flat params data...")
1355
+ save_state_dict(
1356
+ checkpoint_dir,
1357
+ f"model/rank{get_global_rank()}.pt",
1358
+ self._get_flat_param_state_to_save(fsdp_model),
1359
+ upload_to=upload_to,
1360
+ save_overwrite=self.cfg.save_overwrite,
1361
+ )
1362
+
1363
+ # Save optimizer state.
1364
+ log.info("Saving local optimizer state...")
1365
+ save_state_dict(
1366
+ checkpoint_dir,
1367
+ f"optim/rank{get_global_rank()}.pt",
1368
+ optim.state_dict(),
1369
+ upload_to=upload_to,
1370
+ save_overwrite=self.cfg.save_overwrite,
1371
+ )
1372
+
1373
+ # Save trainer state.
1374
+ log.info("Saving trainer state...")
1375
+ save_state_dict(
1376
+ checkpoint_dir,
1377
+ f"train/rank{get_global_rank()}.pt",
1378
+ trainer_state,
1379
+ upload_to=upload_to,
1380
+ save_overwrite=self.cfg.save_overwrite,
1381
+ )
1382
+
1383
+ # Save metadata.
1384
+ self._save_metadata(checkpoint_dir, upload_to=upload_to)
1385
+
1386
+ # Save config. We do this last b/c the presence of a config in a remote checkpoint
1387
+ # "directory" indicates that the folder is valid, as a opposed to a partially
1388
+ # uploaded checkpoint directory that failed before completing.
1389
+ self._save_config(checkpoint_dir, upload_to=upload_to)
1390
+
1391
+ def restore_checkpoint(
1392
+ self,
1393
+ load_path: PathOrStr,
1394
+ fsdp_model: FSDP,
1395
+ optim: Optimizer,
1396
+ *,
1397
+ local_cache: Optional[PathOrStr] = None,
1398
+ load_optimizer_state: bool = True,
1399
+ ) -> Dict[str, Any]:
1400
+ # Load metadata and make sure checkpoint is compatible.
1401
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1402
+ assert metadata.world_size == get_world_size()
1403
+
1404
+ # Load local FSDP flat param data.
1405
+ log.info("Loading local FSDP flat params data...")
1406
+ model_state = load_state_dict(
1407
+ load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1408
+ )
1409
+ self._load_flat_param_state(fsdp_model, model_state)
1410
+ del model_state
1411
+
1412
+ # Load local optim state.
1413
+ if load_optimizer_state:
1414
+ log.info("Loading local optimizer state...")
1415
+ optim_state = load_state_dict(
1416
+ load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
1417
+ )
1418
+ # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
1419
+ # in every rank, and keep this in the optimizer state. But this causes issues when loading the
1420
+ # state since torch sees the state is non-empty for some params which would normally be empty,
1421
+ # and then assumes it should have all of the other state tensors for that param, which is doesn't.
1422
+ # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
1423
+ # Not the end of the world but there's probably a better way around this without resetting
1424
+ # the metric.
1425
+ for param_id in list(optim_state["state"].keys()):
1426
+ state = optim_state["state"][param_id]
1427
+ if "grad_norm_exp_avg" in state:
1428
+ del state["grad_norm_exp_avg"]
1429
+ if len(state) == 0:
1430
+ del optim_state["state"][param_id]
1431
+ optim.load_state_dict(optim_state)
1432
+ del optim_state
1433
+
1434
+ # Load local trainer state.
1435
+ log.info("Loading local trainer state...")
1436
+ trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
1437
+ barrier()
1438
+ return trainer_state
1439
+
1440
+ def _iter_flat_param_shards(
1441
+ self, model_state: Dict[str, Any]
1442
+ ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
1443
+ for module_data in model_state["modules"]:
1444
+ module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
1445
+ for handle in module_data["handles"]:
1446
+ flat_data = handle["flat_param.data"]
1447
+ if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
1448
+ # If there's padding in the flat param it should be on the right.
1449
+ assert (flat_data[-num_padding:] == 0).all()
1450
+ # NOTE: this changes depending on the torch version, but we don't do a version
1451
+ # check since we might be trying to unshard an old checkpoint that was stored
1452
+ # with a different torch version than we're currently running with.
1453
+ if "flat_param._shard_indices" in handle:
1454
+ # torch <=2.0.1
1455
+ param_start = handle["flat_param._shard_indices"][0]
1456
+ current_flat_index = 0
1457
+ for relative_fqn, full_shape, (offset_start, offset_end) in zip(
1458
+ handle["flat_param._fqns"][param_start:],
1459
+ handle["flat_param._shapes"][param_start:],
1460
+ handle["flat_param._shard_param_offsets"],
1461
+ ):
1462
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1463
+ numel_shard = offset_end - offset_start + 1
1464
+ flat_param_shard = _FlatParamShard(
1465
+ full_shape=full_shape,
1466
+ shard_offsets=(offset_start, offset_end),
1467
+ shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
1468
+ )
1469
+ current_flat_index += numel_shard
1470
+ yield root_fqn, flat_param_shard
1471
+ else:
1472
+ # torch >=2.1.0
1473
+ for relative_fqn, full_shape, shard_param_info in zip(
1474
+ handle["flat_param._fqns"],
1475
+ handle["flat_param._shapes"],
1476
+ handle["flat_param._shard_param_infos"],
1477
+ ):
1478
+ if not shard_param_info.in_shard:
1479
+ continue
1480
+ root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
1481
+ flat_param_shard = _FlatParamShard(
1482
+ full_shape=full_shape,
1483
+ shard_offsets=(
1484
+ shard_param_info.intra_param_start_idx,
1485
+ shard_param_info.intra_param_end_idx,
1486
+ ),
1487
+ shard_data=flat_data[
1488
+ shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
1489
+ + shard_param_info.numel_in_shard
1490
+ ],
1491
+ )
1492
+ yield root_fqn, flat_param_shard
1493
+
1494
+ def unshard_checkpoint(
1495
+ self,
1496
+ load_path: PathOrStr,
1497
+ *,
1498
+ local_cache: Optional[PathOrStr] = None,
1499
+ load_optimizer_state: bool = True,
1500
+ load_trainer_state: bool = True,
1501
+ device: Optional[torch.device] = None,
1502
+ ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
1503
+ device = device or torch.device("cpu")
1504
+ metadata = self._load_metadata(load_path, local_cache=local_cache)
1505
+
1506
+ # Gather paths model state, potentially downloading them.
1507
+ log.info("Gathering model state dicts...")
1508
+ model_state_paths = self._gather_state_dict_paths(
1509
+ load_path, "model", metadata.world_size, local_cache=local_cache
1510
+ )
1511
+
1512
+ # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
1513
+ log.info("Materializing full parameters...")
1514
+ full_model_state: Dict[str, torch.Tensor] = {}
1515
+ # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
1516
+ # the full optimizer state below without having to reload the model state dicts.
1517
+ flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
1518
+ for rank, path in enumerate(model_state_paths):
1519
+ log.info(f"Loading shards from rank {rank}...")
1520
+ model_state = torch.load(path, map_location="cpu")
1521
+ for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
1522
+ if root_fqn not in full_model_state:
1523
+ log.info(
1524
+ f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
1525
+ )
1526
+ assert flat_param_shard.shard_data is not None
1527
+ full_model_state[root_fqn] = torch.empty(
1528
+ flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
1529
+ )
1530
+ # Fill with NaNs so we can validate that the whole parameter has been populated
1531
+ # afterwards.
1532
+ full_model_state[root_fqn].fill_(torch.nan)
1533
+ # Copy over the local shard to the relevant part of the full parameter.
1534
+ full_param = full_model_state[root_fqn]
1535
+ log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
1536
+ flat_param_shard.copy_into(full_param)
1537
+ flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)
1538
+
1539
+ log.info("Validating full parameters...")
1540
+ for key, tensor in full_model_state.items():
1541
+ if torch.isnan(tensor).any():
1542
+ raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")
1543
+
1544
+ trainer_state: Optional[Dict[str, Any]] = None
1545
+ if load_trainer_state:
1546
+ trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
1547
+
1548
+ if not load_optimizer_state:
1549
+ return full_model_state, None, trainer_state
1550
+
1551
+ log.info("Gathering optim state dicts...")
1552
+ optim_state_paths = self._gather_state_dict_paths(
1553
+ load_path, "optim", metadata.world_size, local_cache=local_cache
1554
+ )
1555
+
1556
+ log.info("Materializing full optim state...")
1557
+ full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
1558
+ fqn_to_id: Dict[str, int] = {}
1559
+ id_to_fqn: Dict[int, str] = {}
1560
+ for rank, path in enumerate(optim_state_paths):
1561
+ log.info(f"Loading sharded optim state from rank {rank}...")
1562
+ optim_state = torch.load(path, map_location="cpu")
1563
+
1564
+ # Initialize param groups.
1565
+ # We assume parameter groups are the same across all ranks.
1566
+ # The only thing that differs across ranks is the state for each local sharded param.
1567
+ if "param_groups" not in full_optim_state:
1568
+ full_optim_state["param_groups"] = optim_state["param_groups"]
1569
+ else:
1570
+ assert full_optim_state["param_groups"] == optim_state["param_groups"]
1571
+
1572
+ # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
1573
+ if not fqn_to_id or not id_to_fqn:
1574
+ for group in full_optim_state["param_groups"]:
1575
+ for fqn, id in zip(group["param_names"], group["params"]):
1576
+ fqn = fqn.replace("_fsdp_wrapped_module.", "")
1577
+ fqn_to_id[fqn] = id
1578
+ id_to_fqn[id] = fqn
1579
+
1580
+ # Iterate over local shard state and copy into the full state.
1581
+ for id, shard_state in optim_state["state"].items():
1582
+ fqn = id_to_fqn[id]
1583
+ flat_param_shard = flat_params_data[rank].get(fqn) # type: ignore[assignment]
1584
+ full_state = full_optim_state["state"][id]
1585
+ for key, shard_value in shard_state.items():
1586
+ assert isinstance(shard_value, torch.Tensor)
1587
+ if shard_value.shape == torch.Size([]):
1588
+ # Add singleton tensors directly to full state. These should be the same across
1589
+ # all ranks.
1590
+ assert key in ("step", "grad_norm_exp_avg") # sanity check
1591
+ if key not in full_state:
1592
+ full_state[key] = shard_value.to(device)
1593
+ else:
1594
+ assert full_state[key] == shard_value
1595
+ else:
1596
+ # Otherwise we have a sharded param state.
1597
+ # If the corresponding full param state hasn't been materialized yet, do so now.
1598
+ assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
1599
+ if key not in full_state:
1600
+ log.info(
1601
+ f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
1602
+ )
1603
+ full_state[key] = torch.empty(
1604
+ flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
1605
+ )
1606
+ full_state_value = full_state[key]
1607
+
1608
+ # Copy over the local shard state to the relevant part of the full parameter state.
1609
+ log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
1610
+ replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)
1611
+
1612
+ # Lastly, clean up the parameter names in param groups.
1613
+ for group in full_optim_state["param_groups"]:
1614
+ group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]
1615
+
1616
+ return full_model_state, full_optim_state, trainer_state
1617
+
1618
+ def _get_state_dict_path(
1619
+ self,
1620
+ load_path: PathOrStr,
1621
+ state_dict_type: str,
1622
+ rank: int,
1623
+ *,
1624
+ local_cache: Optional[PathOrStr] = None,
1625
+ progress=None,
1626
+ ) -> Tuple[int, Path]:
1627
+ fname = f"{state_dict_type}/rank{rank}.pt"
1628
+ return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)
1629
+
1630
+ def _gather_state_dict_paths(
1631
+ self,
1632
+ load_path: PathOrStr,
1633
+ state_dict_type: str,
1634
+ world_size: int,
1635
+ *,
1636
+ local_cache: Optional[PathOrStr] = None,
1637
+ ) -> List[Path]:
1638
+ progress = get_progress_bar()
1639
+ with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
1640
+ futures = []
1641
+ for rank in range(world_size):
1642
+ future = executor.submit(
1643
+ self._get_state_dict_path,
1644
+ load_path,
1645
+ state_dict_type,
1646
+ rank,
1647
+ local_cache=local_cache,
1648
+ progress=progress,
1649
+ )
1650
+ futures.append(future)
1651
+
1652
+ results: Dict[int, Path] = {}
1653
+ for future in as_completed(futures):
1654
+ rank, path = future.result()
1655
+ results[rank] = path
1656
+
1657
+ return [results[rank] for rank in range(world_size)]
1658
+
1659
+
1660
+ def build_sharded_checkpointer(
1661
+ cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
1662
+ ) -> Checkpointer:
1663
+ name = name or cfg.sharded_checkpointer
1664
+ if name == ShardedCheckpointerType.torch_new:
1665
+ return TorchNewStyleShardedCheckpointer(cfg)
1666
+ elif name == ShardedCheckpointerType.torch_legacy:
1667
+ return TorchLegacyShardedCheckpointer(cfg)
1668
+ elif name == ShardedCheckpointerType.local:
1669
+ return LocalShardedCheckpointer(cfg)
1670
+ else:
1671
+ raise NotImplementedError(name)
config.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import asdict, dataclass, field
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Iterable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ cast,
17
+ )
18
+
19
+ import torch
20
+ from omegaconf import DictConfig, ListConfig
21
+ from omegaconf import OmegaConf as om
22
+ from omegaconf.errors import OmegaConfBaseException
23
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
24
+
25
+ from .aliases import PathOrStr
26
+ from .beam_search import Sampler
27
+ from .exceptions import OLMoConfigurationError
28
+ from .util import StrEnum
29
+
30
+ __all__ = [
31
+ "ActivationType",
32
+ "ActivationCheckpointingStrategy",
33
+ "BlockType",
34
+ "LayerNormType",
35
+ "InitFnType",
36
+ "ModelConfig",
37
+ "OptimizerType",
38
+ "OptimizerConfig",
39
+ "SchedulerType",
40
+ "SchedulerConfig",
41
+ "DataConfig",
42
+ "EvaluatorConfig",
43
+ "TokenizerConfig",
44
+ "TrainConfig",
45
+ "PaddingDirection",
46
+ "TruncationDirection",
47
+ "SpeedMonitorConfig",
48
+ "WandbConfig",
49
+ "CompilerConfig",
50
+ "WandbConfig",
51
+ "FSDPPrecision",
52
+ "FSDPWrapStrategy",
53
+ "FSDPConfig",
54
+ "CheckpointType",
55
+ ]
56
+
57
+ C = TypeVar("C", bound="BaseConfig")
58
+ D = TypeVar("D", bound="DictConfig|ListConfig")
59
+
60
+
61
+ class BaseConfig:
62
+ @classmethod
63
+ def _register_resolvers(cls, validate_paths: bool = True):
64
+ # Expands path globs into a list.
65
+ def path_glob(*paths) -> List[str]:
66
+ out = []
67
+ for path in paths:
68
+ matches = sorted(glob(path))
69
+ if not matches and validate_paths:
70
+ raise FileNotFoundError(f"{path} does not match any files or dirs")
71
+ out.extend(matches)
72
+ return out
73
+
74
+ # Chooses the first path in the arguments that exists.
75
+ def path_choose(*paths) -> str:
76
+ from .util import is_url
77
+
78
+ for path in paths:
79
+ if is_url(path) or Path(path).exists():
80
+ return path
81
+ if validate_paths:
82
+ raise FileNotFoundError(", ".join(paths))
83
+ else:
84
+ return ""
85
+
86
+ # Finds the latest checkpoint in a folder.
87
+ def path_last_checkpoint(path) -> str:
88
+ from .util import find_latest_checkpoint
89
+
90
+ latest_checkpoint = find_latest_checkpoint(path)
91
+ if latest_checkpoint is None:
92
+ if validate_paths:
93
+ raise FileNotFoundError(f"Could not find a latest checkpoint at {path}")
94
+ else:
95
+ return ""
96
+ else:
97
+ return str(latest_checkpoint)
98
+
99
+ om.register_new_resolver("path.glob", path_glob, replace=True)
100
+ om.register_new_resolver("path.choose", path_choose, replace=True)
101
+ om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True)
102
+
103
+ @classmethod
104
+ def update_legacy_settings(cls, config: D) -> D:
105
+ """
106
+ Update the legacy config settings whose schemas have undergone backwards-incompatible changes.
107
+ """
108
+ return config
109
+
110
+ @classmethod
111
+ def new(cls: Type[C], **kwargs) -> C:
112
+ cls._register_resolvers()
113
+ conf = om.structured(cls)
114
+ try:
115
+ if kwargs:
116
+ conf = om.merge(conf, kwargs)
117
+ return cast(C, om.to_object(conf))
118
+ except OmegaConfBaseException as e:
119
+ raise OLMoConfigurationError(str(e))
120
+
121
+ @classmethod
122
+ def load(
123
+ cls: Type[C],
124
+ path: PathOrStr,
125
+ overrides: Optional[List[str]] = None,
126
+ key: Optional[str] = None,
127
+ validate_paths: bool = True,
128
+ ) -> C:
129
+ """Load from a YAML file."""
130
+ cls._register_resolvers(validate_paths=validate_paths)
131
+ schema = om.structured(cls)
132
+ try:
133
+ raw = om.load(str(path))
134
+ if key is not None:
135
+ raw = raw[key] # type: ignore
136
+ raw = cls.update_legacy_settings(raw)
137
+ conf = om.merge(schema, raw)
138
+ if overrides:
139
+ conf = om.merge(conf, om.from_dotlist(overrides))
140
+ return cast(C, om.to_object(conf))
141
+ except OmegaConfBaseException as e:
142
+ raise OLMoConfigurationError(str(e))
143
+
144
+ def save(self, path: PathOrStr) -> None:
145
+ """Save to a YAML file."""
146
+ om.save(config=self, f=str(path))
147
+
148
+ def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]:
149
+ out = asdict(self) # type: ignore
150
+ if exclude is not None:
151
+ for name in exclude:
152
+ if name in out:
153
+ del out[name]
154
+ return out
155
+
156
+
157
+ class LayerNormType(StrEnum):
158
+ default = "default"
159
+ """
160
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
161
+ """
162
+
163
+ low_precision = "low_precision"
164
+ """
165
+ A low-precision version of the default LayerNorm.
166
+ """
167
+
168
+ rms = "rms"
169
+ """
170
+ An RMSNorm implementation. When using ``torch.compile`` this is
171
+ probably the fastest implementation.
172
+ """
173
+
174
+
175
+ class ActivationType(StrEnum):
176
+ gelu = "gelu"
177
+ relu = "relu"
178
+ swiglu = "swiglu"
179
+
180
+
181
+ class BlockType(StrEnum):
182
+ sequential = "sequential"
183
+
184
+ llama = "llama"
185
+ """
186
+ A block similar to the sequential block with slightly different
187
+ implementations of operations like attention to imitate the behavior of Llama.
188
+ """
189
+
190
+
191
+ class InitFnType(StrEnum):
192
+ mitchell = "mitchell"
193
+ """
194
+ The strategy suggested to us by Mitchell Wortsman from UW.
195
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
196
+ on the size of the weights as well as the depth of the layer.
197
+ """
198
+
199
+ normal = "normal"
200
+ """
201
+ All weights are initialized from the same normal distribution.
202
+ """
203
+
204
+ kaiming_normal = "kaiming_normal"
205
+ """
206
+ All weights are initialized with the Kaiming method from a normal distribution.
207
+ Note this currently won't work with FSDP.
208
+ """
209
+
210
+ fan_in = "fan_in"
211
+ """
212
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
213
+ is the input dimensionality of the kernel.
214
+ """
215
+
216
+ full_megatron = "full_megatron"
217
+ """
218
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
219
+ """
220
+
221
+
222
+ @dataclass
223
+ class ModelConfig(BaseConfig):
224
+ """
225
+ OLMo (model) configuration.
226
+ """
227
+
228
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
229
+
230
+ d_model: int = 768
231
+ """
232
+ The hidden size of the model.
233
+ """
234
+
235
+ n_heads: int = 12
236
+ """
237
+ The number of self-attention heads.
238
+ """
239
+
240
+ n_kv_heads: Optional[int] = None
241
+ """
242
+ The number of heads to use for keys and values. Defaults to `n_heads`.
243
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
244
+ Set this to 1 for multi-query attention.
245
+ Set it to some in-between value for Llama2-style grouped query attention.
246
+ """
247
+
248
+ clip_qkv: Optional[float] = None
249
+ """
250
+ Clip QKV to this value when set.
251
+ """
252
+
253
+ n_layers: int = 12
254
+ """
255
+ The number of layers/blocks.
256
+ """
257
+
258
+ mlp_ratio: int = 4
259
+ """
260
+ The ratio of the inner MLP dimensionality to ``d_model``.
261
+ This is only used when ``mlp_hidden_size`` is not set.
262
+ """
263
+
264
+ mlp_hidden_size: Optional[int] = None
265
+ """
266
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
267
+ """
268
+
269
+ activation_type: ActivationType = ActivationType.swiglu
270
+ """
271
+ The activation function to use within the MLP layers.
272
+ """
273
+
274
+ block_type: BlockType = BlockType.sequential
275
+ """
276
+ The transformer block implementation.
277
+ """
278
+
279
+ block_group_size: int = 1
280
+ """
281
+ The number of blocks to group together into a single parent block.
282
+ This has no affect on the number of parameters in the model and is only used to wrap groups
283
+ of blocks together with a single FSDP wrapper during training.
284
+ """
285
+
286
+ alibi: bool = False
287
+ """
288
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
289
+ """
290
+
291
+ alibi_bias_max: float = 8.0
292
+ """
293
+ Maximum absolute value of ALiBi bias.
294
+ """
295
+
296
+ rope: bool = False
297
+ """
298
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
299
+ """
300
+
301
+ rope_full_precision: bool = True
302
+ """
303
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
304
+ apply RoPE at the precision of the input.
305
+ """
306
+
307
+ flash_attention: bool = False
308
+ """
309
+ If ``True``, use ``FlashAttention``.
310
+ """
311
+
312
+ attention_dropout: float = 0.1
313
+ """
314
+ The dropout probability within the attention modules.
315
+ """
316
+
317
+ multi_query_attention: Optional[bool] = None
318
+ """
319
+ Deprecated. Use n_kv_heads instead.
320
+ """
321
+
322
+ attention_layer_norm: bool = False
323
+ """
324
+ Apply layer norm to the keys and queries within the attention mechanism.
325
+ This can help stabilize training.
326
+ """
327
+
328
+ residual_dropout: float = 0.1
329
+ """
330
+ The dropout probability for the MLP and attention output within each block.
331
+ """
332
+
333
+ embedding_dropout: float = 0.1
334
+ """
335
+ The dropout probability for embeddings.
336
+ """
337
+
338
+ layer_norm_type: LayerNormType = LayerNormType.default
339
+ """
340
+ The layernorm implementation to use.
341
+ """
342
+
343
+ layer_norm_with_affine: bool = True
344
+ """
345
+ Whether to include bias and weight parameters for the layer norms.
346
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
347
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
348
+ to ``False``.
349
+ """
350
+
351
+ attention_layer_norm_with_affine: bool = True
352
+ """
353
+ Toggle affine transform for the QK norms.
354
+ """
355
+
356
+ max_sequence_length: int = 1024
357
+ """
358
+ The maximum input sequence length supported by the model.
359
+ """
360
+
361
+ include_bias: bool = True
362
+ """
363
+ Whether or not to include bias parameters in linear layers.
364
+ In PaLM, they got rid of all bias terms because they found that large
365
+ models tend to have near 0 bias terms anyway.
366
+ """
367
+
368
+ bias_for_layer_norm: Optional[bool] = None
369
+ """
370
+ Whether or not to include bias parameters in layer norm.
371
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
372
+ layer norm.
373
+ When this is None (the default), it inherits the setting from include_bias.
374
+ """
375
+
376
+ scale_logits: bool = False
377
+ """
378
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
379
+ """
380
+
381
+ vocab_size: int = 50257
382
+ """
383
+ Vocabulary size of the model.
384
+ """
385
+
386
+ embedding_size: Optional[int] = 50304
387
+ """
388
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
389
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
390
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
391
+ substantially.
392
+ """
393
+
394
+ weight_tying: bool = True
395
+ """
396
+ Whether to tie output linear weights to the input embedding.
397
+ """
398
+
399
+ eos_token_id: int = 50256
400
+ """
401
+ The ID of the end-of-sentence special token.
402
+ """
403
+
404
+ pad_token_id: int = 50256
405
+ """
406
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
407
+ """
408
+
409
+ init_device: Optional[str] = None
410
+ """
411
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
412
+ """
413
+
414
+ init_fn: InitFnType = InitFnType.normal
415
+ """
416
+ The weight initialization strategy.
417
+ """
418
+
419
+ init_std: float = 0.02
420
+ """
421
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
422
+ as "normal".
423
+ """
424
+
425
+ init_cutoff_factor: Optional[float] = None
426
+ """
427
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
428
+ as "normal". Setting this to None means values are not cutoff.
429
+ """
430
+
431
+ precision: Optional[str] = None
432
+ """
433
+ Precision used to train/evaluate with. You shouldn't set this directly.
434
+ See :data:`TrainConfig.precision` instead.
435
+ """
436
+
437
+ ternary: bool = False
438
+ """
439
+ Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf)
440
+ """
441
+
442
+ @property
443
+ def effective_n_kv_heads(self) -> int:
444
+ if self.n_kv_heads is None:
445
+ if self.multi_query_attention is True:
446
+ return 1
447
+ else:
448
+ return self.n_heads
449
+ else:
450
+ if self.multi_query_attention is None:
451
+ return self.n_kv_heads
452
+ if self.multi_query_attention:
453
+ n_kv_heads_should_be = 1
454
+ else:
455
+ n_kv_heads_should_be = self.n_heads
456
+ if self.n_kv_heads == n_kv_heads_should_be:
457
+ return n_kv_heads_should_be
458
+ else:
459
+ raise OLMoConfigurationError(
460
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
461
+ )
462
+
463
+
464
+ class OptimizerType(StrEnum):
465
+ lionw = "lionw"
466
+ adamw = "adamw"
467
+
468
+
469
+ @dataclass
470
+ class OptimizerConfig(BaseConfig):
471
+ name: OptimizerType = OptimizerType.lionw
472
+ learning_rate: float = 1.0e-4
473
+ weight_decay: float = 0.01
474
+ betas: Tuple[float, float] = (0.9, 0.95)
475
+
476
+ no_decay_norm_and_bias: Optional[bool] = None
477
+ """
478
+ Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead.
479
+ """
480
+
481
+ decay_norm_and_bias: bool = False
482
+ decay_embeddings: bool = False
483
+ metrics_log_interval: Optional[int] = None
484
+ """
485
+ The interval with which to collect and log detailed parameter-specific metrics.
486
+ This only applies when logging to W&B, since these metrics won't be logged to the console.
487
+ If not set, defaults to the wandb `log_interval`.
488
+ """
489
+
490
+ def __post_init__(self):
491
+ self.betas = tuple(self.betas) # type: ignore[assignment]
492
+
493
+ @classmethod
494
+ def update_legacy_settings(cls, config: D) -> D:
495
+ new_config = config.copy()
496
+ if om.is_dict(new_config):
497
+ assert isinstance(new_config, DictConfig)
498
+
499
+ if hasattr(new_config, "name") and new_config.name == "decoupled_lionw":
500
+ new_config.name = "lionw"
501
+ if hasattr(new_config, "eps"):
502
+ del new_config.eps
503
+
504
+ return new_config
505
+
506
+
507
+ class SchedulerType(StrEnum):
508
+ cosine_with_warmup = "cosine_with_warmup"
509
+ linear_with_warmup = "linear_with_warmup"
510
+ inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup"
511
+ max_scheduler = "max_scheduler"
512
+ constant = "constant"
513
+
514
+
515
+ class SchedulerUnits(StrEnum):
516
+ steps = "steps"
517
+ tokens = "tokens"
518
+
519
+
520
+ @dataclass
521
+ class SchedulerConfig(BaseConfig):
522
+ name: SchedulerType = SchedulerType.cosine_with_warmup
523
+ units: SchedulerUnits = SchedulerUnits.steps
524
+ t_warmup: Union[int, float] = 100
525
+ t_max: Optional[Union[int, float]] = None
526
+ alpha_f: float = 0.1
527
+
528
+ grad_clip_warmup_steps: Optional[Union[int, float]] = None
529
+ """
530
+ The warmup period for which the max grad norm (or norm ratio) will be set to its
531
+ warmup value of `max_grad_norm * grad_clip_warmup_factor`.
532
+ """
533
+
534
+ grad_clip_warmup_factor: Optional[float] = None
535
+ """
536
+ The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period
537
+ vs after the warmup period.
538
+ """
539
+
540
+
541
+ class PaddingDirection(StrEnum):
542
+ right = "right"
543
+ left = "left"
544
+
545
+
546
+ @dataclass
547
+ class DataConfig(BaseConfig):
548
+ paths: Optional[List[str]] = None
549
+ datasets: Optional[Dict[str, List[str]]] = None
550
+ label_mask_paths: Optional[List[str]] = None
551
+ pad_direction: PaddingDirection = PaddingDirection.right
552
+ generate_attention_mask: bool = False
553
+ num_workers: int = 0
554
+ drop_last: bool = False
555
+ pin_memory: bool = False
556
+ prefetch_factor: Optional[int] = None
557
+ persistent_workers: bool = False
558
+ timeout: int = 0
559
+ seed: Optional[int] = None
560
+
561
+
562
+ class EvaluatorType(StrEnum):
563
+ downstream = "downstream"
564
+ lm = "lm"
565
+
566
+
567
+ @dataclass
568
+ class EvaluatorConfig(BaseConfig):
569
+ label: str
570
+ type: EvaluatorType = EvaluatorType.lm
571
+ data: DataConfig = field(default_factory=DataConfig)
572
+ device_eval_batch_size: Optional[int] = None
573
+ subset_num_batches: Optional[int] = None
574
+
575
+
576
+ class TruncationDirection(StrEnum):
577
+ right = "right"
578
+ left = "left"
579
+
580
+
581
+ @dataclass
582
+ class TokenizerConfig(BaseConfig):
583
+ identifier: str = "gpt2"
584
+ truncate_direction: TruncationDirection = TruncationDirection.right
585
+
586
+
587
+ @dataclass
588
+ class WandbConfig(BaseConfig):
589
+ project: Optional[str] = None
590
+ entity: Optional[str] = "ai2-llm"
591
+ group: Optional[str] = None
592
+ name: Optional[str] = None
593
+ tags: Optional[List[str]] = field(default_factory=lambda: ["watching"])
594
+ log_artifacts: bool = False
595
+ rank_zero_only: bool = True
596
+ log_interval: int = 1
597
+
598
+
599
+ @dataclass
600
+ class SpeedMonitorConfig(BaseConfig):
601
+ window_size: int = 100
602
+ gpu_flops_available: Optional[Union[float, int]] = None
603
+
604
+
605
+ @dataclass
606
+ class CompilerConfig(BaseConfig):
607
+ mode: Optional[str] = None
608
+ """
609
+ The mode to compile the model in. At the moment this can be "default",
610
+ "reduce-overhead" (useful for smaller models/batches), or "max-autotune"
611
+ (the fastest for larger models, but takes a long time to compile).
612
+ """
613
+
614
+ fullgraph: bool = False
615
+ """
616
+ Whether it is OK to break model into several subgraphs when compiling.
617
+ Note that this is not compatible with FSDP.
618
+ """
619
+
620
+ backend: str = "inductor"
621
+ """
622
+ The backend to use.
623
+ """
624
+
625
+
626
+ class FSDPWrapStrategy(StrEnum):
627
+ by_block = "by_block"
628
+ """
629
+ Wrap each OLMo block with its own FSDP instance.
630
+ """
631
+
632
+ by_block_and_size = "by_block_and_size"
633
+ """
634
+ Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well.
635
+ """
636
+
637
+ by_block_group = "by_block_group"
638
+ """
639
+ Wrap each block group together into its own FSDP instance.
640
+ This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1.
641
+ """
642
+
643
+ by_block_group_and_size = "by_block_group_and_size"
644
+ """
645
+ Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well.
646
+ """
647
+
648
+ size_based = "size_based"
649
+ """
650
+ Used PyTorch's default size-based auto wrap policy.
651
+ """
652
+
653
+ one_in_two = "one_in_two"
654
+ one_in_three = "one_in_three"
655
+ one_in_four = "one_in_four"
656
+ one_in_five = "one_in_five"
657
+
658
+
659
+ class FSDPPrecision(StrEnum):
660
+ pure = "pure"
661
+ """
662
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``,
663
+ and ``buffer_dtype`` all set to the autocast precision data type.
664
+ """
665
+
666
+ mixed = "mixed"
667
+ """
668
+ Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype``
669
+ set to the autocast precision data type, while ``reduce_dtype`` is set to fp32.
670
+ """
671
+
672
+
673
+ @dataclass
674
+ class FSDPConfig(BaseConfig):
675
+ use_orig_params: bool = True
676
+ """
677
+ This must be ``True`` if using ``compile`` or you want to track the parameter norm during training.
678
+ """
679
+
680
+ sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
681
+
682
+ wrapping_strategy: Optional[FSDPWrapStrategy] = None
683
+ """
684
+ The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level
685
+ FSDP instance.
686
+ """
687
+
688
+ precision: FSDPPrecision = FSDPPrecision.pure
689
+
690
+
691
+ class CheckpointType(StrEnum):
692
+ sharded = "sharded"
693
+ unsharded = "unsharded"
694
+ sharded_ephemeral = "sharded_ephemeral"
695
+
696
+
697
+ class ShardedCheckpointerType(StrEnum):
698
+ torch_new = "torch_new"
699
+ torch_legacy = "torch_legacy"
700
+ local = "local"
701
+
702
+
703
+ class ActivationCheckpointingStrategy(StrEnum):
704
+ whole_layer = "whole_layer"
705
+ """
706
+ Checkpoint every transformer layer.
707
+ """
708
+
709
+ one_in_two = "one_in_two"
710
+ """
711
+ Checkpoint one in two transformer layers.
712
+ """
713
+
714
+ one_in_three = "one_in_three"
715
+ """
716
+ Checkpoint one in three transformer layers.
717
+ """
718
+
719
+ one_in_four = "one_in_four"
720
+ """
721
+ Checkpoint one in four transformer layers.
722
+ """
723
+
724
+ two_in_three = "two_in_three"
725
+ """
726
+ Checkpoint two out of every three transformer layers.
727
+ """
728
+
729
+ three_in_four = "three_in_four"
730
+ """
731
+ Checkpoint three out of four of every transformer layers.
732
+ """
733
+
734
+ fine_grained = "fine_grained"
735
+ """
736
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
737
+ """
738
+
739
+
740
+ @dataclass
741
+ class TrainConfig(BaseConfig):
742
+ """
743
+ OLMo training configuration.
744
+ """
745
+
746
+ run_name: Optional[str] = None
747
+ """
748
+ The name of the run.
749
+ """
750
+
751
+ seed: int = 6198
752
+ """
753
+ Used to seed all initial RNG states.
754
+ """
755
+
756
+ epoch: Optional[int] = None
757
+ """
758
+ Increment this when starting a new epoch.
759
+ """
760
+
761
+ dry_run: bool = False
762
+ """
763
+ If ``True``, don't actually train.
764
+ """
765
+
766
+ model: ModelConfig = field(default_factory=ModelConfig)
767
+ """
768
+ OLMo Model configuration.
769
+ """
770
+
771
+ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
772
+ """
773
+ Optimizer configuration.
774
+ """
775
+
776
+ scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
777
+ """
778
+ Learning rate scheduler configuration.
779
+ """
780
+
781
+ data: DataConfig = field(default_factory=DataConfig)
782
+ """
783
+ Training data configuration.
784
+ """
785
+
786
+ restore_dataloader: bool = True
787
+ """
788
+ When restarting, restore the data loader to where it left off.
789
+ If you restarting in order to train on a different dataset, set this to ``False``.
790
+ """
791
+
792
+ fast_forward_batches: Optional[int] = None
793
+ """
794
+ When restarting, use this to fast-forward the dataloader beyond the last checkpoint.
795
+ This can be useful when restarting due to a loss spike in order to skip the data that
796
+ corresponded to the spike.
797
+ """
798
+
799
+ evaluators: List[EvaluatorConfig] = field(default_factory=list)
800
+ """
801
+ Evaluation configurations.
802
+ """
803
+
804
+ eval_interval: int = 1000
805
+ """
806
+ How often (in terms of batches) to run evaluations.
807
+ """
808
+
809
+ tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig)
810
+ """
811
+ Tokenizer configuration.
812
+ """
813
+
814
+ save_folder: str = "./"
815
+ """
816
+ The directory to save checkpoints to.
817
+ """
818
+
819
+ remote_save_folder: Optional[str] = None
820
+ """
821
+ A folder in a cloud bucket to upload saved checkpoints to.
822
+ """
823
+
824
+ canceled_check_interval: int = 50
825
+ """
826
+ How often (in batches) to check if the run has been canceled or reached its time limit.
827
+ """
828
+
829
+ save_interval: int = 1000
830
+ """
831
+ How often (in terms of steps) to save sharded training state checkpoints.
832
+ """
833
+
834
+ save_interval_unsharded: Optional[int] = None
835
+ """
836
+ How often (if at all) to save unsharded training state checkpoint.
837
+ For large models it can be costly to save these, so it usually makes sense to save
838
+ these less often than regular (sharded) training checkpoints.
839
+ """
840
+
841
+ save_interval_ephemeral: Optional[int] = None
842
+ """
843
+ How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same
844
+ as those saved every `save_interval` except that at most only the most recent one of these is kept.
845
+ This is useful when you want to checkpoint often for restarts in case of failures, but don't
846
+ want to keep the majority of these checkpoints.
847
+
848
+ For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save
849
+ a temporary checkpoint every 100 steps in case your job fails. In that case you would
850
+ set `save_interval=1000` and `save_interval_ephemeral=100`.
851
+ """
852
+
853
+ save_num_checkpoints_to_keep: int = -1
854
+ """
855
+ How many sharded checkpoints to keep.
856
+ """
857
+
858
+ save_num_unsharded_checkpoints_to_keep: int = -1
859
+ """
860
+ How many unsharded checkpoints to keep.
861
+ """
862
+
863
+ save_overwrite: bool = False
864
+ """
865
+ If ``True``, overwrite any conflicting checkpoint files.
866
+ """
867
+
868
+ force_save_unsharded: bool = False
869
+ """
870
+ Save an unsharded checkpoint before training (even during a dry run).
871
+ Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded
872
+ checkpoint into an unsharded checkpoint.
873
+ """
874
+
875
+ no_pre_train_checkpoint: bool = False
876
+ """
877
+ Skip saving pre-train checkpoint.
878
+ """
879
+
880
+ load_path: Optional[str] = None
881
+ """
882
+ The path to a training checkpoint to restore/resume from.
883
+
884
+ Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes
885
+ a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory.
886
+ For example,
887
+
888
+ ```bash
889
+ --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}'
890
+ ```
891
+ """
892
+
893
+ load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None
894
+ """
895
+ The sharded checkpointer type to use to load the initial checkpoint from ``load_path``.
896
+ """
897
+
898
+ reset_optimizer_state: bool = False
899
+ """
900
+ When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized.
901
+ We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning
902
+ curve (according to the current learning rate schedule settings), and continues from there.
903
+ """
904
+
905
+ reset_trainer_state: bool = False
906
+ """
907
+ When this is set we don't restore the trainer state from a checkpoint.
908
+ """
909
+
910
+ sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy
911
+ """
912
+ The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training.
913
+ """
914
+
915
+ new_style_checkpoints: Optional[bool] = None
916
+ """
917
+ Deprecated. Use ``sharded_checkpointer`` instead.
918
+ """
919
+
920
+ max_duration: Union[int, str] = 10000
921
+ """
922
+ How long to train for.
923
+
924
+ If specified without a unit (the default), the units are assumed to be steps.
925
+ You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until
926
+ 2 trillion tokens.
927
+ """
928
+
929
+ global_train_batch_size: int = 512
930
+ """
931
+ The effective global batch size.
932
+ """
933
+
934
+ device_train_batch_size: Optional[int] = None # calculated automatically
935
+ """
936
+ Don't set this manually. This will be set to ``global_train_batch_size // world_size``.
937
+ """
938
+
939
+ device_train_microbatch_size: int = 16
940
+ """
941
+ The number of instances passed to the model in a single forward-backward pass. You should set
942
+ this as large as you can based on available GPU memory.
943
+ """
944
+
945
+ device_eval_batch_size: int = 16
946
+ """
947
+ The number of evaluation instances passed to the model in a single forward pass on each device.
948
+ """
949
+
950
+ eval_subset_num_batches: int = -1
951
+ """
952
+ The number of batches to use for downstream evaluation from each dataset.
953
+ """
954
+
955
+ eval_on_load: bool = False
956
+ """
957
+ When resuming from a checkpoint, run the evaluation loop right away.
958
+ """
959
+
960
+ device_train_grad_accum: Optional[int] = None # calculated automatically
961
+ """
962
+ Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``.
963
+ """
964
+
965
+ max_grad_norm: Optional[float] = None
966
+ """
967
+ Clip gradient norms to this value if set.
968
+ """
969
+
970
+ max_grad_norm_ratio: Optional[float] = None
971
+ """
972
+ If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`.
973
+ This takes priority over `max_grad_norm` when set.
974
+ """
975
+
976
+ precision: Optional[str] = None
977
+ """
978
+ Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32").
979
+ """
980
+
981
+ wandb: Optional[WandbConfig] = None
982
+ """
983
+ Weights & Biases configuration.
984
+ """
985
+
986
+ speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
987
+ """
988
+ Speed monitor configuration.
989
+ """
990
+
991
+ console_log_interval: int = 1
992
+ """
993
+ How often to log to the console.
994
+ """
995
+
996
+ compile: Optional[CompilerConfig] = None
997
+ """
998
+ Settings for compiling the model with ``torch.compile()``.
999
+ """
1000
+
1001
+ fsdp: FSDPConfig = field(default_factory=FSDPConfig)
1002
+ """
1003
+ Fully sharded data parallel settings.
1004
+ """
1005
+
1006
+ softmax_auxiliary_loss: bool = False
1007
+ """
1008
+ If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax
1009
+ normalizing term to be close to 0.
1010
+ """
1011
+
1012
+ time_limit: Optional[float] = 60 * 60 * 47.5
1013
+ """
1014
+ The maximum amount of time to train for before saving a checkpoint and ending early.
1015
+ On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time
1016
+ to write out a final checkpoint.
1017
+ """
1018
+
1019
+ extra_steps_after_cancel: int = 10
1020
+ """
1021
+ Under certain conditions when a run is canceled we train for a few extra steps after saving
1022
+ the final checkpoint so that when the run is restarted from the latest checkpoint we have some
1023
+ overlap in metrics.
1024
+ """
1025
+
1026
+ early_stopping_factor: Optional[float] = None
1027
+
1028
+ save_data_indices: bool = True
1029
+ """
1030
+ Save training data indices from each batch for each worker.
1031
+ """
1032
+
1033
+ python_profiling: bool = False
1034
+ """
1035
+ Whether to run the Python profiler on batches 6, 7, and 8.
1036
+ """
1037
+
1038
+ torch_profiling: bool = False
1039
+ """
1040
+ Whether to run the PyTorch profiler on batches 6, 7, and 8.
1041
+ """
1042
+
1043
+ stop_at: Optional[int] = None
1044
+ """
1045
+ Stop at a specific step.
1046
+ """
1047
+
1048
+ stop_after: Optional[int] = None
1049
+ """
1050
+ Stop after a specific number of steps.
1051
+ """
1052
+
1053
+ activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None
1054
+ """
1055
+ The activation checkpointing strategy to use.
1056
+ """
1057
+
1058
+ fused_loss: Optional[bool] = None
1059
+ """
1060
+ Whether to use the fused CE loss function from `flash-attn`.
1061
+ """
1062
+
1063
+ @property
1064
+ def autocast_precision(self) -> torch.dtype:
1065
+ if self.precision == "amp_bf16":
1066
+ return torch.bfloat16
1067
+ elif self.precision == "amp_fp16":
1068
+ return torch.float16
1069
+ elif self.precision == "fp32":
1070
+ return torch.float32
1071
+ else:
1072
+ raise ValueError(f"Unexpected precision type '{self.precision}'")
1073
+
1074
+ @property
1075
+ def fsdp_precision(self) -> MixedPrecision:
1076
+ if self.fsdp.precision == FSDPPrecision.pure:
1077
+ return MixedPrecision(
1078
+ param_dtype=self.autocast_precision,
1079
+ reduce_dtype=self.autocast_precision,
1080
+ buffer_dtype=self.autocast_precision,
1081
+ )
1082
+ elif self.fsdp.precision == FSDPPrecision.mixed:
1083
+ return MixedPrecision(
1084
+ param_dtype=self.autocast_precision,
1085
+ reduce_dtype=torch.float32,
1086
+ buffer_dtype=self.autocast_precision,
1087
+ )
1088
+ else:
1089
+ raise NotImplementedError(f"{self.fsdp.precision}")
1090
+
1091
+ @classmethod
1092
+ def update_legacy_settings(cls, config: D) -> D:
1093
+ new_config = config.copy()
1094
+ if om.is_dict(new_config):
1095
+ assert isinstance(new_config, DictConfig)
1096
+
1097
+ if hasattr(new_config, "activation_checkpointing"):
1098
+ if new_config.activation_checkpointing is False:
1099
+ new_config.activation_checkpointing = None
1100
+ if new_config.activation_checkpointing is True:
1101
+ new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer
1102
+
1103
+ if hasattr(new_config, "optimizer"):
1104
+ new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer)
1105
+
1106
+ return new_config
configuration_olmo.py CHANGED
@@ -5,7 +5,13 @@ OLMo configuration
5
  from transformers import AutoConfig, PretrainedConfig
6
  from transformers.utils import logging
7
 
8
- from olmo.config import ModelConfig
 
 
 
 
 
 
9
 
10
  logger = logging.get_logger(__name__)
11
 
 
5
  from transformers import AutoConfig, PretrainedConfig
6
  from transformers.utils import logging
7
 
8
+ from .config import ModelConfig
9
+ from .aliases import PathOrStr
10
+ from .beam_search import Sampler
11
+ from .exceptions import OLMoError
12
+ from .initialization import ModuleType
13
+ from .util import StrEnum
14
+ from .torch_util import seed_all
15
 
16
  logger = logging.get_logger(__name__)
17
 
exceptions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "OLMoError",
3
+ "OLMoConfigurationError",
4
+ "OLMoCliError",
5
+ "OLMoEnvironmentError",
6
+ "OLMoNetworkError",
7
+ "OLMoCheckpointError",
8
+ ]
9
+
10
+
11
+ class OLMoError(Exception):
12
+ """
13
+ Base class for all custom OLMo exceptions.
14
+ """
15
+
16
+
17
+ class OLMoConfigurationError(OLMoError):
18
+ """
19
+ An error with a configuration file.
20
+ """
21
+
22
+
23
+ class OLMoCliError(OLMoError):
24
+ """
25
+ An error from incorrect CLI usage.
26
+ """
27
+
28
+
29
+ class OLMoEnvironmentError(OLMoError):
30
+ """
31
+ An error from incorrect environment variables.
32
+ """
33
+
34
+
35
+ class OLMoNetworkError(OLMoError):
36
+ """
37
+ An error with a network request.
38
+ """
39
+
40
+
41
+ class OLMoCheckpointError(OLMoError):
42
+ """
43
+ An error occurred reading or writing from a checkpoint.
44
+ """
45
+
46
+
47
+ class OLMoThreadError(Exception):
48
+ """
49
+ Raised when a thread fails.
50
+ """
initialization.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .config import InitFnType, ModelConfig
8
+ from .util import StrEnum
9
+
10
+ __all__ = ["init_weights", "ModuleType"]
11
+
12
+
13
+ class ModuleType(StrEnum):
14
+ in_module = "in"
15
+ out_module = "out"
16
+ emb = "emb"
17
+ final_out = "final_out"
18
+
19
+
20
+ def init_weights(
21
+ config: ModelConfig,
22
+ module: Union[nn.Linear, nn.Embedding],
23
+ d: Optional[int] = None,
24
+ layer_id: Optional[int] = None,
25
+ std_factor: float = 1.0,
26
+ type_of_module: Optional[ModuleType] = None,
27
+ ) -> None:
28
+ """
29
+ Initialize weights of a linear or embedding module.
30
+
31
+ :param config: The model config.
32
+ :param module: The linear or embedding submodule to initialize.
33
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
34
+ for fused layers.
35
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
36
+ ``1 / sqrt(2 * (layer_id + 1))``.
37
+ """
38
+ d = d if d is not None else config.d_model
39
+ if config.init_fn == InitFnType.normal:
40
+ std = config.init_std * std_factor
41
+ if config.init_cutoff_factor is not None:
42
+ cutoff_value = config.init_cutoff_factor * std
43
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
44
+ else:
45
+ nn.init.normal_(module.weight, mean=0.0, std=std)
46
+ elif config.init_fn == InitFnType.mitchell:
47
+ std = std_factor / math.sqrt(d)
48
+ if layer_id is not None:
49
+ std = std / math.sqrt(2 * (layer_id + 1))
50
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
51
+ elif config.init_fn == InitFnType.kaiming_normal:
52
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
53
+ elif config.init_fn == InitFnType.fan_in:
54
+ std = std_factor / math.sqrt(d)
55
+ nn.init.normal_(module.weight, mean=0.0, std=std)
56
+ elif config.init_fn == InitFnType.full_megatron:
57
+ if type_of_module is None:
58
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
59
+
60
+ cutoff_factor = config.init_cutoff_factor
61
+ if cutoff_factor is None:
62
+ cutoff_factor = 3
63
+
64
+ if type_of_module == ModuleType.in_module:
65
+ # for att_proj (same as QKV), ff_proj
66
+ std = config.init_std
67
+ elif type_of_module == ModuleType.out_module:
68
+ # for attn_out, ff_out
69
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
70
+ elif type_of_module == ModuleType.emb:
71
+ # positional embeddings (wpe)
72
+ # token embeddings (wte)
73
+ std = config.init_std
74
+ elif type_of_module == ModuleType.final_out:
75
+ # final output (ff_out)
76
+ std = config.d_model**-0.5
77
+ else:
78
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
79
+ nn.init.trunc_normal_(
80
+ module.weight,
81
+ mean=0.0,
82
+ std=std,
83
+ a=-cutoff_factor * std,
84
+ b=cutoff_factor * std,
85
+ )
86
+ else:
87
+ raise NotImplementedError(config.init_fn)
88
+
89
+ if isinstance(module, nn.Linear):
90
+ if module.bias is not None:
91
+ nn.init.zeros_(module.bias)
92
+
93
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
94
+ with torch.no_grad():
95
+ module.weight.div_(math.sqrt(2 * config.n_layers))
model.py ADDED
@@ -0,0 +1,1778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from
3
+ [MosaiclML](https://github.com/mosaicml/examples.git) and
4
+ [minGPT](https://github.com/karpathy/minGPT.git)
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import math
11
+ import sys
12
+ from abc import abstractmethod
13
+ from collections import defaultdict
14
+ from functools import partial
15
+ from typing import (
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ NamedTuple,
21
+ Optional,
22
+ Sequence,
23
+ Set,
24
+ Tuple,
25
+ cast,
26
+ )
27
+
28
+ import torch
29
+ import torch.backends.cuda
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch import einsum
33
+
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast
35
+
36
+ from .aliases import PathOrStr
37
+ from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
38
+ from .config import (
39
+ ActivationCheckpointingStrategy,
40
+ ActivationType,
41
+ BlockType,
42
+ CheckpointType,
43
+ FSDPWrapStrategy,
44
+ LayerNormType,
45
+ ModelConfig,
46
+ )
47
+ from .exceptions import OLMoConfigurationError
48
+ from .initialization import ModuleType, init_weights
49
+ from .torch_util import ensure_finite_
50
+
51
+ if sys.version_info.minor > 8:
52
+ from collections.abc import MutableMapping
53
+ elif sys.version_info.minor == 8:
54
+ from typing import MutableMapping
55
+ else:
56
+ raise SystemExit("This script supports Python 3.8 or higher")
57
+
58
+ __all__ = [
59
+ "LayerNormBase",
60
+ "LayerNorm",
61
+ "RMSLayerNorm",
62
+ "RotaryEmbedding",
63
+ "Activation",
64
+ "GELU",
65
+ "ReLU",
66
+ "SwiGLU",
67
+ "BitLinear158",
68
+ "OLMoBlock",
69
+ "OLMoSequentialBlock",
70
+ "OLMoParallelBlock",
71
+ "OLMo",
72
+ "OLMoOutput",
73
+ "OLMoGenerateOutput",
74
+ ]
75
+
76
+
77
+ log = logging.getLogger(__name__)
78
+
79
+
80
+ def activation_checkpoint_function(cfg: ModelConfig):
81
+ preserve_rng_state = (
82
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
83
+ )
84
+ from torch.utils.checkpoint import checkpoint
85
+
86
+ return partial(
87
+ checkpoint,
88
+ preserve_rng_state=preserve_rng_state,
89
+ use_reentrant=False,
90
+ )
91
+
92
+
93
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
94
+ """
95
+ Cache for attention biases and other things that would normally be stored as buffers.
96
+ We avoid using buffers because we've run into various issues doing so with FSDP.
97
+ In general it appears the way FSDP handles buffers is not well-defined.
98
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
99
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
100
+ NaNs when they're synchronized due to casting or some other issue.
101
+ """
102
+
103
+
104
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
105
+ if config.init_device is not None and config.init_device != "meta":
106
+ return torch.device(config.init_device)
107
+ else:
108
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
+
110
+
111
+ class Dropout(nn.Dropout):
112
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
113
+ if self.p == 0.0:
114
+ return input
115
+ else:
116
+ return F.dropout(input, self.p, self.training, self.inplace)
117
+
118
+
119
+ class LayerNormBase(nn.Module):
120
+ def __init__(
121
+ self,
122
+ config: ModelConfig,
123
+ *,
124
+ size: Optional[int] = None,
125
+ elementwise_affine: Optional[bool] = True,
126
+ eps: float = 1e-05,
127
+ ):
128
+ super().__init__()
129
+ self.config = config
130
+ self.eps = eps
131
+ self.normalized_shape = (size or config.d_model,)
132
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
133
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
134
+ use_bias = self.config.bias_for_layer_norm
135
+ if use_bias is None:
136
+ use_bias = self.config.include_bias
137
+ if use_bias:
138
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
139
+ else:
140
+ self.register_parameter("bias", None)
141
+ else:
142
+ self.register_parameter("bias", None)
143
+ self.register_parameter("weight", None)
144
+
145
+ @abstractmethod
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ raise NotImplementedError
148
+
149
+ @classmethod
150
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
151
+ if config.layer_norm_type == LayerNormType.default:
152
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
153
+ elif config.layer_norm_type == LayerNormType.low_precision:
154
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
155
+ elif config.layer_norm_type == LayerNormType.rms:
156
+ return RMSLayerNorm(config, size=size, **kwargs)
157
+ else:
158
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
159
+
160
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
161
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
162
+ # `is_autocast_cpu_enabled()` for CPU autocast.
163
+ # See https://github.com/pytorch/pytorch/issues/110966.
164
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
165
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
166
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
167
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
168
+ else:
169
+ return tensor
170
+
171
+ def reset_parameters(self):
172
+ if self.weight is not None:
173
+ torch.nn.init.ones_(self.weight) # type: ignore
174
+ if self.bias is not None:
175
+ torch.nn.init.zeros_(self.bias) # type: ignore
176
+
177
+
178
+ class LayerNorm(LayerNormBase):
179
+ """
180
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ config: ModelConfig,
186
+ size: Optional[int] = None,
187
+ low_precision: bool = False,
188
+ elementwise_affine: Optional[bool] = None,
189
+ eps: float = 1e-05,
190
+ ):
191
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
192
+ self.low_precision = low_precision
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ if self.low_precision:
196
+ module_device = x.device
197
+ downcast_x = self._cast_if_autocast_enabled(x)
198
+ downcast_weight = (
199
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
200
+ )
201
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
202
+ with torch.autocast(enabled=False, device_type=module_device.type):
203
+ return F.layer_norm(
204
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
205
+ )
206
+ else:
207
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
208
+
209
+
210
+ class RMSLayerNorm(LayerNormBase):
211
+ """
212
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ config: ModelConfig,
218
+ size: Optional[int] = None,
219
+ elementwise_affine: Optional[bool] = None,
220
+ eps: float = 1e-5,
221
+ ):
222
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
223
+
224
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
225
+ with torch.autocast(enabled=False, device_type=x.device.type):
226
+ og_dtype = x.dtype
227
+ x = x.to(torch.float32)
228
+ variance = x.pow(2).mean(-1, keepdim=True)
229
+ x = x * torch.rsqrt(variance + self.eps)
230
+ x = x.to(og_dtype)
231
+
232
+ if self.weight is not None:
233
+ if self.bias is not None:
234
+ return self.weight * x + self.bias
235
+ else:
236
+ return self.weight * x
237
+ else:
238
+ return x
239
+
240
+
241
+ class RotaryEmbedding(nn.Module):
242
+ """
243
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
244
+ """
245
+
246
+ def __init__(self, config: ModelConfig, cache: BufferCache):
247
+ super().__init__()
248
+ self.config = config
249
+ self.__cache = cache
250
+ # Warm up cache.
251
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
252
+
253
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
254
+ if (
255
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
256
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
257
+ and pos_sin.shape[-2] >= seq_len
258
+ and pos_cos.shape[-2] >= seq_len
259
+ ):
260
+ if pos_sin.device != device:
261
+ pos_sin = pos_sin.to(device)
262
+ self.__cache["rope_pos_sin"] = pos_sin
263
+ if pos_cos.device != device:
264
+ pos_cos = pos_cos.to(device)
265
+ self.__cache["rope_pos_cos"] = pos_cos
266
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
267
+
268
+ with torch.autocast(device.type, enabled=False):
269
+ dim = self.config.d_model // self.config.n_heads
270
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
271
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
272
+ freqs = einsum("i , j -> i j", seq, inv_freq)
273
+ positions = torch.cat((freqs, freqs), dim=-1)
274
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
275
+ self.__cache["rope_pos_sin"] = pos_sin
276
+ self.__cache["rope_pos_cos"] = pos_cos
277
+ return pos_sin, pos_cos
278
+
279
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
280
+ B, nh, T, hs = x.size()
281
+ x = x.view(B, nh, T, 2, hs // 2)
282
+ x1, x2 = x.unbind(dim=-2)
283
+ return torch.cat((-x2, x1), dim=-1)
284
+
285
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
286
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
287
+
288
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ if self.config.rope_full_precision:
290
+ q_, k_ = q.float(), k.float()
291
+ else:
292
+ q_, k_ = q, k
293
+
294
+ with torch.autocast(q.device.type, enabled=False):
295
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
296
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
297
+ pos_sin = pos_sin.type_as(q_)
298
+ pos_cos = pos_cos.type_as(q_)
299
+ q_ = self.apply_rotary_pos_emb(
300
+ pos_sin[:, :, key_len - query_len : key_len, :],
301
+ pos_cos[:, :, key_len - query_len : key_len, :],
302
+ q_,
303
+ )
304
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
305
+ return q_.type_as(q), k_.type_as(k)
306
+
307
+
308
+ class Activation(nn.Module):
309
+ def __init__(self, config: ModelConfig):
310
+ super().__init__()
311
+ self.config = config
312
+
313
+ @abstractmethod
314
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
315
+ raise NotImplementedError
316
+
317
+ @property
318
+ @abstractmethod
319
+ def output_multiplier(self) -> float:
320
+ raise NotImplementedError
321
+
322
+ @classmethod
323
+ def build(cls, config: ModelConfig) -> Activation:
324
+ if config.activation_type == ActivationType.gelu:
325
+ return cast(Activation, GELU(approximate="none"))
326
+ elif config.activation_type == ActivationType.relu:
327
+ return cast(Activation, ReLU(inplace=False))
328
+ elif config.activation_type == ActivationType.swiglu:
329
+ return SwiGLU(config)
330
+ else:
331
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
332
+
333
+
334
+ class GELU(nn.GELU):
335
+ @property
336
+ def output_multiplier(self) -> float:
337
+ return 1.0
338
+
339
+
340
+ class ReLU(nn.ReLU):
341
+ @property
342
+ def output_multiplier(self) -> float:
343
+ return 1.0
344
+
345
+
346
+ class SwiGLU(Activation):
347
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
348
+ x, gate = x.chunk(2, dim=-1)
349
+ return F.silu(gate) * x
350
+
351
+ @property
352
+ def output_multiplier(self) -> float:
353
+ return 0.5
354
+
355
+
356
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
357
+ att_bias = torch.triu(
358
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
359
+ diagonal=1,
360
+ )
361
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
362
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
363
+
364
+
365
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
366
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
367
+ if causal_bias.device != device:
368
+ causal_bias = causal_bias.to(device)
369
+ cache["causal_attention_bias"] = causal_bias
370
+ return causal_bias
371
+ with torch.autocast(device.type, enabled=False):
372
+ causal_bias = causal_attention_bias(seq_len, device)
373
+ cache["causal_attention_bias"] = causal_bias
374
+ return causal_bias
375
+
376
+
377
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
378
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
379
+
380
+ # shape: (1, 1, seq_len, seq_len)
381
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
382
+ alibi_bias.abs_().mul_(-1)
383
+
384
+ # shape: (n_heads,)
385
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
386
+ m.mul_(config.alibi_bias_max / config.n_heads)
387
+
388
+ # shape: (1, n_heads, seq_len, seq_len)
389
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
390
+
391
+ def activation_quant(x):
392
+ """Per−token quantization to 8 bits. No grouping is needed for quantization.
393
+ Args:
394
+ x: an activation tensor with shape [n, d]
395
+ Returns:
396
+ y: a quantized activation tensor with shape [n, d]
397
+ """
398
+ scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
399
+ y = (x * scale).round().clamp_(-128, 127) / scale
400
+ return y
401
+
402
+ def weight_quant(w):
403
+ """Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
404
+ Args:
405
+ w: a weight tensor with shape [d, k]
406
+ Returns:
407
+ u: a quantized weight with shape [d, k]
408
+ """
409
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
410
+ u = (w * scale).round().clamp_(-1, 1) / scale
411
+ return u
412
+
413
+
414
+ class BitLinear158(nn.Linear):
415
+ """
416
+ This is only for training, and kernel optimization is needed for efficiency.
417
+ """
418
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
419
+ device=None, dtype=None, config=None):
420
+ super().__init__(in_features, out_features, bias, device, dtype)
421
+ self.norm = RMSLayerNorm(config, elementwise_affine=False)
422
+
423
+ def forward(self, x):
424
+ """
425
+ Args:
426
+ x: an input tensor with shape [n, d]
427
+ Returns:
428
+ y: an output tensor with shape [n, d]
429
+ """
430
+ w = self.weight # a weight tensor with shape [d, k]
431
+ x_norm = self.norm(x)
432
+ # Atrick for implementing Straight−Through−Estimator (STE) using detach()
433
+ x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
434
+ w_quant = w + (weight_quant(w) - w).detach()
435
+ y = F.linear(x_quant, w_quant)
436
+ return y
437
+
438
+
439
+ class OLMoBlock(nn.Module):
440
+ """
441
+ A base class for transformer block implementations.
442
+ """
443
+
444
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
445
+ super().__init__()
446
+ self.layer_id = layer_id
447
+ self.config = config
448
+ self.hidden_size = (
449
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
450
+ )
451
+ self.__cache = cache
452
+ assert config.d_model % config.n_heads == 0
453
+
454
+ self._activation_checkpoint_fn = None
455
+
456
+ Linear = BitLinear158 if config.ternary else nn.Linear
457
+
458
+ # Dropout.
459
+ self.dropout = Dropout(config.residual_dropout)
460
+
461
+ # Layer norms.
462
+ self.k_norm: Optional[LayerNormBase] = None
463
+ self.q_norm: Optional[LayerNormBase] = None
464
+ if config.attention_layer_norm:
465
+ self.k_norm = LayerNormBase.build(
466
+ config,
467
+ size=config.d_model // config.n_heads if config.multi_query_attention else None,
468
+ elementwise_affine=config.attention_layer_norm_with_affine,
469
+ )
470
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
471
+
472
+ # Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
473
+ if config.clip_qkv is not None:
474
+ assert config.clip_qkv > 0
475
+
476
+ # Activation function.
477
+ self.act = Activation.build(config)
478
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
479
+
480
+ # Attention output projection.
481
+ self.attn_out = Linear(
482
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device,
483
+ config=config
484
+ )
485
+
486
+ # Feed-forward output projection.
487
+ self.ff_out = Linear(
488
+ int(self.act.output_multiplier * self.hidden_size),
489
+ config.d_model,
490
+ bias=config.include_bias,
491
+ device=config.init_device,
492
+ config=config,
493
+ )
494
+ self.ff_out._is_residual = True # type: ignore
495
+
496
+ # Rotary embeddings.
497
+ if self.config.rope:
498
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
499
+
500
+ def reset_parameters(self):
501
+ if self.k_norm is not None:
502
+ self.k_norm.reset_parameters()
503
+ if self.q_norm is not None:
504
+ self.q_norm.reset_parameters()
505
+ init_weights(
506
+ self.config,
507
+ self.attn_out,
508
+ d=self.config.d_model,
509
+ layer_id=self.layer_id,
510
+ type_of_module=ModuleType.out_module,
511
+ )
512
+ init_weights(
513
+ self.config,
514
+ self.ff_out,
515
+ d=self.ff_out.in_features,
516
+ layer_id=self.layer_id,
517
+ type_of_module=ModuleType.out_module,
518
+ )
519
+
520
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
521
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
522
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
523
+ else:
524
+ self._activation_checkpoint_fn = None
525
+
526
+ @classmethod
527
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
528
+ target_dtype = input_dtype
529
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
530
+ # `is_autocast_cpu_enabled()` for CPU autocast.
531
+ # See https://github.com/pytorch/pytorch/issues/110966.
532
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
533
+ target_dtype = torch.get_autocast_gpu_dtype()
534
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
535
+ target_dtype = torch.get_autocast_cpu_dtype()
536
+ if bias.dtype != target_dtype:
537
+ bias = bias.to(target_dtype)
538
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
539
+ return bias
540
+
541
+ def _scaled_dot_product_attention(
542
+ self,
543
+ q: torch.Tensor,
544
+ k: torch.Tensor,
545
+ v: torch.Tensor,
546
+ attn_mask: Optional[torch.Tensor] = None,
547
+ dropout_p: float = 0.0,
548
+ is_causal: bool = False,
549
+ ) -> torch.Tensor:
550
+ """
551
+ Computes scaled dot product attention on query, key and value tensors, using an optional
552
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
553
+
554
+ This method is based on PyTorch's `scaled_dot_product_attention`.
555
+ """
556
+ return F.scaled_dot_product_attention(
557
+ q,
558
+ k,
559
+ v,
560
+ attn_mask=attn_mask,
561
+ dropout_p=dropout_p,
562
+ is_causal=is_causal,
563
+ )
564
+
565
+ def attention(
566
+ self,
567
+ q: torch.Tensor,
568
+ k: torch.Tensor,
569
+ v: torch.Tensor,
570
+ attention_bias: Optional[torch.Tensor] = None,
571
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
+ use_cache: bool = False,
573
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
574
+ B, T, C = q.size() # batch size, sequence length, d_model
575
+ dtype = k.dtype
576
+
577
+ # Optionally apply layer norm to keys and queries.
578
+ if self.q_norm is not None and self.k_norm is not None:
579
+ q = self.q_norm(q).to(dtype=dtype)
580
+ k = self.k_norm(k).to(dtype=dtype)
581
+
582
+ # Move head forward to be next to the batch dim.
583
+ # shape: (B, nh, T, hs)
584
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
585
+ if self.config.multi_query_attention:
586
+ # shape: (B, 1, T, hs)
587
+ k = k.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
588
+ # shape: (B, 1, T, hs)
589
+ v = v.view(B, T, 1, C // self.config.n_heads).transpose(1, 2)
590
+ else:
591
+ # shape: (B, nh, T, hs)
592
+ k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
593
+ # shape: (B, nh, T, hs)
594
+ v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
595
+
596
+ if layer_past is not None:
597
+ past_key, past_value = layer_past
598
+ k = torch.cat((past_key, k), dim=-2)
599
+ v = torch.cat((past_value, v), dim=-2)
600
+
601
+ present = (k, v) if use_cache else None
602
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
603
+
604
+ if self.config.rope:
605
+ # Apply rotary embeddings.
606
+ q, k = self.rotary_emb(q, k)
607
+
608
+ if attention_bias is not None:
609
+ # Resize and cast attention bias.
610
+ # The current dtype of the attention bias might not match the dtype that the SDP attn function will
611
+ # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
612
+ # as down-casting the attention bias to the autocast precision will result in -infs, which will
613
+ # cause the SDP attn function to produce NaNs.
614
+ attention_bias = self._cast_attn_bias(
615
+ attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
616
+ )
617
+
618
+ # Get the attention scores.
619
+ # shape: (B, nh, T, hs)
620
+ att = self._scaled_dot_product_attention(
621
+ q,
622
+ k,
623
+ v,
624
+ attn_mask=attention_bias,
625
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
626
+ is_causal=attention_bias is None,
627
+ )
628
+
629
+ # Re-assemble all head outputs side-by-side.
630
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
631
+
632
+ # Apply output projection.
633
+ return self.attn_out(att), present
634
+
635
+ @abstractmethod
636
+ def forward(
637
+ self,
638
+ x: torch.Tensor,
639
+ attention_bias: Optional[torch.FloatTensor] = None,
640
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
641
+ use_cache: bool = False,
642
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
643
+ raise NotImplementedError
644
+
645
+ @classmethod
646
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OLMoBlock:
647
+ if config.block_type == BlockType.sequential:
648
+ return OLMoSequentialBlock(layer_id, config, cache)
649
+ elif config.block_type == BlockType.parallel:
650
+ return OLMoParallelBlock(layer_id, config, cache)
651
+ elif config.block_type == BlockType.llama:
652
+ return OLMoLlamaBlock(layer_id, config, cache)
653
+ else:
654
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
655
+
656
+
657
+ class OLMoSequentialBlock(OLMoBlock):
658
+ """
659
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
660
+ (plus another skip connection).
661
+ """
662
+
663
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
664
+ super().__init__(layer_id, config, cache)
665
+ # Layer norms.
666
+ self.attn_norm = LayerNorm.build(config)
667
+ self.ff_norm = LayerNorm.build(config)
668
+ Linear = BitLinear158 if config.ternary else nn.Linear
669
+ # Attention input projection. Projects x -> (q, k, v)
670
+ if config.multi_query_attention:
671
+ self.fused_dims = (config.d_model, config.d_model // config.n_heads, config.d_model // config.n_heads)
672
+ else:
673
+ self.fused_dims = (config.d_model, config.d_model, config.d_model)
674
+ self.att_proj = Linear(
675
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
676
+ config=config
677
+ )
678
+ # Feed-forward input projection.
679
+ self.ff_proj = Linear(
680
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
681
+ config=config
682
+ )
683
+
684
+ def reset_parameters(self):
685
+ super().reset_parameters()
686
+ self.attn_norm.reset_parameters()
687
+ self.ff_norm.reset_parameters()
688
+ # NOTE: the standard deviation for these weights does not depend on the layer.
689
+ init_weights(
690
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
691
+ )
692
+ init_weights(
693
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
694
+ )
695
+
696
+ def forward(
697
+ self,
698
+ x: torch.Tensor,
699
+ attention_bias: Optional[torch.Tensor] = None,
700
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
701
+ use_cache: bool = False,
702
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
703
+ # Get query, key, value projections.
704
+ # shape:
705
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
706
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
707
+ # k, v: (batch_size, seq_len, d_model // n_heads)
708
+ if self._activation_checkpoint_fn is not None:
709
+ qkv = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x))
710
+ else:
711
+ qkv = self.att_proj(self.attn_norm(x))
712
+
713
+ if self.config.clip_qkv is not None:
714
+ qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
715
+
716
+ q, k, v = qkv.split(self.fused_dims, dim=-1)
717
+
718
+ # Get attention scores.
719
+ if self._activation_checkpoint_fn is not None:
720
+ att, cache = self._activation_checkpoint_fn( # type: ignore
721
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
722
+ )
723
+ else:
724
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
725
+
726
+ # Add attention scores.
727
+ # shape: (B, T, C)
728
+ x = x + self.dropout(att)
729
+
730
+ # Add feed-forward projection.
731
+ # shape: (batch_size, seq_len, d_model)
732
+ og_x = x
733
+ if self._activation_checkpoint_fn is not None:
734
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
735
+ else:
736
+ x = self.ff_norm(x)
737
+ x = self.ff_proj(x)
738
+ if self._activation_checkpoint_fn is not None:
739
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
740
+ else:
741
+ x = self.act(x)
742
+ x = self.ff_out(x)
743
+ x = self.dropout(x)
744
+ x = og_x + x
745
+
746
+ return x, cache
747
+
748
+
749
+ class OLMoParallelBlock(OLMoBlock):
750
+ """
751
+ This is a transformer block where the output is computed as ``MLP(LN(x)) + Attention(LN(x))``
752
+ as in the PaLM architecture, as opposed to the typical ``MLP(LN(x + Attention(LN(x))))``
753
+ as in :class:`OLMoSequentialBlock` (ignoring some skip connections).
754
+
755
+ The decoupling of the MLP and Attention functions allow us to fuse the separate input projections
756
+ into a single linear layer to increase throughput. In this configuration it's also straight-forward
757
+ to fuse the output projections, but we found that didn't help.
758
+ """
759
+
760
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
761
+ super().__init__(layer_id, config, cache)
762
+ self.norm = LayerNorm.build(config)
763
+ Linear = BitLinear158 if config.ternary else nn.Linear
764
+ # Fused attention and feed-forward projection.
765
+ # NOTE: we could also fuse the attention and feed-forward output projections but we
766
+ # found that didn't help, possibly because of the overhead of joining the `att` and
767
+ # `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
768
+ if config.multi_query_attention:
769
+ self.fused_dims = (
770
+ config.d_model,
771
+ config.d_model // config.n_heads,
772
+ config.d_model // config.n_heads,
773
+ self.hidden_size,
774
+ )
775
+ else:
776
+ self.fused_dims = (config.d_model, config.d_model, config.d_model, self.hidden_size)
777
+ self.fused_attn_ff_proj = Linear(
778
+ config.d_model, sum(self.fused_dims), bias=config.include_bias, device=config.init_device,
779
+ config=config
780
+ )
781
+
782
+ def reset_parameters(self):
783
+ super().reset_parameters()
784
+ self.norm.reset_parameters()
785
+ # NOTE: the standard deviation for these weights does not depend on the layer.
786
+ init_weights(
787
+ self.config,
788
+ self.fused_attn_ff_proj,
789
+ d=self.config.d_model,
790
+ layer_id=None,
791
+ type_of_module=ModuleType.in_module,
792
+ )
793
+
794
+ def forward(
795
+ self,
796
+ x: torch.Tensor,
797
+ attention_bias: Optional[torch.Tensor] = None,
798
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
799
+ use_cache: bool = False,
800
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
801
+ # Get query, key, value, and feed-forward projections.
802
+ # shape of q, k, v:
803
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
804
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
805
+ # k, v: (batch_size, seq_len, d_model // n_heads)
806
+ # shape of ff: (batch_size, seq_len, hidden_size)
807
+ if self._activation_checkpoint_fn is not None:
808
+ q, k, v, ff = self.fused_attn_ff_proj(self._activation_checkpoint_fn(self.norm, x)).split(
809
+ self.fused_dims, dim=-1
810
+ )
811
+ else:
812
+ q, k, v, ff = self.fused_attn_ff_proj(self.norm(x)).split(self.fused_dims, dim=-1)
813
+
814
+ if self.config.clip_qkv is not None:
815
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
816
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
817
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
818
+
819
+ # Get attention scores.
820
+ # shape: (B, T, C)
821
+ if self._activation_checkpoint_fn is not None:
822
+ att, cache = self._activation_checkpoint_fn( # type: ignore
823
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
824
+ )
825
+ else:
826
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
827
+
828
+ # Apply output projections (and activation function) and sum the results.
829
+ # We keep these projections separate because we found that we got better throughput this
830
+ # way compared to fusing them.
831
+ if self._activation_checkpoint_fn is not None:
832
+ return (
833
+ x + self.dropout(self.ff_out(self._activation_checkpoint_fn(self.act, ff))) + self.dropout(att),
834
+ cache,
835
+ )
836
+ else:
837
+ return (
838
+ x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att),
839
+ cache,
840
+ )
841
+
842
+
843
+ class OLMoLlamaBlock(OLMoBlock):
844
+ """
845
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
846
+ (plus another skip connection). This block is similar to `OLMoSequentialBlock`
847
+ but some operations have slightly different implementations to imitate the
848
+ behavior of Llama.
849
+ """
850
+
851
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
852
+ super().__init__(layer_id, config, cache)
853
+ # Layer norms.
854
+ self.attn_norm = LayerNorm.build(config)
855
+ self.ff_norm = LayerNorm.build(config)
856
+ self.__cache = cache
857
+ Linear = BitLinear158 if config.ternary else nn.Linear
858
+
859
+ # Attention input projection. Projects x -> (q, k, v)
860
+ if config.multi_query_attention:
861
+ q_proj_out_dim = config.d_model
862
+ k_proj_out_dim = config.d_model // config.n_heads
863
+ v_proj_out_dim = config.d_model // config.n_heads
864
+ else:
865
+ q_proj_out_dim = config.d_model
866
+ k_proj_out_dim = config.d_model
867
+ v_proj_out_dim = config.d_model
868
+ self.q_proj = Linear(
869
+ config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device,
870
+ config=config
871
+ )
872
+ self.k_proj = Linear(
873
+ config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device,
874
+ config=config
875
+ )
876
+ self.v_proj = Linear(
877
+ config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device,
878
+ config=config
879
+ )
880
+
881
+ # Feed-forward input projection.
882
+ self.ff_proj = Linear(
883
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device,
884
+ config=config
885
+ )
886
+
887
+ def reset_parameters(self):
888
+ super().reset_parameters()
889
+ if self.attn_norm:
890
+ self.attn_norm.reset_parameters()
891
+ self.ff_norm.reset_parameters()
892
+ # NOTE: the standard deviation for these weights does not depend on the layer.
893
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
894
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
895
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
896
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
897
+
898
+ def _scaled_dot_product_attention(
899
+ self,
900
+ q: torch.Tensor,
901
+ k: torch.Tensor,
902
+ v: torch.Tensor,
903
+ attn_mask: Optional[torch.Tensor] = None,
904
+ dropout_p: float = 0.0,
905
+ is_causal: bool = False,
906
+ ) -> torch.Tensor:
907
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
908
+
909
+ if is_causal:
910
+ assert attn_mask is None
911
+
912
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
913
+ attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
914
+ elif attn_mask is not None:
915
+ attn_bias = attn_mask.to(q.dtype)
916
+ else:
917
+ attn_bias = torch.zeros_like(attn_weights)
918
+
919
+ attn_weights += attn_bias
920
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
921
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
922
+ return torch.matmul(attn_weights, v)
923
+
924
+ def forward(
925
+ self,
926
+ x: torch.Tensor,
927
+ attention_bias: Optional[torch.Tensor] = None,
928
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
929
+ use_cache: bool = False,
930
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
931
+ # Get query, key, value projections.
932
+ # shape:
933
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
934
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
935
+ # k, v: (batch_size, seq_len, d_model // n_heads)
936
+ x_normed = self.attn_norm(x)
937
+ q = self.q_proj(x_normed)
938
+ k = self.k_proj(x_normed)
939
+ v = self.v_proj(x_normed)
940
+
941
+ if self.config.clip_qkv is not None:
942
+ q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
943
+ k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
944
+ v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
945
+
946
+ # Get attention scores.
947
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
948
+
949
+ # Add attention scores.
950
+ # shape: (B, T, C)
951
+ x = x + self.dropout(att)
952
+
953
+ # Add feed-forward projection.
954
+ # shape: (batch_size, seq_len, d_model)
955
+ og_x = x
956
+ if self._activation_checkpoint_fn is not None:
957
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
958
+ else:
959
+ x = self.ff_norm(x)
960
+ x = self.ff_proj(x)
961
+ if self._activation_checkpoint_fn is not None:
962
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
963
+ else:
964
+ x = self.act(x)
965
+ x = self.ff_out(x)
966
+ x = self.dropout(x)
967
+ x = og_x + x
968
+
969
+ return x, cache
970
+
971
+
972
+ class OLMoOutput(NamedTuple):
973
+ logits: torch.FloatTensor
974
+ """
975
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
976
+ for the next token *before* normalization via (log) softmax.
977
+ """
978
+
979
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
980
+ """
981
+ Attention keys and values from each block.
982
+ """
983
+
984
+ hidden_states: Optional[Tuple[torch.Tensor]]
985
+ """
986
+ Hidden states from each block.
987
+ """
988
+
989
+
990
+ class OLMoGenerateOutput(NamedTuple):
991
+ token_ids: torch.LongTensor
992
+ """
993
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
994
+ These do *not* include the original input IDs.
995
+ """
996
+
997
+ scores: torch.FloatTensor
998
+ """
999
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
1000
+ """
1001
+
1002
+
1003
+ class OLMoBlockGroup(nn.ModuleList):
1004
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
1005
+ super().__init__(modules)
1006
+ self.config = config
1007
+ self.layer_offset = layer_offset
1008
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1009
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
1010
+
1011
+ def forward(
1012
+ self,
1013
+ x: torch.Tensor,
1014
+ attention_bias: Optional[torch.FloatTensor] = None,
1015
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
1016
+ use_cache: bool = False,
1017
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
1018
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1019
+ for block_idx, block in enumerate(self):
1020
+ layer_past = None if layers_past is None else layers_past[block_idx]
1021
+ block_idx += self.layer_offset
1022
+ if (
1023
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1024
+ or (
1025
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1026
+ and block_idx % 2 == 0
1027
+ )
1028
+ or (
1029
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1030
+ and block_idx % 3 == 0
1031
+ )
1032
+ or (
1033
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1034
+ and block_idx % 4 == 0
1035
+ )
1036
+ ):
1037
+ # shape: (batch_size, seq_len, d_model)
1038
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1039
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1040
+ )
1041
+ else:
1042
+ # shape: (batch_size, seq_len, d_model)
1043
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1044
+ if attn_key_values is not None:
1045
+ assert cache is not None
1046
+ attn_key_values.append(cache)
1047
+ return x, attn_key_values
1048
+
1049
+ def reset_parameters(self):
1050
+ for block in self:
1051
+ block.reset_parameters()
1052
+
1053
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1054
+ self.activation_checkpointing_strategy = strategy
1055
+ for block in self:
1056
+ block.set_activation_checkpointing(strategy)
1057
+
1058
+
1059
+ class OLMo(nn.Module):
1060
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1061
+ super().__init__()
1062
+ self.config = config
1063
+ self.__cache = BufferCache()
1064
+
1065
+ # Validate config.
1066
+ if self.config.alibi and self.config.flash_attention:
1067
+ raise OLMoConfigurationError("ALiBi is currently not supported with FlashAttention")
1068
+
1069
+ if self.config.alibi and self.config.rope:
1070
+ raise OLMoConfigurationError("ALiBi and RoPE are mutually exclusive")
1071
+
1072
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1073
+ if self.config.embedding_size < self.config.vocab_size:
1074
+ raise OLMoConfigurationError("embedding size should be at least as big as vocab size")
1075
+ elif self.config.embedding_size % 128 != 0:
1076
+ import warnings
1077
+
1078
+ warnings.warn(
1079
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1080
+ )
1081
+
1082
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1083
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1084
+
1085
+ if not (
1086
+ 0 < self.config.block_group_size <= self.config.n_layers
1087
+ and self.config.n_layers % self.config.block_group_size == 0
1088
+ ):
1089
+ raise OLMoConfigurationError("n layers must be divisible by block group size")
1090
+
1091
+ torch.backends.cuda.enable_flash_sdp(self.config.flash_attention)
1092
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1093
+
1094
+ self.transformer = nn.ModuleDict(
1095
+ dict(
1096
+ wte=nn.Embedding(
1097
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1098
+ ),
1099
+ emb_drop=Dropout(config.embedding_dropout),
1100
+ ln_f=LayerNorm.build(config),
1101
+ )
1102
+ )
1103
+
1104
+ blocks = [OLMoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1105
+ if self.config.block_group_size > 1:
1106
+ block_groups = [
1107
+ OLMoBlockGroup(config, i, blocks[i : i + config.block_group_size])
1108
+ for i in range(0, config.n_layers, config.block_group_size)
1109
+ ]
1110
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1111
+ else:
1112
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1113
+
1114
+ if not (self.config.alibi or self.config.rope):
1115
+ self.transformer.update(
1116
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1117
+ )
1118
+ if not config.weight_tying:
1119
+ self.transformer.update(
1120
+ {
1121
+ "ff_out": nn.Linear(
1122
+ config.d_model,
1123
+ config.embedding_size or config.vocab_size,
1124
+ bias=config.include_bias,
1125
+ device=config.init_device,
1126
+ )
1127
+ }
1128
+ )
1129
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1130
+ if init_params and self.config.init_device != "meta":
1131
+ self.reset_parameters()
1132
+ self.__num_fwd_flops: Optional[int] = None
1133
+
1134
+ # Warm up cache.
1135
+ if self.config.alibi:
1136
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1137
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1138
+
1139
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1140
+ self.activation_checkpointing_strategy = strategy
1141
+ if self.config.block_group_size != 1:
1142
+ for block_group in self.transformer.block_groups:
1143
+ block_group.set_activation_checkpointing(strategy)
1144
+ else:
1145
+ for block in self.transformer.blocks:
1146
+ block.set_activation_checkpointing(strategy)
1147
+
1148
+ @property
1149
+ def device(self) -> torch.device:
1150
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1151
+ if device.type == "meta":
1152
+ return _non_meta_init_device(self.config)
1153
+ else:
1154
+ return device
1155
+
1156
+ def reset_parameters(self):
1157
+ log.info("Initializing model parameters...")
1158
+ # Top-level embeddings / linear layers.
1159
+ init_weights(
1160
+ self.config,
1161
+ self.transformer.wte, # type: ignore
1162
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1163
+ type_of_module=ModuleType.emb,
1164
+ )
1165
+ if hasattr(self.transformer, "wpe"):
1166
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1167
+
1168
+ # Top-level layer norm.
1169
+ self.transformer.ln_f.reset_parameters() # type: ignore
1170
+
1171
+ # Output weights.
1172
+ if hasattr(self.transformer, "ff_out"):
1173
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1174
+
1175
+ # Let the blocks handle themselves.
1176
+ if self.config.block_group_size == 1:
1177
+ for block in self.transformer.blocks:
1178
+ block.reset_parameters()
1179
+ else:
1180
+ for block_group in self.transformer.block_groups:
1181
+ block_group.reset_parameters()
1182
+
1183
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1184
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1185
+ -1
1186
+ ] >= seq_len:
1187
+ if alibi_bias.device != device:
1188
+ alibi_bias = alibi_bias.to(device)
1189
+ self.__cache["alibi_attention_bias"] = alibi_bias
1190
+ return alibi_bias
1191
+ with torch.autocast(device.type, enabled=False):
1192
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1193
+ self.__cache["alibi_attention_bias"] = alibi_bias
1194
+ return alibi_bias
1195
+
1196
+ def forward(
1197
+ self,
1198
+ input_ids: torch.LongTensor,
1199
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1200
+ attention_mask: Optional[torch.Tensor] = None,
1201
+ attention_bias: Optional[torch.Tensor] = None,
1202
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1203
+ use_cache: bool = False,
1204
+ last_logits_only: bool = False,
1205
+ output_hidden_states: Optional[bool] = None,
1206
+ ) -> OLMoOutput:
1207
+ """
1208
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1209
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1210
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1211
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1212
+ which input IDs are masked. A `1` value in the mask means that
1213
+ the corresponding input ID should *not* be ignored. A `0` means
1214
+ that the corresponding input ID is masked.
1215
+
1216
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1217
+ library.
1218
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1219
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1220
+ to introduce causal or other biases.
1221
+
1222
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1223
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1224
+ element in the sequence.
1225
+
1226
+ If the tensor is a float tensor, it will just be added to the attention
1227
+ scores before the softmax.
1228
+
1229
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1230
+ :param past_key_values: Pre-computed keys and values for each attention block.
1231
+ Can be used to speed up sequential decoding. The `input_ids` which have
1232
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1233
+ :param use_cache: If `True`, return key and value tensors for each block.
1234
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1235
+ This can speed up decoding when you only care about the next token.
1236
+ """
1237
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1238
+
1239
+ if past_key_values:
1240
+ assert len(past_key_values) == self.config.n_layers
1241
+
1242
+ batch_size, seq_len = input_ids.size() if inputs_embeds is None else inputs_embeds.size()[:2]
1243
+ if past_key_values is None:
1244
+ past_length = 0
1245
+ else:
1246
+ past_length = past_key_values[0][0].size(-2)
1247
+
1248
+ # Get embeddings of input.
1249
+ # shape: (batch_size, seq_len, d_model)
1250
+ x = self.transformer.wte(input_ids) if inputs_embeds is None else inputs_embeds # type: ignore
1251
+
1252
+ if not (self.config.alibi or self.config.rope):
1253
+ # Get positional embeddings.
1254
+ # shape: (1, seq_len)
1255
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1256
+ # shape: (1, seq_len, d_model)
1257
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1258
+ x = pos_emb + x
1259
+
1260
+ # Add input + positional embeddings and apply dropout.
1261
+ # shape: (batch_size, seq_len, d_model)
1262
+ x = self.transformer.emb_drop(x) # type: ignore
1263
+
1264
+ # Transform the attention mask into what the blocks expect.
1265
+ if attention_mask is not None:
1266
+ # shape: (batch_size, 1, 1, seq_len)
1267
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1268
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1269
+
1270
+ # Merge attention mask with attention bias.
1271
+ if (
1272
+ attention_bias is not None
1273
+ or attention_mask is not None
1274
+ or self.config.alibi
1275
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1276
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1277
+ # scores correctly.
1278
+ or past_key_values is not None
1279
+ ):
1280
+ if attention_bias is None and self.config.alibi:
1281
+ attention_bias = get_causal_attention_bias(
1282
+ self.__cache, past_length + seq_len, x.device
1283
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1284
+ elif attention_bias is None:
1285
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1286
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1287
+ attention_bias = attention_bias.to(dtype=torch.float)
1288
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1289
+
1290
+ # Transform to the right shape and data type.
1291
+ mask_len = seq_len
1292
+ if attention_mask is not None:
1293
+ mask_len = attention_mask.shape[-1]
1294
+ elif past_key_values is not None:
1295
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1296
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1297
+
1298
+ # Add in the masking bias.
1299
+ if attention_mask is not None:
1300
+ attention_bias = attention_bias + attention_mask
1301
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1302
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1303
+ # it can produce NaNs.
1304
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1305
+
1306
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1307
+
1308
+ # decoder layers
1309
+ all_hidden_states = []
1310
+
1311
+ # Apply blocks one-by-one.
1312
+ if self.config.block_group_size == 1:
1313
+ for block_idx, block in enumerate(self.transformer.blocks):
1314
+ if output_hidden_states:
1315
+ # add hidden states
1316
+ all_hidden_states.append(x)
1317
+
1318
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1319
+ if (
1320
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1321
+ or (
1322
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1323
+ and block_idx % 2 == 0
1324
+ )
1325
+ or (
1326
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1327
+ and block_idx % 3 == 0
1328
+ )
1329
+ or (
1330
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1331
+ and block_idx % 4 == 0
1332
+ )
1333
+ ):
1334
+ # shape: (batch_size, seq_len, d_model)
1335
+ x, cache = self._activation_checkpoint_fn(
1336
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1337
+ )
1338
+ else:
1339
+ # shape: (batch_size, seq_len, d_model)
1340
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1341
+ if attn_key_values is not None:
1342
+ assert cache is not None
1343
+ attn_key_values.append(cache)
1344
+ else:
1345
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1346
+ if output_hidden_states:
1347
+ # add hidden states
1348
+ all_hidden_states.append(x)
1349
+
1350
+ layers_past = (
1351
+ None
1352
+ if past_key_values is None
1353
+ else past_key_values[
1354
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1355
+ ]
1356
+ )
1357
+ x, cache = block_group(
1358
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1359
+ )
1360
+ if attn_key_values is not None:
1361
+ assert cache is not None
1362
+ attn_key_values.extend(cache)
1363
+
1364
+ if last_logits_only:
1365
+ # shape: (batch_size, 1, d_model)
1366
+ x = x[:, -1, :].unsqueeze(1)
1367
+
1368
+ # Apply final layer norm.
1369
+ # shape: (batch_size, seq_len or 1, d_model)
1370
+ x = self.transformer.ln_f(x) # type: ignore
1371
+ if output_hidden_states:
1372
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1373
+ all_hidden_states.append(x)
1374
+
1375
+ # Get logits.
1376
+ # shape: (batch_size, seq_len or 1, vocab_size)
1377
+ if self.config.weight_tying:
1378
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1379
+ else:
1380
+ logits = self.transformer.ff_out(x) # type: ignore
1381
+ if self.config.scale_logits:
1382
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1383
+
1384
+ return BaseModelOutputWithPast(
1385
+ last_hidden_state=x,
1386
+ past_key_values=tuple(attn_key_values) if attn_key_values is not None else None,
1387
+ hidden_states=tuple(all_hidden_states) if output_hidden_states else None,
1388
+ )
1389
+
1390
+ def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
1391
+ if wrap_strategy is None:
1392
+ return None
1393
+
1394
+ # The 'recurse' mode for the wrap function does not behave like you'd expect.
1395
+ # Even if we return False, it may still recurse because PyTorch does what it wants,
1396
+ # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
1397
+ # but not other linear layers within a block.
1398
+ # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
1399
+ # return True in 'recurse' mode for simplicity.
1400
+ size_based_module_to_wrap = {self.transformer.wte}
1401
+ if hasattr(self.transformer, "ff_out"):
1402
+ size_based_module_to_wrap.add(self.transformer.ff_out)
1403
+
1404
+ if wrap_strategy == FSDPWrapStrategy.by_block:
1405
+
1406
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1407
+ del nonwrapped_numel
1408
+ wrap = isinstance(module, OLMoBlock)
1409
+ if recurse:
1410
+ return True
1411
+ else:
1412
+ return wrap
1413
+
1414
+ return fsdp_wrap_fn
1415
+ elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:
1416
+
1417
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1418
+ del nonwrapped_numel
1419
+ wrap = isinstance(module, (OLMoBlock,)) or module in size_based_module_to_wrap
1420
+ if recurse:
1421
+ return True
1422
+ else:
1423
+ return wrap
1424
+
1425
+ return fsdp_wrap_fn
1426
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group:
1427
+ if self.config.block_group_size <= 1:
1428
+ raise OLMoConfigurationError(
1429
+ "'by_block_group' FSDP wrapping strategy requires block group size greater than 1"
1430
+ )
1431
+
1432
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1433
+ del nonwrapped_numel
1434
+ wrap = isinstance(module, OLMoBlockGroup)
1435
+ if recurse:
1436
+ return True
1437
+ else:
1438
+ return wrap
1439
+
1440
+ return fsdp_wrap_fn
1441
+ elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
1442
+ if self.config.block_group_size <= 1:
1443
+ raise OLMoConfigurationError(
1444
+ "'by_block_group_and_size' FSDP wrapping strategy requires block group size greater than 1"
1445
+ )
1446
+
1447
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1448
+ del nonwrapped_numel
1449
+ wrap = isinstance(module, (OLMoBlockGroup,)) or module in size_based_module_to_wrap
1450
+ if recurse:
1451
+ return True
1452
+ else:
1453
+ return wrap
1454
+
1455
+ return fsdp_wrap_fn
1456
+ elif wrap_strategy == FSDPWrapStrategy.size_based:
1457
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
1458
+
1459
+ return size_based_auto_wrap_policy
1460
+ elif wrap_strategy in {
1461
+ FSDPWrapStrategy.one_in_two,
1462
+ FSDPWrapStrategy.one_in_three,
1463
+ FSDPWrapStrategy.one_in_four,
1464
+ FSDPWrapStrategy.one_in_five,
1465
+ }:
1466
+ c = {
1467
+ FSDPWrapStrategy.one_in_two: 2,
1468
+ FSDPWrapStrategy.one_in_three: 3,
1469
+ FSDPWrapStrategy.one_in_four: 4,
1470
+ FSDPWrapStrategy.one_in_five: 5,
1471
+ }[wrap_strategy]
1472
+
1473
+ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
1474
+ del nonwrapped_numel
1475
+ wrap = isinstance(module, OLMoBlock) and module.layer_id % c == 0
1476
+ if recurse:
1477
+ return True
1478
+ else:
1479
+ return wrap
1480
+
1481
+ return fsdp_wrap_fn
1482
+ else:
1483
+ raise NotImplementedError(wrap_strategy)
1484
+
1485
+ def num_params(self, include_embedding: bool = True) -> int:
1486
+ """
1487
+ Get the total number of parameters.
1488
+ """
1489
+ params = (np for np in self.named_parameters())
1490
+ if not include_embedding:
1491
+ params = filter( # type: ignore
1492
+ lambda np: ".wte." not in np[0] and ".wpe." not in np[0],
1493
+ params,
1494
+ )
1495
+ return sum(p.numel() for _, p in params)
1496
+
1497
+ @property
1498
+ def num_fwd_flops(self):
1499
+ if self.__num_fwd_flops:
1500
+ return self.__num_fwd_flops
1501
+ n_params = self.num_params()
1502
+ # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network
1503
+ # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param
1504
+ # this gets us FLOPs / token
1505
+ params_flops_per_token = 2 * n_params
1506
+ params_flops_per_seq = params_flops_per_token * self.config.max_sequence_length
1507
+ # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)
1508
+ attn_flops_per_seq = (
1509
+ self.config.n_layers * 2 * 2 * (self.config.d_model * (self.config.max_sequence_length**2))
1510
+ )
1511
+ self.__num_fwd_flops = params_flops_per_seq + attn_flops_per_seq
1512
+ return self.__num_fwd_flops
1513
+
1514
+ def generate(
1515
+ self,
1516
+ input_ids: torch.LongTensor,
1517
+ attention_mask: Optional[torch.Tensor] = None,
1518
+ attention_bias: Optional[torch.Tensor] = None,
1519
+ max_steps: int = 10,
1520
+ beam_size: int = 1,
1521
+ per_node_beam_size: Optional[int] = None,
1522
+ sampler: Optional[Sampler] = None,
1523
+ min_steps: Optional[int] = None,
1524
+ final_sequence_scorer: Optional[FinalSequenceScorer] = None,
1525
+ constraints: Optional[List[Constraint]] = None,
1526
+ ) -> OLMoGenerateOutput:
1527
+ """
1528
+ Generate token IDs using beam search.
1529
+
1530
+ Note that by default ``beam_size`` is set to 1, which is greedy decoding.
1531
+
1532
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1533
+ :param attention_mask: A optional tensor of shape `(batch_size, seq_len)`, the same
1534
+ as for the forward method.
1535
+ :param attention_bias: A tensor of shape
1536
+ `(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
1537
+ the same as for the forward method except only one shape is excepted here.
1538
+
1539
+ For an explanation of the other arguments, see :class:`BeamSearch`.
1540
+ """
1541
+ beam_search = BeamSearch(
1542
+ self.config.eos_token_id,
1543
+ max_steps=max_steps,
1544
+ beam_size=beam_size,
1545
+ per_node_beam_size=per_node_beam_size,
1546
+ sampler=sampler,
1547
+ min_steps=min_steps,
1548
+ final_sequence_scorer=final_sequence_scorer,
1549
+ constraints=constraints,
1550
+ )
1551
+
1552
+ # Validate inputs.
1553
+ batch_size, seq_len = input_ids.shape
1554
+ if attention_mask is not None:
1555
+ assert attention_mask.shape == (batch_size, seq_len)
1556
+ if attention_bias is not None:
1557
+ assert len(attention_bias.shape) == 4
1558
+ assert attention_bias.shape[:2] == (batch_size, 1)
1559
+ assert (
1560
+ seq_len + beam_search.max_steps
1561
+ <= attention_bias.shape[2]
1562
+ == attention_bias.shape[3]
1563
+ <= self.config.max_sequence_length
1564
+ )
1565
+
1566
+ tokens_generated = 0
1567
+
1568
+ def flatten_past_key_values(
1569
+ past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
1570
+ ) -> Dict[str, torch.Tensor]:
1571
+ out = {}
1572
+ for i, (key, value) in enumerate(past_key_values):
1573
+ out[f"past_key_{i}"] = key
1574
+ out[f"past_value_{i}"] = value
1575
+ return out
1576
+
1577
+ def unflatten_past_key_values(
1578
+ past_key_values: Dict[str, torch.Tensor],
1579
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
1580
+ out = []
1581
+ for i in range(self.config.n_layers):
1582
+ past_key = past_key_values[f"past_key_{i}"]
1583
+ past_value = past_key_values[f"past_value_{i}"]
1584
+ out.append((past_key, past_value))
1585
+ return out
1586
+
1587
+ def step(
1588
+ last_predictions: torch.Tensor, state: dict[str, torch.Tensor]
1589
+ ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
1590
+ nonlocal tokens_generated
1591
+
1592
+ attention_mask = state.get("attention_mask")
1593
+ attention_bias = state.get("attention_bias")
1594
+
1595
+ if tokens_generated > 0:
1596
+ past_key_values = unflatten_past_key_values(state)
1597
+ input_ids = last_predictions.unsqueeze(1)
1598
+ if attention_mask is not None:
1599
+ group_size = input_ids.shape[0]
1600
+ attention_mask = torch.cat((attention_mask, attention_mask.new_ones((group_size, 1))), dim=-1)
1601
+ else:
1602
+ past_key_values = None
1603
+ input_ids = state["input_ids"]
1604
+
1605
+ tokens_generated += 1
1606
+
1607
+ # Run forward pass of model to get logits, then normalize to get log probs.
1608
+ output = self(
1609
+ input_ids,
1610
+ attention_mask=attention_mask,
1611
+ attention_bias=attention_bias,
1612
+ past_key_values=past_key_values,
1613
+ use_cache=True,
1614
+ last_logits_only=True,
1615
+ )
1616
+ log_probs = F.log_softmax(output.logits[:, -1, :], dim=-1)
1617
+
1618
+ # Create new state.
1619
+ state = flatten_past_key_values(output.attn_key_values)
1620
+ if attention_mask is not None:
1621
+ state["attention_mask"] = attention_mask
1622
+ if attention_bias is not None:
1623
+ state["attention_bias"] = attention_bias
1624
+
1625
+ return log_probs, state
1626
+
1627
+ initial_preds = input_ids.new_zeros((batch_size,)) # This is arbitrary, we won't use this.
1628
+ state: dict[str, torch.Tensor] = {"input_ids": input_ids}
1629
+ if attention_mask is not None:
1630
+ state["attention_mask"] = attention_mask
1631
+ if attention_bias is not None:
1632
+ state["attention_bias"] = attention_bias
1633
+ with torch.no_grad():
1634
+ token_ids, scores = beam_search.search(initial_preds, state, step)
1635
+
1636
+ return OLMoGenerateOutput(
1637
+ token_ids=token_ids, # type: ignore[arg-type]
1638
+ scores=scores, # type: ignore[arg-type]
1639
+ )
1640
+
1641
+ @classmethod
1642
+ def from_checkpoint(
1643
+ cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
1644
+ ) -> OLMo:
1645
+ """
1646
+ Load an OLMo model from a checkpoint.
1647
+ """
1648
+ from .util import resource_path
1649
+
1650
+ # Guess checkpoint type.
1651
+ if checkpoint_type is None:
1652
+ try:
1653
+ if resource_path(checkpoint_dir, "model.pt").is_file():
1654
+ checkpoint_type = CheckpointType.unsharded
1655
+ else:
1656
+ checkpoint_type = CheckpointType.sharded
1657
+ except FileNotFoundError:
1658
+ checkpoint_type = CheckpointType.sharded
1659
+
1660
+ # Load config.
1661
+ config_path = resource_path(checkpoint_dir, "config.yaml")
1662
+ model_config = ModelConfig.load(config_path, key="model", validate_paths=False)
1663
+
1664
+ if checkpoint_type == CheckpointType.unsharded:
1665
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1666
+ model_config.init_device = "cpu"
1667
+ model = OLMo(model_config)
1668
+
1669
+ # Load state dict directly to target device.
1670
+ state_dict_path = resource_path(checkpoint_dir, "model.pt")
1671
+ state_dict = torch.load(state_dict_path, map_location="cpu")
1672
+ model.load_state_dict(model._make_state_dict_compatible(state_dict)[0])
1673
+ model = model.to(torch.device(device))
1674
+ else:
1675
+ from .checkpoint import load_model_state
1676
+
1677
+ # Initialize model on target device. In this case the state dict is loaded in-place
1678
+ # so it's not necessary to start on CPU if the target device is a GPU.
1679
+ model_config.init_device = device
1680
+ model = OLMo(model_config)
1681
+
1682
+ # Load state dict in place.
1683
+ load_model_state(checkpoint_dir, model)
1684
+
1685
+ return model.eval()
1686
+
1687
+ def _make_state_dict_compatible(
1688
+ self, state_dict: Dict[str, torch.Tensor]
1689
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Set[str]]]:
1690
+ """
1691
+ Handles some cases where the state dict is valid yet may need to be transformed in order to
1692
+ be loaded.
1693
+
1694
+ This modifies the state dict in-place and also returns it, along with a mapping of original key
1695
+ names to new key names in cases where the keys were simply renamed. That mapping can be used
1696
+ to make a corresponding optimizer state dict compatible as well.
1697
+ """
1698
+ import re
1699
+ from fnmatch import fnmatch
1700
+
1701
+ new_keys_to_og_keys: Dict[str, str] = {}
1702
+
1703
+ # Remove "_fsdp_wrapped_module." prefix from all keys. We don't want this prefix when the model is
1704
+ # not wrapped in FSDP. And when the model is wrapped in FSDP, loading this state dict will still work
1705
+ # fine without the prefixes. This also simplifies the other steps below.
1706
+ for key in list(state_dict.keys()):
1707
+ state_dict[(new_key := key.replace("_fsdp_wrapped_module.", ""))] = state_dict.pop(key)
1708
+ new_keys_to_og_keys[new_key] = key
1709
+
1710
+ # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
1711
+ if self.config.block_type == BlockType.sequential:
1712
+ for key in list(state_dict.keys()):
1713
+ if fnmatch(key, "transformer.*.norm.weight"):
1714
+ tensor = state_dict.pop(key)
1715
+ state_dict[(new_key := key.replace("norm.weight", "attn_norm.weight"))] = tensor
1716
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1717
+ state_dict[(new_key := key.replace("norm.weight", "ff_norm.weight"))] = tensor.clone()
1718
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1719
+ del new_keys_to_og_keys[key]
1720
+ elif fnmatch(key, "transformer.*.norm.bias"):
1721
+ tensor = state_dict.pop(key)
1722
+ state_dict[(new_key := key.replace("norm.bias", "attn_norm.bias"))] = tensor
1723
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1724
+ state_dict[(new_key := key.replace("norm.bias", "ff_norm.bias"))] = tensor.clone()
1725
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys[key]
1726
+ del new_keys_to_og_keys[key]
1727
+
1728
+ # For loading a state dict that was saved with a different `block_group_size`.
1729
+ if "transformer.block_groups.0.0.attn_out.weight" in state_dict.keys():
1730
+ state_dict_block_group_size = len(
1731
+ [k for k in state_dict.keys() if fnmatch(k, "transformer.block_groups.0.*.attn_out.weight")]
1732
+ )
1733
+ else:
1734
+ state_dict_block_group_size = 1
1735
+ if self.config.block_group_size != state_dict_block_group_size:
1736
+ log.info(
1737
+ f"Regrouping state dict blocks from group size {state_dict_block_group_size} to "
1738
+ f"group size {self.config.block_group_size}"
1739
+ )
1740
+ # For simplicity we're first going to flatten out the block groups in the state dict (if necessary)
1741
+ # and then (re-)group them into the right block sizes.
1742
+ if state_dict_block_group_size > 1:
1743
+ for key in list(state_dict.keys()):
1744
+ if (m := re.match(r"transformer.block_groups\.(\d+)\.(\d+)\..*", key)) is not None:
1745
+ group_idx, group_block_idx = int(m.group(1)), int(m.group(2))
1746
+ block_idx = (group_idx * state_dict_block_group_size) + group_block_idx
1747
+ state_dict[
1748
+ (
1749
+ new_key := key.replace(
1750
+ f"block_groups.{group_idx}.{group_block_idx}.", f"blocks.{block_idx}."
1751
+ )
1752
+ )
1753
+ ] = state_dict.pop(key)
1754
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1755
+
1756
+ if self.config.block_group_size > 1:
1757
+ # Group the state dict blocks into the right block size.
1758
+ for key in list(state_dict.keys()):
1759
+ if (m := re.match(r"transformer.blocks\.(\d+)\..*", key)) is not None:
1760
+ block_idx = int(m.group(1))
1761
+ group_idx, group_block_idx = (
1762
+ block_idx // self.config.block_group_size,
1763
+ block_idx % self.config.block_group_size,
1764
+ )
1765
+ state_dict[
1766
+ (
1767
+ new_key := key.replace(
1768
+ f"blocks.{block_idx}.", f"block_groups.{group_idx}.{group_block_idx}."
1769
+ )
1770
+ )
1771
+ ] = state_dict.pop(key)
1772
+ new_keys_to_og_keys[new_key] = new_keys_to_og_keys.pop(key)
1773
+
1774
+ og_keys_to_new: Dict[str, Set[str]] = defaultdict(set)
1775
+ for new_key, og_key in new_keys_to_og_keys.items():
1776
+ og_keys_to_new[og_key].add(new_key)
1777
+
1778
+ return state_dict, og_keys_to_new
modeling_olmo.py CHANGED
@@ -2,16 +2,17 @@ from dataclasses import fields
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
 
 
5
  from transformers import PreTrainedModel
6
- from transformers.modeling_outputs import CausalLMOutputWithPast
7
  from transformers.models.auto import AutoModelForCausalLM
8
 
9
- from olmo.config import ModelConfig
10
- from olmo.model import OLMo
11
 
12
  from .configuration_olmo import OLMoConfig
13
 
14
-
15
  def create_model_config_from_pretrained_config(config: OLMoConfig):
16
  """
17
  Utility function
@@ -24,26 +25,52 @@ def create_model_config_from_pretrained_config(config: OLMoConfig):
24
  model_config = ModelConfig(**kwargs)
25
  return model_config
26
 
27
-
28
- class OLMoForCausalLM(PreTrainedModel):
29
- """
30
- Extremely barebones HF model wrapper.
31
- """
32
-
33
  config_class = OLMoConfig
34
  base_model_prefix = "model"
35
  _no_split_modules = ["OLMoBlock"]
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
38
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
39
 
40
- if not model:
41
- model_config = create_model_config_from_pretrained_config(config)
42
- # Initialize model (always on CPU to start with so we don't run out of GPU memory).
43
- model_config.init_device = "cpu"
44
- self.model = OLMo(model_config, init_params=init_params)
45
  else:
46
- self.model = model
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def forward(
49
  self,
@@ -58,18 +85,36 @@ class OLMoForCausalLM(PreTrainedModel):
58
  output_hidden_states: Optional[bool] = None,
59
  return_dict: Optional[bool] = None,
60
  ) -> Union[Tuple, CausalLMOutputWithPast]:
61
- if use_cache is None:
62
- use_cache = self.config.use_cache
63
-
64
- if output_attentions:
65
- raise ValueError("output_attentions is not yet supported in OLMo")
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
68
 
 
 
69
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
70
- outputs = self.model.forward(
71
  input_ids=input_ids,
72
- input_embeddings=inputs_embeds,
73
  attention_mask=attention_mask,
74
  attention_bias=attention_bias,
75
  past_key_values=past_key_values,
@@ -77,8 +122,16 @@ class OLMoForCausalLM(PreTrainedModel):
77
  output_hidden_states=output_hidden_states,
78
  )
79
 
80
- logits = outputs.logits
81
- hidden_states = outputs.hidden_states
 
 
 
 
 
 
 
 
82
 
83
  loss = None
84
  if labels is not None:
@@ -87,26 +140,25 @@ class OLMoForCausalLM(PreTrainedModel):
87
  shift_labels = labels[..., 1:].contiguous()
88
  # Flatten the tokens
89
  loss_fct = torch.nn.CrossEntropyLoss()
90
- shift_logits = shift_logits.view(-1, self.config.embedding_size)
91
  shift_labels = shift_labels.view(-1)
92
  # Enable model parallelism
93
  shift_labels = shift_labels.to(shift_logits.device)
94
  loss = loss_fct(shift_logits, shift_labels)
95
 
96
  if not return_dict:
97
- output = (logits,) + outputs[1:]
98
  return (loss,) + output if loss is not None else output
99
 
 
100
  return CausalLMOutputWithPast(
101
  loss=loss,
102
  logits=logits,
103
- past_key_values=outputs.attn_key_values,
104
- hidden_states=hidden_states,
 
105
  )
106
 
107
- def can_generate(self) -> bool:
108
- return True
109
-
110
  def prepare_inputs_for_generation(
111
  self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
112
  ):
@@ -115,42 +167,20 @@ class OLMoForCausalLM(PreTrainedModel):
115
  input_ids = input_ids[:, -1:]
116
  model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
117
 
 
118
  model_inputs.update(kwargs)
119
- model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
 
120
  return model_inputs
121
 
122
- # TODO: these are required to make the implementation complete.
123
- # def resize_position_embeddings(self, new_num_position_embeddings: int):
124
- # pass
125
- #
126
- # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
127
- # pass
128
- #
129
- # def _reorder_cache(self, past_key_values, beam_idx):
130
- # pass
131
-
132
- def get_input_embeddings(self) -> torch.nn.Module:
133
- return self.model.transformer.wte
134
-
135
- def set_input_embeddings(self, value: torch.nn.Module):
136
- self.model.transformer.wte = value
137
-
138
- def get_output_embeddings(self):
139
- if self.config.weight_tying:
140
- return self.model.transformer.wte
141
- else:
142
- return self.model.transformer.ff_out
143
-
144
- def set_output_embeddings(self, value: torch.nn.Module):
145
- if self.config.weight_tying:
146
- self.model.transformer.wte = value
147
- else:
148
- self.model.transformer.ff_out = value
149
-
150
- def tie_weights(self):
151
- if self.config.weight_tying:
152
- self.model.transformer.ff_out = self.model.transformer.wte
153
-
154
 
155
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
156
  AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
 
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
+ import torch.nn.functional as F
6
+ import math
7
  from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
9
  from transformers.models.auto import AutoModelForCausalLM
10
 
11
+ from .config import ModelConfig
12
+ from .model import OLMo
13
 
14
  from .configuration_olmo import OLMoConfig
15
 
 
16
  def create_model_config_from_pretrained_config(config: OLMoConfig):
17
  """
18
  Utility function
 
25
  model_config = ModelConfig(**kwargs)
26
  return model_config
27
 
28
+ class OLMoPreTrainedModel(PreTrainedModel):
 
 
 
 
 
29
  config_class = OLMoConfig
30
  base_model_prefix = "model"
31
  _no_split_modules = ["OLMoBlock"]
32
+ # _skip_keys_device_placement = ["past_key_values", "causal_mask"]
33
+ _skip_keys_device_placement = ["past_key_values"]
34
+
35
+ def _init_weights(self, module):
36
+ # `OLMoModel.reset_parameters` initializes weights of itself and its children
37
+ if isinstance(module, OLMo):
38
+ module.reset_parameters()
39
+
40
+ class OLMoForCausalLM(OLMoPreTrainedModel):
41
+ _tied_weights_keys = []
42
+ # _tied_weights_keys = ["transformer.wte.weight"]
43
 
44
+ def __init__(self, config: OLMoConfig):
45
  super().__init__(config)
46
+ self.model = OLMo(config)
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self) -> torch.nn.Module:
52
+ return self.model.transformer.wte
53
+
54
+ def set_input_embeddings(self, value: torch.nn.Module):
55
+ self.model.transformer.wte = value
56
 
57
+ def get_output_embeddings(self):
58
+ if self.config.weight_tying:
59
+ return self.model.transformer.wte
 
 
60
  else:
61
+ return self.model.transformer.ff_out
62
+
63
+ def set_output_embeddings(self, value: torch.nn.Module):
64
+ if self.config.weight_tying:
65
+ self.model.transformer.wte = value
66
+ else:
67
+ self.model.transformer.ff_out = value
68
+
69
+ def set_decoder(self, decoder):
70
+ self.model = decoder
71
+
72
+ def get_decoder(self):
73
+ return self.model
74
 
75
  def forward(
76
  self,
 
85
  output_hidden_states: Optional[bool] = None,
86
  return_dict: Optional[bool] = None,
87
  ) -> Union[Tuple, CausalLMOutputWithPast]:
88
+ r"""
89
+ Args:
90
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
91
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
92
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
93
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
94
+ Returns:
95
+ Example:
96
+ ```python
97
+ >>> from transformers import AutoTokenizer, OLMoForCausalLM
98
+ >>> model = OLMoForCausalLM.from_pretrained("allenai/OLMo-7B")
99
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B")
100
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
101
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
102
+ >>> # Generate
103
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
104
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
105
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
106
+ ```"""
107
+ output_attentions = output_attentions or self.config.output_attentions
108
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
109
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
110
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
111
 
112
+ assert not output_attentions
113
+
114
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
115
+ base_output: Union[BaseModelOutputWithPast, Tuple] = self.model.forward(
116
  input_ids=input_ids,
117
+ inputs_embeds=inputs_embeds,
118
  attention_mask=attention_mask,
119
  attention_bias=attention_bias,
120
  past_key_values=past_key_values,
 
122
  output_hidden_states=output_hidden_states,
123
  )
124
 
125
+ last_hidden_state = base_output.last_hidden_state if return_dict else base_output[0]
126
+
127
+ # Get logits.
128
+ # shape: (batch_size, seq_len or 1, vocab_size)
129
+ if self.config.weight_tying:
130
+ logits = F.linear(last_hidden_state, self.model.transformer.wte.weight, None) # type: ignore
131
+ else:
132
+ logits = self.model.transformer.ff_out(last_hidden_state) # type: ignore
133
+ if self.config.scale_logits:
134
+ logits.mul_(1 / math.sqrt(self.config.d_model))
135
 
136
  loss = None
137
  if labels is not None:
 
140
  shift_labels = labels[..., 1:].contiguous()
141
  # Flatten the tokens
142
  loss_fct = torch.nn.CrossEntropyLoss()
143
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
144
  shift_labels = shift_labels.view(-1)
145
  # Enable model parallelism
146
  shift_labels = shift_labels.to(shift_logits.device)
147
  loss = loss_fct(shift_logits, shift_labels)
148
 
149
  if not return_dict:
150
+ output = (logits,) + base_output[1:]
151
  return (loss,) + output if loss is not None else output
152
 
153
+ assert isinstance(base_output, BaseModelOutputWithPast)
154
  return CausalLMOutputWithPast(
155
  loss=loss,
156
  logits=logits,
157
+ past_key_values=base_output.past_key_values,
158
+ hidden_states=base_output.hidden_states,
159
+ attentions=base_output.attentions,
160
  )
161
 
 
 
 
162
  def prepare_inputs_for_generation(
163
  self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
164
  ):
 
167
  input_ids = input_ids[:, -1:]
168
  model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
169
 
170
+ kwargs.pop("cache_position")
171
  model_inputs.update(kwargs)
172
+ # logger.warning("%s %s", kwargs.keys(), model_inputs.keys())
173
+ # model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
174
  return model_inputs
175
 
176
+ @staticmethod
177
+ def _reorder_cache(past_key_values, beam_idx):
178
+ reordered_past = ()
179
+ for layer_past in past_key_values:
180
+ reordered_past += (
181
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
182
+ )
183
+ return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Register the model so that it is available for transformer pipelines, auto-loading, etc.
186
  AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
tokenization_olmo_fast.py DELETED
@@ -1,16 +0,0 @@
1
- from transformers import AutoTokenizer, PreTrainedTokenizerFast
2
-
3
- from hf_olmo.configuration_olmo import OLMoConfig
4
-
5
-
6
- class OLMoTokenizerFast(PreTrainedTokenizerFast):
7
- # Note: OLMo's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
8
- pass
9
-
10
- # def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
11
- # # This is required to make the implementation complete.
12
- # pass
13
-
14
-
15
- # Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
16
- AutoTokenizer.register(OLMoConfig, fast_tokenizer_class=OLMoTokenizerFast)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import List, Optional, Union
6
+
7
+ from tokenizers import Tokenizer as BaseTokenizer
8
+
9
+ from .aliases import PathOrStr
10
+ from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection
11
+ from .exceptions import OLMoConfigurationError
12
+
13
+ __all__ = ["Tokenizer"]
14
+
15
+
16
+ class Tokenizer:
17
+ """
18
+ A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`.
19
+
20
+ :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use.
21
+ :param eos_token_id: The token ID corresponding to the "end-of-sentence" token.
22
+ :param truncate_to: Truncate when tokenizing to this number of token IDs.
23
+ :param truncate_direction: The direction to truncate in. "right" means truncate the tokens
24
+ on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null,
25
+ this setting has no effect.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ base_tokenizer: BaseTokenizer,
31
+ eos_token_id: int,
32
+ pad_token_id: Optional[int] = None,
33
+ truncate_to: Optional[int] = None,
34
+ truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
35
+ ):
36
+ self.base_tokenizer = base_tokenizer
37
+ self.base_tokenizer.no_truncation()
38
+ self.eos_token_id = eos_token_id
39
+ self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id
40
+ self.truncate_to = truncate_to
41
+ self.truncate_direction = TruncationDirection(truncate_direction)
42
+
43
+ @property
44
+ def vocab_size(self) -> int:
45
+ return self.base_tokenizer.get_vocab_size()
46
+
47
+ @property
48
+ def eos_token(self) -> str:
49
+ return self.decode([self.eos_token_id], skip_special_tokens=False)
50
+
51
+ @property
52
+ def pad_token(self) -> str:
53
+ return self.decode([self.pad_token_id], skip_special_tokens=False)
54
+
55
+ @classmethod
56
+ def from_train_config(cls, config: TrainConfig) -> Tokenizer:
57
+ tokenizer_identifier = config.tokenizer.identifier
58
+ if Path(tokenizer_identifier).is_file():
59
+ tokenizer = cls.from_file(
60
+ tokenizer_identifier,
61
+ eos_token_id=config.model.eos_token_id,
62
+ pad_token_id=config.model.pad_token_id,
63
+ )
64
+ else:
65
+ tokenizer = cls.from_pretrained(
66
+ tokenizer_identifier,
67
+ eos_token_id=config.model.eos_token_id,
68
+ pad_token_id=config.model.pad_token_id,
69
+ )
70
+ if config.model.vocab_size != tokenizer.vocab_size:
71
+ raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
72
+ return tokenizer
73
+
74
+ @classmethod
75
+ def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer:
76
+ """
77
+ Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub.
78
+
79
+ :param identifier: The identifier of a model on the Hub that contains a
80
+ ``tokenizer.json`` file.
81
+ :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
82
+ """
83
+ base_tokenizer = BaseTokenizer.from_pretrained(identifier)
84
+ eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
85
+ return cls(base_tokenizer, eos_token_id, **kwargs)
86
+
87
+ @classmethod
88
+ def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer:
89
+ """
90
+ Initialize a tokenizer from a file.
91
+
92
+ You can create those files with ``BaseTokenizer.save()``.
93
+
94
+ :param filename: The name of a file containing a tokenizer specification.
95
+ :param kwargs: Other key word arguments passed to :class:`Tokenizer`.
96
+ """
97
+ base_tokenizer = BaseTokenizer.from_file(filename)
98
+ eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
99
+ return cls(base_tokenizer, eos_token_id, **kwargs)
100
+
101
+ @classmethod
102
+ def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer:
103
+ """
104
+ Load a tokenizer from a checkpoint.
105
+ """
106
+ from cached_path import cached_path
107
+
108
+ # Load configs.
109
+ config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
110
+ tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer")
111
+ model_config = ModelConfig.load(config_path, key="model")
112
+
113
+ # Initialize tokenizer and validate vocab size.
114
+ if Path(tokenizer_config.identifier).is_file():
115
+ tokenizer = cls.from_file(
116
+ tokenizer_config.identifier,
117
+ eos_token_id=model_config.eos_token_id,
118
+ pad_token_id=model_config.pad_token_id,
119
+ )
120
+ else:
121
+ tokenizer = cls.from_pretrained(
122
+ tokenizer_config.identifier,
123
+ eos_token_id=model_config.eos_token_id,
124
+ pad_token_id=model_config.pad_token_id,
125
+ )
126
+ if model_config.vocab_size != tokenizer.vocab_size:
127
+ raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
128
+ return tokenizer
129
+
130
+ def add_special_tokens(self, input_ids: List[int]) -> List[int]:
131
+ """
132
+ Add special tokens in-place (if not already present) to the given token IDs.
133
+ """
134
+ if not input_ids or input_ids[-1] != self.eos_token_id:
135
+ input_ids.append(self.eos_token_id)
136
+ return input_ids
137
+
138
+ def num_special_tokens_to_add(self, is_pair: bool = False) -> int:
139
+ return 2 if is_pair else 1
140
+
141
+ def _truncate(
142
+ self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection
143
+ ) -> list[int]:
144
+ if truncate_to is None or len(input_ids) <= truncate_to:
145
+ return input_ids
146
+ elif direction == TruncationDirection.left:
147
+ return input_ids[len(input_ids) - truncate_to :]
148
+ else:
149
+ return input_ids[: -(len(input_ids) - truncate_to)]
150
+
151
+ def encode(self, input: str, add_special_tokens: bool = True) -> List[int]:
152
+ """
153
+ Encode a string into token IDs.
154
+ """
155
+ return self.encode_batch([input], add_special_tokens=add_special_tokens)[0]
156
+
157
+ def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]:
158
+ """
159
+ Encode a batch of strings into token IDs.
160
+ """
161
+ truncate_to = self.truncate_to
162
+ if truncate_to is not None and add_special_tokens:
163
+ truncate_to -= self.num_special_tokens_to_add(False)
164
+
165
+ batch_encoding = self.base_tokenizer.encode_batch(inputs)
166
+
167
+ all_input_ids = []
168
+ for encoding in batch_encoding:
169
+ input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction)
170
+ if add_special_tokens:
171
+ input_ids = self.add_special_tokens(input_ids)
172
+ all_input_ids.append(input_ids)
173
+
174
+ return all_input_ids
175
+
176
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
177
+ """
178
+ Decode a list of token IDs to a string.
179
+ """
180
+ return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
tokenizer_config.json CHANGED
@@ -226,10 +226,13 @@
226
  }
227
  },
228
  "clean_up_tokenization_spaces": true,
229
- "eos_token": "|||IP_ADDRESS|||",
230
  "max_length": null,
231
  "model_max_length": 1000000000000000019884624838656,
232
  "pad_token": "<|padding|>",
233
- "tokenizer_class": "OLMoTokenizer",
234
- "truncation": "right"
 
 
 
235
  }
 
226
  }
227
  },
228
  "clean_up_tokenization_spaces": true,
229
+ "eos_token": "<|endoftext|>",
230
  "max_length": null,
231
  "model_max_length": 1000000000000000019884624838656,
232
  "pad_token": "<|padding|>",
233
+ "tokenizer_class": "PreTrainedTokenizerFast",
234
+ "truncation": "right",
235
+ "auto_map": {
236
+ "AutoConfig": "configuration_olmo.OLMoConfig"
237
+ }
238
  }
torch_util.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import Optional, TypeVar
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ def seed_all(seed: int):
12
+ """Seed all rng objects."""
13
+ import random
14
+
15
+ import numpy as np
16
+
17
+ if seed < 0 or seed > 2**32 - 1:
18
+ raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]")
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ # torch.manual_seed may call manual_seed_all but calling it again here
23
+ # to make sure it gets called at least once
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ def is_distributed() -> bool:
28
+ return dist.is_available() and dist.is_initialized()
29
+
30
+
31
+ def get_node_rank() -> int:
32
+ return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size())
33
+
34
+
35
+ def get_world_size() -> int:
36
+ if is_distributed():
37
+ return dist.get_world_size()
38
+ else:
39
+ return 1
40
+
41
+
42
+ def get_local_world_size() -> int:
43
+ return int(os.environ.get("LOCAL_WORLD_SIZE") or 1)
44
+
45
+
46
+ def get_global_rank() -> int:
47
+ return int(os.environ.get("RANK") or dist.get_rank())
48
+
49
+
50
+ def get_local_rank() -> int:
51
+ return int(os.environ.get("LOCAL_RANK") or 0)
52
+
53
+
54
+ def get_fs_local_rank() -> int:
55
+ """Get the local rank per filesystem, meaning that, regardless of the number of nodes,
56
+ if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
57
+ but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
58
+ """
59
+ return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
60
+
61
+
62
+ def move_to_device(o: T, device: torch.device) -> T:
63
+ if isinstance(o, torch.Tensor):
64
+ return o.to(device) # type: ignore[return-value]
65
+ elif isinstance(o, dict):
66
+ return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value]
67
+ elif isinstance(o, list):
68
+ return [move_to_device(x, device) for x in o] # type: ignore[return-value]
69
+ elif isinstance(o, tuple):
70
+ return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value]
71
+ else:
72
+ return o
73
+
74
+
75
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
76
+ """
77
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
78
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
79
+ """
80
+ if check_neg_inf:
81
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
82
+ if check_pos_inf:
83
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
84
+
85
+
86
+ def get_default_device() -> torch.device:
87
+ if torch.cuda.is_available() and torch.cuda.is_initialized():
88
+ return torch.device("cuda")
89
+ else:
90
+ return torch.device("cpu")
91
+
92
+
93
+ def barrier() -> None:
94
+ if is_distributed():
95
+ dist.barrier()
96
+
97
+
98
+ def peak_gpu_memory(reset: bool = False) -> Optional[float]:
99
+ """
100
+ Get the peak GPU memory usage in MB across all ranks.
101
+ Only rank 0 will get the final result.
102
+ """
103
+ if not torch.cuda.is_available():
104
+ return None
105
+
106
+ device = torch.device("cuda")
107
+ peak_mb = torch.cuda.max_memory_allocated(device) / 1000000
108
+ if is_distributed():
109
+ peak_mb_tensor = torch.tensor(peak_mb, device=device)
110
+ dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX)
111
+ peak_mb = peak_mb_tensor.item()
112
+
113
+ if reset:
114
+ # Reset peak stats.
115
+ torch.cuda.reset_max_memory_allocated(device)
116
+
117
+ return peak_mb
118
+
119
+
120
+ V = TypeVar("V", bool, int, float)
121
+
122
+
123
+ def synchronize_value(value: V, device: torch.device) -> V:
124
+ if dist.is_available() and dist.is_initialized():
125
+ value_tensor = torch.tensor(value, device=device)
126
+ dist.broadcast(value_tensor, 0)
127
+ return value_tensor.item() # type: ignore
128
+ else:
129
+ return value
130
+
131
+
132
+ def synchronize_flag(flag: bool, device: torch.device) -> bool:
133
+ return synchronize_value(flag, device)
134
+
135
+
136
+ def gc_cuda():
137
+ gc.collect()
138
+ if torch.cuda.is_available():
139
+ torch.cuda.empty_cache()
util.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ import socket
5
+ import sys
6
+ import time
7
+ import warnings
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from itertools import cycle, islice
11
+ from pathlib import Path
12
+ from queue import Queue
13
+ from threading import Thread
14
+ from typing import Any, Callable, Dict, Optional, Union
15
+
16
+ import boto3
17
+ import botocore.exceptions as boto_exceptions
18
+ import rich
19
+ from botocore.config import Config
20
+ from rich.console import Console, ConsoleRenderable
21
+ from rich.highlighter import NullHighlighter
22
+ from rich.progress import Progress
23
+ from rich.text import Text
24
+ from rich.traceback import Traceback
25
+
26
+ from .aliases import PathOrStr
27
+ from .exceptions import (
28
+ OLMoCliError,
29
+ OLMoEnvironmentError,
30
+ OLMoError,
31
+ OLMoNetworkError,
32
+ OLMoThreadError,
33
+ )
34
+ from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed
35
+
36
+ try:
37
+ from functools import cache
38
+ except ImportError:
39
+ from functools import lru_cache as cache
40
+
41
+
42
+ class StrEnum(str, Enum):
43
+ """
44
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
45
+ We include this here for compatibility with older version of Python.
46
+ """
47
+
48
+ def __str__(self) -> str:
49
+ return self.value
50
+
51
+ def __repr__(self) -> str:
52
+ return f"'{str(self)}'"
53
+
54
+
55
+ _log_extra_fields: Dict[str, Any] = {}
56
+ log = logging.getLogger(__name__)
57
+
58
+
59
+ class LogFilterType(StrEnum):
60
+ rank0_only = "rank0_only"
61
+ local_rank0_only = "local_rank0_only"
62
+ all_ranks = "all_ranks"
63
+
64
+
65
+ def log_extra_field(field_name: str, field_value: Any) -> None:
66
+ global _log_extra_fields
67
+ if field_value is None:
68
+ if field_name in _log_extra_fields:
69
+ del _log_extra_fields[field_name]
70
+ else:
71
+ _log_extra_fields[field_name] = field_value
72
+
73
+
74
+ def setup_logging(log_filter_type: LogFilterType = LogFilterType.rank0_only) -> None:
75
+ """
76
+ :param rank0_only: INFO and below messages will only be emitted on the rank0 process.
77
+ """
78
+ log_extra_field("hostname", socket.gethostname())
79
+ if is_distributed():
80
+ log_extra_field("node_rank", get_node_rank())
81
+ log_extra_field("local_rank", get_local_rank())
82
+ log_extra_field("global_rank", get_global_rank())
83
+ else:
84
+ log_extra_field("node_rank", 0)
85
+ log_extra_field("local_rank", 0)
86
+ log_extra_field("global_rank", 0)
87
+
88
+ old_log_record_factory = logging.getLogRecordFactory()
89
+
90
+ def log_record_factory(*args, **kwargs) -> logging.LogRecord:
91
+ record = old_log_record_factory(*args, **kwargs)
92
+ for field_name, field_value in _log_extra_fields.items():
93
+ setattr(record, field_name, field_value)
94
+ return record
95
+
96
+ logging.setLogRecordFactory(log_record_factory)
97
+
98
+ handler: logging.Handler
99
+ if (
100
+ os.environ.get("OLMo_NONINTERACTIVE", False)
101
+ or os.environ.get("DEBIAN_FRONTEND", None) == "noninteractive"
102
+ or not sys.stdout.isatty()
103
+ ):
104
+ handler = logging.StreamHandler(sys.stdout)
105
+ formatter = logging.Formatter(
106
+ "%(asctime)s\t%(hostname)s:%(local_rank)s\t%(name)s:%(lineno)s\t%(levelname)s\t%(message)s"
107
+ )
108
+ formatter.default_time_format = "%Y-%m-%d %H:%M:%S"
109
+ formatter.default_msec_format = "%s.%03d"
110
+ handler.setFormatter(formatter)
111
+ else:
112
+ handler = RichHandler()
113
+
114
+ def rank0_filter(record: logging.LogRecord) -> int:
115
+ if record.levelno > logging.INFO:
116
+ return 1
117
+ if getattr(record, "global_rank", 0) == 0:
118
+ return 1
119
+ else:
120
+ return 0
121
+
122
+ def local_rank0_filter(record: logging.LogRecord) -> int:
123
+ if record.levelno > logging.INFO:
124
+ return 1
125
+ if getattr(record, "local_rank", 0) == 0:
126
+ return 1
127
+ else:
128
+ return 0
129
+
130
+ if log_filter_type == LogFilterType.rank0_only:
131
+ filter = rank0_filter
132
+ elif log_filter_type == LogFilterType.local_rank0_only:
133
+ filter = local_rank0_filter # type: ignore
134
+ elif log_filter_type == LogFilterType.all_ranks:
135
+ filter = None
136
+ else:
137
+ raise ValueError(log_filter_type)
138
+
139
+ if filter is not None:
140
+ handler.addFilter(filter) # type: ignore
141
+ logging.basicConfig(handlers=[handler], level=logging.INFO)
142
+
143
+ logging.captureWarnings(True)
144
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
145
+
146
+
147
+ def excepthook(exctype, value, traceback):
148
+ """
149
+ Used to patch `sys.excepthook` in order to log exceptions.
150
+ """
151
+ if issubclass(exctype, KeyboardInterrupt):
152
+ sys.__excepthook__(exctype, value, traceback)
153
+ elif issubclass(exctype, OLMoCliError):
154
+ rich.get_console().print(f"[yellow]{value}[/]", highlight=False)
155
+ elif issubclass(exctype, OLMoError):
156
+ rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
157
+ else:
158
+ log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))
159
+
160
+
161
+ def install_excepthook():
162
+ sys.excepthook = excepthook
163
+
164
+
165
+ def filter_warnings():
166
+ # Filter internal deprecation warnings from torch
167
+ warnings.filterwarnings(
168
+ action="ignore",
169
+ category=UserWarning,
170
+ message="torch.distributed.*_base is a private function and will be deprecated.*",
171
+ )
172
+ warnings.filterwarnings(
173
+ action="ignore",
174
+ category=UserWarning,
175
+ message="TypedStorage is deprecated.*",
176
+ )
177
+ warnings.filterwarnings(
178
+ action="ignore",
179
+ category=UserWarning,
180
+ message="Please use DTensor instead.*",
181
+ )
182
+ # Torchvision warnings. We don't actually use torchvision.
183
+ warnings.filterwarnings(
184
+ action="ignore",
185
+ message="failed to load.*",
186
+ module="torchvision.io.image",
187
+ )
188
+
189
+
190
+ def set_env_variables():
191
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
192
+
193
+
194
+ def prepare_cli_environment(log_filter_type: Optional[LogFilterType] = None):
195
+ if log_filter_type is None:
196
+ log_filter_type = LogFilterType(os.environ.get("LOG_FILTER_TYPE", "rank0_only"))
197
+ rich.reconfigure(width=max(rich.get_console().width, 180), soft_wrap=True)
198
+ setup_logging(log_filter_type=log_filter_type)
199
+ install_excepthook()
200
+ filter_warnings()
201
+ set_env_variables()
202
+
203
+
204
+ def clean_opt(arg: str) -> str:
205
+ if "=" not in arg:
206
+ arg = f"{arg}=True"
207
+ name, val = arg.split("=", 1)
208
+ name = name.strip("-").replace("-", "_")
209
+ return f"{name}={val}"
210
+
211
+
212
+ class RichHandler(logging.Handler):
213
+ """
214
+ A simplified version of rich.logging.RichHandler from
215
+ https://github.com/Textualize/rich/blob/master/rich/logging.py
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ *,
221
+ level: Union[int, str] = logging.NOTSET,
222
+ console: Optional[Console] = None,
223
+ markup: bool = False,
224
+ ) -> None:
225
+ super().__init__(level=level)
226
+ self.console = console or rich.get_console()
227
+ self.highlighter = NullHighlighter()
228
+ self.markup = markup
229
+
230
+ def emit(self, record: logging.LogRecord) -> None:
231
+ try:
232
+ if hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"):
233
+ self.console.print(record.msg)
234
+ else:
235
+ msg: Any = record.msg
236
+ if isinstance(record.msg, str):
237
+ msg = self.render_message(record=record, message=record.getMessage())
238
+ renderables = [
239
+ self.get_time_text(record),
240
+ self.get_level_text(record),
241
+ self.get_location_text(record),
242
+ msg,
243
+ ]
244
+ if record.exc_info is not None:
245
+ tb = Traceback.from_exception(*record.exc_info) # type: ignore
246
+ renderables.append(tb)
247
+ self.console.print(*renderables)
248
+ except Exception:
249
+ self.handleError(record)
250
+
251
+ def render_message(self, *, record: logging.LogRecord, message: str) -> ConsoleRenderable:
252
+ use_markup = getattr(record, "markup", self.markup)
253
+ message_text = Text.from_markup(message) if use_markup else Text(message)
254
+
255
+ highlighter = getattr(record, "highlighter", self.highlighter)
256
+ if highlighter:
257
+ message_text = highlighter(message_text)
258
+
259
+ return message_text
260
+
261
+ def get_time_text(self, record: logging.LogRecord) -> Text:
262
+ log_time = datetime.fromtimestamp(record.created)
263
+ time_str = log_time.strftime("[%Y-%m-%d %X]")
264
+ return Text(time_str, style="log.time", end=" ")
265
+
266
+ def get_level_text(self, record: logging.LogRecord) -> Text:
267
+ level_name = record.levelname
268
+ level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}")
269
+ level_text.style = "log.level"
270
+ level_text.end = " "
271
+ return level_text
272
+
273
+ def get_location_text(self, record: logging.LogRecord) -> Text:
274
+ name_and_line = f"{record.name}:{record.lineno}" if record.name != "root" else "root"
275
+ text = f"[{name_and_line}, rank={record.local_rank}]" # type: ignore
276
+ return Text(text, style="log.path")
277
+
278
+
279
+ def wait_for(condition: Callable[[], bool], description: str, timeout: float = 10.0):
280
+ """Wait for the condition function to return True."""
281
+ start_time = time.monotonic()
282
+ while not condition():
283
+ time.sleep(0.5)
284
+ if time.monotonic() - start_time > timeout:
285
+ raise TimeoutError(f"{description} timed out")
286
+
287
+
288
+ def is_url(path: PathOrStr) -> bool:
289
+ return re.match(r"[a-z0-9]+://.*", str(path)) is not None
290
+
291
+
292
+ def dir_is_empty(dir: PathOrStr) -> bool:
293
+ dir = Path(dir)
294
+ if not dir.is_dir():
295
+ return True
296
+ try:
297
+ next(dir.glob("*"))
298
+ return False
299
+ except StopIteration:
300
+ return True
301
+
302
+
303
+ def get_progress_bar() -> Progress:
304
+ from cached_path import get_download_progress
305
+
306
+ return get_download_progress()
307
+
308
+
309
+ def resource_path(
310
+ folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None, progress: Optional[Progress] = None
311
+ ) -> Path:
312
+ if local_cache is not None and (local_path := Path(local_cache) / fname).is_file():
313
+ log.info(f"Found local cache of {fname} at {local_path}")
314
+ return local_path
315
+ else:
316
+ from cached_path import cached_path
317
+
318
+ return cached_path(f"{str(folder).rstrip('/')}/{fname}", progress=progress)
319
+
320
+
321
+ def file_size(path: PathOrStr) -> int:
322
+ """
323
+ Get the size of a local or remote file in bytes.
324
+ """
325
+ if is_url(path):
326
+ from urllib.parse import urlparse
327
+
328
+ parsed = urlparse(str(path))
329
+ if parsed.scheme == "gs":
330
+ return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
331
+ elif parsed.scheme in ("s3", "r2"):
332
+ return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
333
+ elif parsed.scheme == "file":
334
+ return file_size(str(path).replace("file://", "", 1))
335
+ else:
336
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
337
+ else:
338
+ return os.stat(path).st_size
339
+
340
+
341
+ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
342
+ """Upload source file to a target location on GCS or S3."""
343
+ from urllib.parse import urlparse
344
+
345
+ source = Path(source)
346
+ assert source.is_file()
347
+ parsed = urlparse(target)
348
+ if parsed.scheme == "gs":
349
+ _gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
350
+ elif parsed.scheme in ("s3", "r2"):
351
+ _s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
352
+ else:
353
+ raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")
354
+
355
+
356
+ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> bytes:
357
+ if is_url(source):
358
+ from urllib.parse import urlparse
359
+
360
+ parsed = urlparse(str(source))
361
+ if parsed.scheme == "gs":
362
+ return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
363
+ elif parsed.scheme in ("s3", "r2"):
364
+ return _s3_get_bytes_range(
365
+ parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
366
+ )
367
+ elif parsed.scheme == "file":
368
+ return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
369
+ else:
370
+ raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
371
+ else:
372
+ with open(source, "rb") as f:
373
+ f.seek(bytes_start)
374
+ return f.read(num_bytes)
375
+
376
+
377
+ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
378
+ if is_url(dir):
379
+ from urllib.parse import urlparse
380
+
381
+ parsed = urlparse(str(dir))
382
+ if parsed.scheme == "gs":
383
+ raise NotImplementedError
384
+ elif parsed.scheme in ("s3", "r2"):
385
+ return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
386
+ elif parsed.scheme == "file":
387
+ return find_latest_checkpoint(str(dir).replace("file://", "", 1))
388
+ else:
389
+ raise NotImplementedError(f"find_latest_checkpoint not implemented for '{parsed.scheme}' files")
390
+ else:
391
+ latest_step = 0
392
+ latest_checkpoint: Optional[Path] = None
393
+ for path in Path(dir).glob("step*"):
394
+ if path.is_dir():
395
+ try:
396
+ step = int(path.name.replace("step", "").replace("-unsharded", ""))
397
+ except ValueError:
398
+ continue
399
+ # We prioritize sharded checkpoints over unsharded checkpoints.
400
+ if step > latest_step or (step == latest_step and not path.name.endswith("-unsharded")):
401
+ latest_step = step
402
+ latest_checkpoint = path
403
+ return latest_checkpoint
404
+
405
+
406
+ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
407
+ from google.cloud import storage as gcs
408
+
409
+ storage_client = gcs.Client()
410
+ bucket = storage_client.bucket(bucket_name)
411
+ blob = bucket.blob(key)
412
+ if not save_overwrite and blob.exists():
413
+ raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
414
+ blob.upload_from_filename(source)
415
+
416
+
417
+ def _gcs_file_size(bucket_name: str, key: str) -> int:
418
+ from google.api_core.exceptions import NotFound
419
+ from google.cloud import storage as gcs
420
+
421
+ storage_client = gcs.Client()
422
+ bucket = storage_client.bucket(bucket_name)
423
+ blob = bucket.blob(key)
424
+ try:
425
+ blob.reload()
426
+ except NotFound:
427
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
428
+ assert blob.size is not None
429
+ return blob.size
430
+
431
+
432
+ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
433
+ from google.api_core.exceptions import NotFound
434
+ from google.cloud import storage as gcs
435
+
436
+ storage_client = gcs.Client()
437
+ bucket = storage_client.bucket(bucket_name)
438
+ blob = bucket.blob(key)
439
+ try:
440
+ blob.reload()
441
+ except NotFound:
442
+ raise FileNotFoundError(f"gs://{bucket_name}/{key}")
443
+ return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)
444
+
445
+
446
+ def _get_s3_profile_name(scheme: str) -> Optional[str]:
447
+ if scheme == "s3":
448
+ # For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
449
+ return os.environ.get("S3_PROFILE")
450
+ if scheme == "r2":
451
+ profile_name = os.environ.get("R2_PROFILE")
452
+ if profile_name is None:
453
+ raise OLMoEnvironmentError(
454
+ "R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
455
+ )
456
+
457
+ return profile_name
458
+
459
+ raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")
460
+
461
+
462
+ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
463
+ if scheme == "s3":
464
+ return None
465
+ if scheme == "r2":
466
+ r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
467
+ if r2_endpoint_url is None:
468
+ raise OLMoEnvironmentError(
469
+ "R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
470
+ )
471
+
472
+ return r2_endpoint_url
473
+
474
+ raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")
475
+
476
+
477
+ @cache
478
+ def _get_s3_client(scheme: str):
479
+ session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
480
+ return session.client(
481
+ "s3",
482
+ endpoint_url=_get_s3_endpoint_url(scheme),
483
+ config=Config(retries={"max_attempts": 10, "mode": "standard"}),
484
+ use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
485
+ )
486
+
487
+
488
+ def _wait_before_retry(attempt: int):
489
+ time.sleep(min(0.5 * 2**attempt, 3.0))
490
+
491
+
492
+ def _s3_upload(
493
+ source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
494
+ ):
495
+ err: Optional[Exception] = None
496
+ if not save_overwrite:
497
+ for attempt in range(1, max_attempts + 1):
498
+ try:
499
+ _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
500
+ raise FileExistsError(
501
+ f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
502
+ )
503
+ except boto_exceptions.ClientError as e:
504
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
505
+ err = None
506
+ break
507
+ err = e
508
+
509
+ if attempt < max_attempts:
510
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
511
+ _wait_before_retry(attempt)
512
+
513
+ if err is not None:
514
+ raise OLMoNetworkError(f"Failed to check object existence during {scheme} upload") from err
515
+
516
+ try:
517
+ _get_s3_client(scheme).upload_file(source, bucket_name, key)
518
+ except boto_exceptions.ClientError as e:
519
+ raise OLMoNetworkError(f"Failed to upload to {scheme}") from e
520
+
521
+
522
+ def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
523
+ err: Optional[Exception] = None
524
+ for attempt in range(1, max_attempts + 1):
525
+ try:
526
+ return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
527
+ except boto_exceptions.ClientError as e:
528
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
529
+ raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
530
+ err = e
531
+
532
+ if attempt < max_attempts:
533
+ log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
534
+ _wait_before_retry(attempt)
535
+
536
+ raise OLMoNetworkError(f"Failed to get {scheme} file size") from err
537
+
538
+
539
+ def _s3_get_bytes_range(
540
+ scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
541
+ ) -> bytes:
542
+ err: Optional[Exception] = None
543
+ for attempt in range(1, max_attempts + 1):
544
+ try:
545
+ return (
546
+ _get_s3_client(scheme)
547
+ .get_object(
548
+ Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
549
+ )["Body"]
550
+ .read()
551
+ )
552
+ except boto_exceptions.ClientError as e:
553
+ if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
554
+ raise FileNotFoundError(f"{scheme}://{bucket_name}/{key}") from e
555
+ err = e
556
+ except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
557
+ # ResponseStreamingError (subclass of HTTPClientError) can happen as
558
+ # a result of a failed read from the stream (http.client.IncompleteRead).
559
+ # Retrying can help in this case.
560
+ err = e
561
+
562
+ if attempt < max_attempts:
563
+ log.warning(
564
+ "%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
565
+ )
566
+ _wait_before_retry(attempt)
567
+
568
+ # When torch's DataLoader intercepts exceptions, it may try to re-raise them
569
+ # by recalling their constructor with a single message arg. Torch has some
570
+ # logic to deal with the absence of a single-parameter constructor, but it
571
+ # doesn't gracefully handle other possible failures in calling such a constructor
572
+ # This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
573
+ # in us losing the true exception info. To avoid this, we change the exception
574
+ # to a type that has a single-parameter constructor.
575
+ raise OLMoNetworkError(f"Failed to get bytes range from {scheme}") from err
576
+
577
+
578
+ def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
579
+ if not prefix.endswith("/"):
580
+ prefix = f"{prefix}/"
581
+ response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
582
+ assert not response["IsTruncated"] # need to handle this if it happens
583
+ latest_step = 0
584
+ latest_checkpoint: Optional[str] = None
585
+ for item in response["CommonPrefixes"]:
586
+ prefix = item["Prefix"].strip("/")
587
+ checkpoint_name = os.path.split(prefix)[-1]
588
+ if not checkpoint_name.startswith("step"):
589
+ continue
590
+ try:
591
+ step = int(checkpoint_name.replace("step", "").replace("-unsharded", ""))
592
+ except ValueError:
593
+ continue
594
+ # Make sure the checkpoint dir contains a config, otherwise the checkpoint is incomplete
595
+ # (upload might have have failed part way through).
596
+ try:
597
+ _s3_file_size(scheme, bucket_name, f"{prefix}/config.yaml")
598
+ except FileNotFoundError:
599
+ continue
600
+ # We prioritize sharded checkpoints over unsharded ones.
601
+ if step > latest_step or (step == latest_step and not checkpoint_name.endswith("-unsharded")):
602
+ latest_step = step
603
+ latest_checkpoint = f"{scheme}://ai2-llm/{prefix}"
604
+ return latest_checkpoint
605
+
606
+
607
+ def default_thread_count() -> int:
608
+ return int(os.environ.get("OLMO_NUM_THREADS") or min(32, (os.cpu_count() or 1) + 4))
609
+
610
+
611
+ def pass_through_fn(fn, *args, **kwargs):
612
+ return fn(*args, **kwargs)
613
+
614
+
615
+ def threaded_generator(g, maxsize: int = 16, thread_name: Optional[str] = None):
616
+ q: Queue = Queue(maxsize=maxsize)
617
+
618
+ sentinel = object()
619
+
620
+ def fill_queue():
621
+ try:
622
+ for value in g:
623
+ q.put(value)
624
+ except Exception as e:
625
+ q.put(e)
626
+ finally:
627
+ q.put(sentinel)
628
+
629
+ thread_name = thread_name or repr(g)
630
+ thread = Thread(name=thread_name, target=fill_queue, daemon=True)
631
+ thread.start()
632
+
633
+ for x in iter(q.get, sentinel):
634
+ if isinstance(x, Exception):
635
+ raise OLMoThreadError(f"generator thread {thread_name} failed") from x
636
+ else:
637
+ yield x
638
+
639
+
640
+ def roundrobin(*iterables):
641
+ """
642
+ Call the given iterables in a round-robin fashion. For example:
643
+ ``roundrobin('ABC', 'D', 'EF') --> A D E B F C``
644
+ """
645
+ # Adapted from https://docs.python.org/3/library/itertools.html#itertools-recipes
646
+ num_active = len(iterables)
647
+ nexts = cycle(iter(it).__next__ for it in iterables)
648
+ while num_active:
649
+ try:
650
+ for next in nexts:
651
+ yield next()
652
+ except StopIteration:
653
+ # Remove the iterator we just exhausted from the cycle.
654
+ num_active -= 1
655
+ nexts = cycle(islice(nexts, num_active))