update
Browse files- README.md +21 -13
- config.json +3 -3
- modeling_lsg_bart.py +85 -480
- pytorch_model.bin +1 -1
README.md
CHANGED
@@ -23,16 +23,23 @@ should probably proofread and complete it, then remove this comment. -->
|
|
23 |
This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers arxiv dataset. \
|
24 |
It achieves the following results on the test set:
|
25 |
|
26 |
-
| Length | Sparse Type
|
27 |
-
|:------
|
28 |
-
| 4096 |
|
29 |
-
| 4096 |
|
30 |
-
| 4096 |
|
31 |
-
| 4096 |
|
32 |
-
| 4096 |
|
33 |
-
| 4096 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
## Model description
|
38 |
The model relies on Local-Sparse-Global attention to handle long sequences:
|
@@ -61,7 +68,8 @@ The following hyperparameters were used during training:
|
|
61 |
- total_train_batch_size: 32
|
62 |
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
63 |
- lr_scheduler_type: linear
|
64 |
-
-
|
|
|
65 |
|
66 |
### Generate hyperparameters
|
67 |
|
@@ -69,13 +77,13 @@ The following hyperparameters were used during generation:
|
|
69 |
- dataset_name: scientific_papers
|
70 |
- dataset_config_name: arxiv
|
71 |
- eval_batch_size: 8
|
|
|
72 |
- early_stopping: True
|
73 |
- ignore_pad_token_for_loss: True
|
74 |
- length_penalty: 2.0
|
75 |
- max_length: 320
|
76 |
-
- min_length:
|
77 |
- num_beams: 5
|
78 |
-
- num_samples: None
|
79 |
- no_repeat_ngram_size: None
|
80 |
- seed: 123
|
81 |
|
|
|
23 |
This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers arxiv dataset. \
|
24 |
It achieves the following results on the test set:
|
25 |
|
26 |
+
| Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
|
27 |
+
|:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
|
28 |
+
| 4096 | Local | 256 | 0 | 768 | 46.65 | 18.91 | 26.90 | 42.18 |
|
29 |
+
| 4096 | Local | 128 | 0 | 384 | 46.18 | 18.57 | 26.71 | 41.69 |
|
30 |
+
| 4096 | Pooling | 128 | 4 | 644 | 46.27 | 18.68 | 26.87 | 41.82 |
|
31 |
+
| 4096 | Stride | 128 | 4 | 644 | 46.34 | 18.64 | 26.69 | 41.87 |
|
32 |
+
| 4096 | Norm | 128 | 4 | 644 | 45.96 | 18.46 | 26.52 | 41.51 |
|
33 |
+
| 4096 | LSH | 128 | 4 | 644 | 46.19 | 18.72 | 26.89 | 41.76 |
|
34 |
+
|
35 |
+
With blocks of size 32 (lower ressources):
|
36 |
+
| Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
|
37 |
+
|:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
|
38 |
+
| 4096 | Pooling | 32 | 4 | 160 | 42.75 | 16.34 | 25.20 | 38.23 |
|
39 |
+
| 4096 | Stride | 32 | 4 | 160 | 44.23 | 17.21 | 25.71 | 39.72 |
|
40 |
+
| 4096 | Block Stride | 32 | 4 | 160 | 44.15 | 17.10 | 25.68 | 39.60 |
|
41 |
+
| 4096 | Norm | 32 | 4 | 160 | 42.02 | 15.65 | 24.56 | 37.45 |
|
42 |
+
| 4096 | LSH | 32 | 4 | 160 | 42.58 | 16.21 | 25.10 | 38.04 |
|
43 |
|
44 |
## Model description
|
45 |
The model relies on Local-Sparse-Global attention to handle long sequences:
|
|
|
68 |
- total_train_batch_size: 32
|
69 |
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
70 |
- lr_scheduler_type: linear
|
71 |
+
- lr_scheduler_warmup_ratio: 0.1
|
72 |
+
- num_epochs: 6.0
|
73 |
|
74 |
### Generate hyperparameters
|
75 |
|
|
|
77 |
- dataset_name: scientific_papers
|
78 |
- dataset_config_name: arxiv
|
79 |
- eval_batch_size: 8
|
80 |
+
- eval_samples: 6440
|
81 |
- early_stopping: True
|
82 |
- ignore_pad_token_for_loss: True
|
83 |
- length_penalty: 2.0
|
84 |
- max_length: 320
|
85 |
+
- min_length: 32
|
86 |
- num_beams: 5
|
|
|
87 |
- no_repeat_ngram_size: None
|
88 |
- seed: 123
|
89 |
|
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"activation_dropout": 0.1,
|
4 |
"activation_function": "gelu",
|
5 |
"adaptive": true,
|
@@ -67,8 +67,8 @@
|
|
67 |
"pool_with_global": true,
|
68 |
"scale_embedding": false,
|
69 |
"sparse_block_size": 0,
|
70 |
-
"sparsity_factor":
|
71 |
-
"sparsity_type": "
|
72 |
"task_specific_params": {
|
73 |
"summarization": {
|
74 |
"length_penalty": 1.0,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local",
|
3 |
"activation_dropout": 0.1,
|
4 |
"activation_function": "gelu",
|
5 |
"adaptive": true,
|
|
|
67 |
"pool_with_global": true,
|
68 |
"scale_embedding": false,
|
69 |
"sparse_block_size": 0,
|
70 |
+
"sparsity_factor": 2,
|
71 |
+
"sparsity_type": "none",
|
72 |
"task_specific_params": {
|
73 |
"summarization": {
|
74 |
"length_penalty": 1.0,
|
modeling_lsg_bart.py
CHANGED
@@ -41,8 +41,6 @@ class LSGBartConfig(BartConfig):
|
|
41 |
):
|
42 |
"""Constructs LSGConfig."""
|
43 |
super().__init__(**kwargs)
|
44 |
-
|
45 |
-
assert sparsity_type in ["norm", "lsh", "pooling", "stride"], "Sparsity mode must be 'norm', 'lsh' or 'pooling'"
|
46 |
|
47 |
self.adaptive = adaptive
|
48 |
self.auto_map = AUTO_MAP
|
@@ -55,7 +53,33 @@ class LSGBartConfig(BartConfig):
|
|
55 |
self.sparse_block_size = sparse_block_size
|
56 |
self.sparsity_factor = sparsity_factor
|
57 |
self.sparsity_type = sparsity_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
|
|
|
|
|
|
|
|
59 |
|
60 |
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
61 |
"""
|
@@ -208,8 +232,6 @@ class LSGAttentionProduct(nn.Module):
|
|
208 |
# Shape of blocks
|
209 |
self.local_shapes = (self.block_size*3, self.block_size)
|
210 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
211 |
-
assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
|
212 |
-
assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
|
213 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
214 |
|
215 |
self.attention = BaseAttentionProduct(config)
|
@@ -393,21 +415,15 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
393 |
}
|
394 |
|
395 |
self.sparsity_type = config.sparsity_type
|
396 |
-
self.get_sparse_elements = sparse_functions
|
397 |
-
|
398 |
-
if config.sparsity_type == "stride":
|
399 |
-
if config.sparsity_factor > config.encoder_attention_heads:
|
400 |
-
logger.warning(
|
401 |
-
"Warning: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
|
402 |
-
)
|
403 |
|
404 |
if config.sparsity_type == "lsh":
|
405 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
406 |
-
|
407 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
408 |
|
409 |
if self.sparsity_factor == 1:
|
410 |
-
return keys, values, mask
|
411 |
|
412 |
with torch.no_grad():
|
413 |
|
@@ -435,7 +451,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
435 |
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
436 |
|
437 |
if self.sparsity_factor == 1:
|
438 |
-
return keys, values, mask
|
439 |
|
440 |
keys = self.chunk(keys, self.sparsity_factor)
|
441 |
values = self.chunk(values, self.sparsity_factor)
|
@@ -457,13 +473,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
457 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
458 |
|
459 |
if self.sparsity_factor == 1:
|
460 |
-
return keys, values, mask
|
461 |
|
462 |
n, h, t, d = keys.size()
|
463 |
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
|
464 |
sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
|
465 |
sparse_idx = sparse_idx.expand(n, h, -1, 1)
|
466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
468 |
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
469 |
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
|
@@ -473,7 +506,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
473 |
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
474 |
|
475 |
if self.sparsity_factor == 1:
|
476 |
-
return keys, values, mask
|
477 |
|
478 |
block_size = min(self.block_size, self.sparse_block_size)
|
479 |
keys = self.chunk(keys, block_size)
|
@@ -490,9 +523,9 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
490 |
extra_factor = 1
|
491 |
|
492 |
for _ in range(self.lsh_num_pre_rounds):
|
493 |
-
keys, values, mask = self.
|
494 |
|
495 |
-
keys, values, mask = self.
|
496 |
keys /= mask + 1e-8
|
497 |
values /= mask + 1e-8
|
498 |
|
@@ -500,7 +533,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
500 |
|
501 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
502 |
|
503 |
-
def
|
504 |
|
505 |
with torch.no_grad():
|
506 |
|
@@ -1304,6 +1337,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
1304 |
self.padding_idx = config.pad_token_id
|
1305 |
self.max_target_positions = config.max_position_embeddings
|
1306 |
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
|
1307 |
|
1308 |
if embed_tokens is not None:
|
1309 |
self.embed_tokens = embed_tokens
|
@@ -1346,6 +1380,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
1346 |
|
1347 |
return combined_attention_mask
|
1348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1349 |
def forward(
|
1350 |
self,
|
1351 |
input_ids=None,
|
@@ -1386,12 +1429,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
1386 |
if inputs_embeds is None:
|
1387 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
1388 |
|
1389 |
-
#
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
|
|
|
|
1395 |
|
1396 |
attention_mask = self._prepare_decoder_attention_mask(
|
1397 |
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
@@ -1485,6 +1530,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
|
|
1485 |
if encoder_hidden_states is not None:
|
1486 |
all_cross_attentions += (layer_outputs[2],)
|
1487 |
|
|
|
|
|
|
|
1488 |
# add hidden states from the last decoder layer
|
1489 |
if output_hidden_states:
|
1490 |
all_hidden_states += (hidden_states,)
|
@@ -1621,14 +1669,14 @@ class LSGBartModel(LSGBartPretrainedModel):
|
|
1621 |
)
|
1622 |
|
1623 |
|
1624 |
-
class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
|
1625 |
|
1626 |
base_model_prefix = "model"
|
1627 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
1628 |
|
1629 |
def __init__(self, config):
|
1630 |
|
1631 |
-
|
1632 |
self.model = LSGBartModel(config)
|
1633 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
1634 |
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
@@ -1636,157 +1684,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
|
|
1636 |
# Initialize weights and apply final processing
|
1637 |
self.post_init()
|
1638 |
|
1639 |
-
def get_encoder(self):
|
1640 |
-
return self.model.get_encoder()
|
1641 |
-
|
1642 |
-
def get_decoder(self):
|
1643 |
-
return self.model.get_decoder()
|
1644 |
-
|
1645 |
-
def resize_token_embeddings(self, new_num_tokens):
|
1646 |
-
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
1647 |
-
self._resize_final_logits_bias(new_num_tokens)
|
1648 |
-
return new_embeddings
|
1649 |
-
|
1650 |
-
def _resize_final_logits_bias(self, new_num_tokens):
|
1651 |
-
old_num_tokens = self.final_logits_bias.shape[-1]
|
1652 |
-
if new_num_tokens <= old_num_tokens:
|
1653 |
-
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
1654 |
-
else:
|
1655 |
-
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
1656 |
-
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
1657 |
-
self.register_buffer("final_logits_bias", new_bias)
|
1658 |
-
|
1659 |
-
def get_output_embeddings(self):
|
1660 |
-
return self.lm_head
|
1661 |
-
|
1662 |
-
def set_output_embeddings(self, new_embeddings):
|
1663 |
-
self.lm_head = new_embeddings
|
1664 |
-
|
1665 |
-
def forward(
|
1666 |
-
self,
|
1667 |
-
input_ids=None,
|
1668 |
-
attention_mask=None,
|
1669 |
-
decoder_input_ids=None,
|
1670 |
-
decoder_attention_mask=None,
|
1671 |
-
head_mask=None,
|
1672 |
-
decoder_head_mask=None,
|
1673 |
-
cross_attn_head_mask=None,
|
1674 |
-
encoder_outputs=None,
|
1675 |
-
past_key_values=None,
|
1676 |
-
inputs_embeds=None,
|
1677 |
-
decoder_inputs_embeds=None,
|
1678 |
-
labels=None,
|
1679 |
-
use_cache=None,
|
1680 |
-
output_attentions=None,
|
1681 |
-
output_hidden_states=None,
|
1682 |
-
return_dict=None,
|
1683 |
-
):
|
1684 |
-
|
1685 |
-
r"""
|
1686 |
-
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1687 |
-
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
1688 |
-
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
|
1689 |
-
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
|
1690 |
-
Returns:
|
1691 |
-
"""
|
1692 |
-
|
1693 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1694 |
-
|
1695 |
-
if labels is not None:
|
1696 |
-
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
1697 |
-
decoder_input_ids = shift_tokens_right(
|
1698 |
-
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1699 |
-
)
|
1700 |
-
|
1701 |
-
outputs = self.model(
|
1702 |
-
input_ids,
|
1703 |
-
attention_mask=attention_mask,
|
1704 |
-
decoder_input_ids=decoder_input_ids,
|
1705 |
-
encoder_outputs=encoder_outputs,
|
1706 |
-
decoder_attention_mask=decoder_attention_mask,
|
1707 |
-
head_mask=head_mask,
|
1708 |
-
decoder_head_mask=decoder_head_mask,
|
1709 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
1710 |
-
past_key_values=past_key_values,
|
1711 |
-
inputs_embeds=inputs_embeds,
|
1712 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
1713 |
-
use_cache=use_cache,
|
1714 |
-
output_attentions=output_attentions,
|
1715 |
-
output_hidden_states=output_hidden_states,
|
1716 |
-
return_dict=return_dict,
|
1717 |
-
)
|
1718 |
-
|
1719 |
-
|
1720 |
-
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
1721 |
-
|
1722 |
-
masked_lm_loss = None
|
1723 |
-
if labels is not None:
|
1724 |
-
loss_fct = CrossEntropyLoss()
|
1725 |
-
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
1726 |
-
|
1727 |
-
if not return_dict:
|
1728 |
-
output = (lm_logits,) + outputs[1:]
|
1729 |
-
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1730 |
-
|
1731 |
-
return Seq2SeqLMOutput(
|
1732 |
-
loss=masked_lm_loss,
|
1733 |
-
logits=lm_logits,
|
1734 |
-
past_key_values=outputs.past_key_values,
|
1735 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
1736 |
-
decoder_attentions=outputs.decoder_attentions,
|
1737 |
-
cross_attentions=outputs.cross_attentions,
|
1738 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1739 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
1740 |
-
encoder_attentions=outputs.encoder_attentions,
|
1741 |
-
)
|
1742 |
-
|
1743 |
-
def prepare_inputs_for_generation(
|
1744 |
-
self,
|
1745 |
-
decoder_input_ids,
|
1746 |
-
past=None,
|
1747 |
-
attention_mask=None,
|
1748 |
-
head_mask=None,
|
1749 |
-
decoder_head_mask=None,
|
1750 |
-
cross_attn_head_mask=None,
|
1751 |
-
use_cache=None,
|
1752 |
-
encoder_outputs=None,
|
1753 |
-
**kwargs
|
1754 |
-
):
|
1755 |
-
# cut decoder_input_ids if past is used
|
1756 |
-
if past is not None:
|
1757 |
-
decoder_input_ids = decoder_input_ids[:, -1:]
|
1758 |
-
|
1759 |
-
return {
|
1760 |
-
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
1761 |
-
"encoder_outputs": encoder_outputs,
|
1762 |
-
"past_key_values": past,
|
1763 |
-
"decoder_input_ids": decoder_input_ids,
|
1764 |
-
"attention_mask": attention_mask,
|
1765 |
-
"head_mask": head_mask,
|
1766 |
-
"decoder_head_mask": decoder_head_mask,
|
1767 |
-
"cross_attn_head_mask": cross_attn_head_mask,
|
1768 |
-
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
1769 |
-
}
|
1770 |
-
|
1771 |
-
def prepare_decoder_input_ids_from_labels(self, labels):
|
1772 |
-
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
1773 |
-
|
1774 |
-
@staticmethod
|
1775 |
-
def _reorder_cache(past, beam_idx):
|
1776 |
-
reordered_past = ()
|
1777 |
-
for layer_past in past:
|
1778 |
-
# cached cross_attention states don't have to be reordered -> they are always the same
|
1779 |
-
reordered_past += (
|
1780 |
-
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
1781 |
-
)
|
1782 |
-
return reordered_past
|
1783 |
-
|
1784 |
|
1785 |
-
class LSGBartForSequenceClassification(LSGBartPretrainedModel):
|
1786 |
|
1787 |
-
def __init__(self, config, **kwargs):
|
1788 |
|
1789 |
-
|
1790 |
self.model = LSGBartModel(config)
|
1791 |
self.classification_head = LSGBartClassificationHead(
|
1792 |
config.d_model,
|
@@ -1797,115 +1700,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
|
|
1797 |
self.model._init_weights(self.classification_head.dense)
|
1798 |
self.model._init_weights(self.classification_head.out_proj)
|
1799 |
|
1800 |
-
def forward(
|
1801 |
-
self,
|
1802 |
-
input_ids=None,
|
1803 |
-
attention_mask=None,
|
1804 |
-
decoder_input_ids=None,
|
1805 |
-
decoder_attention_mask=None,
|
1806 |
-
head_mask=None,
|
1807 |
-
decoder_head_mask=None,
|
1808 |
-
cross_attn_head_mask=None,
|
1809 |
-
encoder_outputs=None,
|
1810 |
-
inputs_embeds=None,
|
1811 |
-
decoder_inputs_embeds=None,
|
1812 |
-
labels=None,
|
1813 |
-
use_cache=None,
|
1814 |
-
output_attentions=None,
|
1815 |
-
output_hidden_states=None,
|
1816 |
-
return_dict=None,
|
1817 |
-
):
|
1818 |
-
|
1819 |
-
r"""
|
1820 |
-
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1821 |
-
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
1822 |
-
config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1823 |
-
"""
|
1824 |
-
|
1825 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1826 |
-
if labels is not None:
|
1827 |
-
use_cache = False
|
1828 |
-
|
1829 |
-
if input_ids is None and inputs_embeds is not None:
|
1830 |
-
raise NotImplementedError(
|
1831 |
-
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
|
1832 |
-
)
|
1833 |
-
|
1834 |
-
outputs = self.model(
|
1835 |
-
input_ids,
|
1836 |
-
attention_mask=attention_mask,
|
1837 |
-
decoder_input_ids=decoder_input_ids,
|
1838 |
-
decoder_attention_mask=decoder_attention_mask,
|
1839 |
-
head_mask=head_mask,
|
1840 |
-
decoder_head_mask=decoder_head_mask,
|
1841 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
1842 |
-
encoder_outputs=encoder_outputs,
|
1843 |
-
inputs_embeds=inputs_embeds,
|
1844 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
1845 |
-
use_cache=use_cache,
|
1846 |
-
output_attentions=output_attentions,
|
1847 |
-
output_hidden_states=output_hidden_states,
|
1848 |
-
return_dict=return_dict,
|
1849 |
-
)
|
1850 |
-
hidden_states = outputs[0] # last hidden state
|
1851 |
-
|
1852 |
-
eos_mask = input_ids.eq(self.config.eos_token_id)
|
1853 |
-
|
1854 |
-
t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
|
1855 |
-
if t > t_:
|
1856 |
-
eos_mask = eos_mask[:, :t_]
|
1857 |
-
|
1858 |
-
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
1859 |
-
raise ValueError("All examples must have the same number of <eos> tokens.")
|
1860 |
-
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
1861 |
-
:, -1, :
|
1862 |
-
]
|
1863 |
-
logits = self.classification_head(sentence_representation)
|
1864 |
-
|
1865 |
-
loss = None
|
1866 |
-
if labels is not None:
|
1867 |
-
if self.config.problem_type is None:
|
1868 |
-
if self.config.num_labels == 1:
|
1869 |
-
self.config.problem_type = "regression"
|
1870 |
-
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1871 |
-
self.config.problem_type = "single_label_classification"
|
1872 |
-
else:
|
1873 |
-
self.config.problem_type = "multi_label_classification"
|
1874 |
-
|
1875 |
-
if self.config.problem_type == "regression":
|
1876 |
-
loss_fct = MSELoss()
|
1877 |
-
if self.config.num_labels == 1:
|
1878 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1879 |
-
else:
|
1880 |
-
loss = loss_fct(logits, labels)
|
1881 |
-
elif self.config.problem_type == "single_label_classification":
|
1882 |
-
loss_fct = CrossEntropyLoss()
|
1883 |
-
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
1884 |
-
elif self.config.problem_type == "multi_label_classification":
|
1885 |
-
loss_fct = BCEWithLogitsLoss()
|
1886 |
-
loss = loss_fct(logits, labels)
|
1887 |
-
if not return_dict:
|
1888 |
-
output = (logits,) + outputs[1:]
|
1889 |
-
return ((loss,) + output) if loss is not None else output
|
1890 |
-
|
1891 |
-
return Seq2SeqSequenceClassifierOutput(
|
1892 |
-
loss=loss,
|
1893 |
-
logits=logits,
|
1894 |
-
past_key_values=outputs.past_key_values,
|
1895 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
1896 |
-
decoder_attentions=outputs.decoder_attentions,
|
1897 |
-
cross_attentions=outputs.cross_attentions,
|
1898 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1899 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
1900 |
-
encoder_attentions=outputs.encoder_attentions,
|
1901 |
-
)
|
1902 |
|
|
|
1903 |
|
1904 |
-
|
1905 |
|
1906 |
-
|
1907 |
-
|
1908 |
-
super().__init__(config)
|
1909 |
|
1910 |
config.num_labels = 2
|
1911 |
self.num_labels = config.num_labels
|
@@ -1915,102 +1715,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
|
|
1915 |
|
1916 |
self.model._init_weights(self.qa_outputs)
|
1917 |
|
1918 |
-
def forward(
|
1919 |
-
self,
|
1920 |
-
input_ids=None,
|
1921 |
-
attention_mask=None,
|
1922 |
-
decoder_input_ids=None,
|
1923 |
-
decoder_attention_mask=None,
|
1924 |
-
head_mask=None,
|
1925 |
-
decoder_head_mask=None,
|
1926 |
-
cross_attn_head_mask=None,
|
1927 |
-
encoder_outputs=None,
|
1928 |
-
start_positions=None,
|
1929 |
-
end_positions=None,
|
1930 |
-
inputs_embeds=None,
|
1931 |
-
decoder_inputs_embeds=None,
|
1932 |
-
use_cache=None,
|
1933 |
-
output_attentions=None,
|
1934 |
-
output_hidden_states=None,
|
1935 |
-
return_dict=None,
|
1936 |
-
):
|
1937 |
-
|
1938 |
-
r"""
|
1939 |
-
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1940 |
-
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1941 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1942 |
-
are not taken into account for computing the loss.
|
1943 |
-
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1944 |
-
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1945 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1946 |
-
are not taken into account for computing the loss.
|
1947 |
-
"""
|
1948 |
-
|
1949 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1950 |
-
if start_positions is not None and end_positions is not None:
|
1951 |
-
use_cache = False
|
1952 |
-
|
1953 |
-
outputs = self.model(
|
1954 |
-
input_ids,
|
1955 |
-
attention_mask=attention_mask,
|
1956 |
-
decoder_input_ids=decoder_input_ids,
|
1957 |
-
decoder_attention_mask=decoder_attention_mask,
|
1958 |
-
head_mask=head_mask,
|
1959 |
-
decoder_head_mask=decoder_head_mask,
|
1960 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
1961 |
-
encoder_outputs=encoder_outputs,
|
1962 |
-
inputs_embeds=inputs_embeds,
|
1963 |
-
decoder_inputs_embeds=decoder_inputs_embeds,
|
1964 |
-
use_cache=use_cache,
|
1965 |
-
output_attentions=output_attentions,
|
1966 |
-
output_hidden_states=output_hidden_states,
|
1967 |
-
return_dict=return_dict,
|
1968 |
-
)
|
1969 |
-
|
1970 |
-
sequence_output = outputs[0]
|
1971 |
-
|
1972 |
-
logits = self.qa_outputs(sequence_output)
|
1973 |
-
start_logits, end_logits = logits.split(1, dim=-1)
|
1974 |
-
start_logits = start_logits.squeeze(-1).contiguous()
|
1975 |
-
end_logits = end_logits.squeeze(-1).contiguous()
|
1976 |
-
|
1977 |
-
total_loss = None
|
1978 |
-
if start_positions is not None and end_positions is not None:
|
1979 |
-
# If we are on multi-GPU, split add a dimension
|
1980 |
-
if len(start_positions.size()) > 1:
|
1981 |
-
start_positions = start_positions.squeeze(-1)
|
1982 |
-
if len(end_positions.size()) > 1:
|
1983 |
-
end_positions = end_positions.squeeze(-1)
|
1984 |
-
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1985 |
-
ignored_index = start_logits.size(1)
|
1986 |
-
start_positions = start_positions.clamp(0, ignored_index)
|
1987 |
-
end_positions = end_positions.clamp(0, ignored_index)
|
1988 |
-
|
1989 |
-
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1990 |
-
start_loss = loss_fct(start_logits, start_positions)
|
1991 |
-
end_loss = loss_fct(end_logits, end_positions)
|
1992 |
-
total_loss = (start_loss + end_loss) / 2
|
1993 |
-
|
1994 |
-
if not return_dict:
|
1995 |
-
output = (
|
1996 |
-
start_logits,
|
1997 |
-
end_logits,
|
1998 |
-
) + outputs[1:]
|
1999 |
-
return ((total_loss,) + output) if total_loss is not None else output
|
2000 |
-
|
2001 |
-
return Seq2SeqQuestionAnsweringModelOutput(
|
2002 |
-
loss=total_loss,
|
2003 |
-
start_logits=start_logits,
|
2004 |
-
end_logits=end_logits,
|
2005 |
-
past_key_values=outputs.past_key_values,
|
2006 |
-
decoder_hidden_states=outputs.decoder_hidden_states,
|
2007 |
-
decoder_attentions=outputs.decoder_attentions,
|
2008 |
-
cross_attentions=outputs.cross_attentions,
|
2009 |
-
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
2010 |
-
encoder_hidden_states=outputs.encoder_hidden_states,
|
2011 |
-
encoder_attentions=outputs.encoder_attentions,
|
2012 |
-
)
|
2013 |
-
|
2014 |
|
2015 |
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
2016 |
"""
|
@@ -2018,7 +1722,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
2018 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
2019 |
"""
|
2020 |
|
2021 |
-
def __init__(self, config):
|
2022 |
super().__init__(config)
|
2023 |
self.decoder = LSGBartDecoder(config)
|
2024 |
|
@@ -2026,14 +1730,14 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
2026 |
return self.decoder(*args, **kwargs)
|
2027 |
|
2028 |
|
2029 |
-
class LSGBartForCausalLM(LSGBartPretrainedModel):
|
2030 |
|
2031 |
-
def __init__(self, config):
|
2032 |
|
2033 |
-
super().__init__(config)
|
2034 |
config = copy.deepcopy(config)
|
2035 |
config.is_decoder = True
|
2036 |
config.is_encoder_decoder = False
|
|
|
2037 |
self.model = LSGBartDecoderWrapper(config)
|
2038 |
|
2039 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
@@ -2041,105 +1745,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
|
|
2041 |
# Initialize weights and apply final processing
|
2042 |
self.post_init()
|
2043 |
|
2044 |
-
def get_input_embeddings(self):
|
2045 |
-
return self.model.decoder.embed_tokens
|
2046 |
-
|
2047 |
-
def set_input_embeddings(self, value):
|
2048 |
-
self.model.decoder.embed_tokens = value
|
2049 |
-
|
2050 |
-
def get_output_embeddings(self):
|
2051 |
-
return self.lm_head
|
2052 |
-
|
2053 |
-
def set_output_embeddings(self, new_embeddings):
|
2054 |
-
self.lm_head = new_embeddings
|
2055 |
-
|
2056 |
-
def set_decoder(self, decoder):
|
2057 |
-
self.model.decoder = decoder
|
2058 |
-
|
2059 |
-
def get_decoder(self):
|
2060 |
-
return self.model.decoder
|
2061 |
-
|
2062 |
-
def forward(
|
2063 |
-
self,
|
2064 |
-
input_ids=None,
|
2065 |
-
attention_mask=None,
|
2066 |
-
encoder_hidden_states=None,
|
2067 |
-
encoder_attention_mask=None,
|
2068 |
-
head_mask=None,
|
2069 |
-
cross_attn_head_mask=None,
|
2070 |
-
past_key_values=None,
|
2071 |
-
inputs_embeds=None,
|
2072 |
-
labels=None,
|
2073 |
-
use_cache=None,
|
2074 |
-
output_attentions=None,
|
2075 |
-
output_hidden_states=None,
|
2076 |
-
return_dict=None,
|
2077 |
-
):
|
2078 |
-
|
2079 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2080 |
-
output_hidden_states = (
|
2081 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
2082 |
-
)
|
2083 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2084 |
-
|
2085 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
2086 |
-
outputs = self.model.decoder(
|
2087 |
-
input_ids=input_ids,
|
2088 |
-
attention_mask=attention_mask,
|
2089 |
-
encoder_hidden_states=encoder_hidden_states,
|
2090 |
-
encoder_attention_mask=encoder_attention_mask,
|
2091 |
-
head_mask=head_mask,
|
2092 |
-
cross_attn_head_mask=cross_attn_head_mask,
|
2093 |
-
past_key_values=past_key_values,
|
2094 |
-
inputs_embeds=inputs_embeds,
|
2095 |
-
use_cache=use_cache,
|
2096 |
-
output_attentions=output_attentions,
|
2097 |
-
output_hidden_states=output_hidden_states,
|
2098 |
-
return_dict=return_dict,
|
2099 |
-
)
|
2100 |
-
|
2101 |
-
logits = self.lm_head(outputs[0])
|
2102 |
-
|
2103 |
-
loss = None
|
2104 |
-
if labels is not None:
|
2105 |
-
loss_fct = CrossEntropyLoss()
|
2106 |
-
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
2107 |
-
|
2108 |
-
if not return_dict:
|
2109 |
-
output = (logits,) + outputs[1:]
|
2110 |
-
return (loss,) + output if loss is not None else output
|
2111 |
-
|
2112 |
-
return CausalLMOutputWithCrossAttentions(
|
2113 |
-
loss=loss,
|
2114 |
-
logits=logits,
|
2115 |
-
past_key_values=outputs.past_key_values,
|
2116 |
-
hidden_states=outputs.hidden_states,
|
2117 |
-
attentions=outputs.attentions,
|
2118 |
-
cross_attentions=outputs.cross_attentions,
|
2119 |
-
)
|
2120 |
-
|
2121 |
-
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
|
2122 |
-
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
2123 |
-
if attention_mask is None:
|
2124 |
-
attention_mask = input_ids.new_ones(input_ids.shape)
|
2125 |
-
|
2126 |
-
if past:
|
2127 |
-
input_ids = input_ids[:, -1:]
|
2128 |
-
# first step, decoder_cached_states are empty
|
2129 |
-
return {
|
2130 |
-
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
2131 |
-
"attention_mask": attention_mask,
|
2132 |
-
"past_key_values": past,
|
2133 |
-
"use_cache": use_cache,
|
2134 |
-
}
|
2135 |
-
|
2136 |
-
@staticmethod
|
2137 |
-
def _reorder_cache(past, beam_idx):
|
2138 |
-
reordered_past = ()
|
2139 |
-
for layer_past in past:
|
2140 |
-
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
2141 |
-
return reordered_past
|
2142 |
-
|
2143 |
|
2144 |
def str_to_class(classname):
|
2145 |
return getattr(sys.modules[__name__], classname)
|
|
|
41 |
):
|
42 |
"""Constructs LSGConfig."""
|
43 |
super().__init__(**kwargs)
|
|
|
|
|
44 |
|
45 |
self.adaptive = adaptive
|
46 |
self.auto_map = AUTO_MAP
|
|
|
53 |
self.sparse_block_size = sparse_block_size
|
54 |
self.sparsity_factor = sparsity_factor
|
55 |
self.sparsity_type = sparsity_type
|
56 |
+
|
57 |
+
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
|
58 |
+
logger.warning(
|
59 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
|
60 |
+
self.sparsity_type = None
|
61 |
+
|
62 |
+
if self.sparsity_type == "stride":
|
63 |
+
if self.sparsity_factor > self.encoder_attention_heads:
|
64 |
+
logger.warning(
|
65 |
+
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
|
66 |
+
)
|
67 |
+
|
68 |
+
if self.num_global_tokens < 1:
|
69 |
+
logger.warning(
|
70 |
+
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
71 |
+
)
|
72 |
+
self.num_global_tokens = 1
|
73 |
+
elif self.num_global_tokens > 512:
|
74 |
+
logger.warning(
|
75 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
|
76 |
+
)
|
77 |
+
self.num_global_tokens = 512
|
78 |
|
79 |
+
if self.sparsity_factor > 0:
|
80 |
+
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
81 |
+
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
82 |
+
|
83 |
|
84 |
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
85 |
"""
|
|
|
232 |
# Shape of blocks
|
233 |
self.local_shapes = (self.block_size*3, self.block_size)
|
234 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
|
|
|
|
235 |
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
|
236 |
|
237 |
self.attention = BaseAttentionProduct(config)
|
|
|
415 |
}
|
416 |
|
417 |
self.sparsity_type = config.sparsity_type
|
418 |
+
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
420 |
if config.sparsity_type == "lsh":
|
421 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
422 |
+
|
423 |
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
424 |
|
425 |
if self.sparsity_factor == 1:
|
426 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
427 |
|
428 |
with torch.no_grad():
|
429 |
|
|
|
451 |
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
452 |
|
453 |
if self.sparsity_factor == 1:
|
454 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
455 |
|
456 |
keys = self.chunk(keys, self.sparsity_factor)
|
457 |
values = self.chunk(values, self.sparsity_factor)
|
|
|
473 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
474 |
|
475 |
if self.sparsity_factor == 1:
|
476 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
477 |
|
478 |
n, h, t, d = keys.size()
|
479 |
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
|
480 |
sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
|
481 |
sparse_idx = sparse_idx.expand(n, h, -1, 1)
|
482 |
|
483 |
+
"""
|
484 |
+
t, b = self.block_size, t // self.block_size
|
485 |
+
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
|
486 |
+
sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1, 1)
|
487 |
+
sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
|
488 |
+
sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
|
489 |
+
|
490 |
+
|
491 |
+
t, b = self.block_size, t // self.block_size
|
492 |
+
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
|
493 |
+
sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
|
494 |
+
sparse_idx = (sparse_idx % t)
|
495 |
+
#sparse_idx[..., -t//2:, :] = (sparse_idx[..., -t//2:, :] + t//2) % t
|
496 |
+
sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
|
497 |
+
sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
|
498 |
+
"""
|
499 |
+
|
500 |
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
501 |
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
|
502 |
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
|
|
|
506 |
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
507 |
|
508 |
if self.sparsity_factor == 1:
|
509 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
510 |
|
511 |
block_size = min(self.block_size, self.sparse_block_size)
|
512 |
keys = self.chunk(keys, block_size)
|
|
|
523 |
extra_factor = 1
|
524 |
|
525 |
for _ in range(self.lsh_num_pre_rounds):
|
526 |
+
keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
|
527 |
|
528 |
+
keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
|
529 |
keys /= mask + 1e-8
|
530 |
values /= mask + 1e-8
|
531 |
|
|
|
533 |
|
534 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
535 |
|
536 |
+
def lsh_round(self, keys, values, mask, output_size):
|
537 |
|
538 |
with torch.no_grad():
|
539 |
|
|
|
1337 |
self.padding_idx = config.pad_token_id
|
1338 |
self.max_target_positions = config.max_position_embeddings
|
1339 |
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
1340 |
+
self.adaptive = config.adaptive
|
1341 |
|
1342 |
if embed_tokens is not None:
|
1343 |
self.embed_tokens = embed_tokens
|
|
|
1380 |
|
1381 |
return combined_attention_mask
|
1382 |
|
1383 |
+
def resize_inputs(self, inputs_embeds, attention_mask):
|
1384 |
+
pad = 0
|
1385 |
+
|
1386 |
+
max_len = int(attention_mask.sum(dim=-1).max())
|
1387 |
+
pad = attention_mask.size()[-1] - max_len
|
1388 |
+
inputs_embeds = inputs_embeds[:, :max_len]
|
1389 |
+
attention_mask = attention_mask[..., :max_len]
|
1390 |
+
return pad, inputs_embeds, attention_mask
|
1391 |
+
|
1392 |
def forward(
|
1393 |
self,
|
1394 |
input_ids=None,
|
|
|
1429 |
if inputs_embeds is None:
|
1430 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
1431 |
|
1432 |
+
# Resize to reduce computation
|
1433 |
+
pad = 0
|
1434 |
+
if self.adaptive:
|
1435 |
+
if attention_mask is not None:
|
1436 |
+
pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
|
1437 |
+
input_shape = inputs_embeds.size()[:-1]
|
1438 |
+
if encoder_attention_mask is not None:
|
1439 |
+
_, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
|
1440 |
|
1441 |
attention_mask = self._prepare_decoder_attention_mask(
|
1442 |
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
|
1530 |
if encoder_hidden_states is not None:
|
1531 |
all_cross_attentions += (layer_outputs[2],)
|
1532 |
|
1533 |
+
# Resize to original shape
|
1534 |
+
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
|
1535 |
+
|
1536 |
# add hidden states from the last decoder layer
|
1537 |
if output_hidden_states:
|
1538 |
all_hidden_states += (hidden_states,)
|
|
|
1669 |
)
|
1670 |
|
1671 |
|
1672 |
+
class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
|
1673 |
|
1674 |
base_model_prefix = "model"
|
1675 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
1676 |
|
1677 |
def __init__(self, config):
|
1678 |
|
1679 |
+
LSGBartPretrainedModel.__init__(self, config)
|
1680 |
self.model = LSGBartModel(config)
|
1681 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
1682 |
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
|
|
1684 |
# Initialize weights and apply final processing
|
1685 |
self.post_init()
|
1686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1687 |
|
1688 |
+
class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
|
1689 |
|
1690 |
+
def __init__(self, config: LSGBartConfig, **kwargs):
|
1691 |
|
1692 |
+
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
1693 |
self.model = LSGBartModel(config)
|
1694 |
self.classification_head = LSGBartClassificationHead(
|
1695 |
config.d_model,
|
|
|
1700 |
self.model._init_weights(self.classification_head.dense)
|
1701 |
self.model._init_weights(self.classification_head.out_proj)
|
1702 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1703 |
|
1704 |
+
class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
|
1705 |
|
1706 |
+
def __init__(self, config: LSGBartConfig):
|
1707 |
|
1708 |
+
LSGBartPretrainedModel.__init__(self, config)
|
|
|
|
|
1709 |
|
1710 |
config.num_labels = 2
|
1711 |
self.num_labels = config.num_labels
|
|
|
1715 |
|
1716 |
self.model._init_weights(self.qa_outputs)
|
1717 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1718 |
|
1719 |
class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
1720 |
"""
|
|
|
1722 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1723 |
"""
|
1724 |
|
1725 |
+
def __init__(self, config: LSGBartConfig):
|
1726 |
super().__init__(config)
|
1727 |
self.decoder = LSGBartDecoder(config)
|
1728 |
|
|
|
1730 |
return self.decoder(*args, **kwargs)
|
1731 |
|
1732 |
|
1733 |
+
class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
|
1734 |
|
1735 |
+
def __init__(self, config: LSGBartConfig):
|
1736 |
|
|
|
1737 |
config = copy.deepcopy(config)
|
1738 |
config.is_decoder = True
|
1739 |
config.is_encoder_decoder = False
|
1740 |
+
LSGBartPretrainedModel.__init__(self, config)
|
1741 |
self.model = LSGBartDecoderWrapper(config)
|
1742 |
|
1743 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
1745 |
# Initialize weights and apply final processing
|
1746 |
self.post_init()
|
1747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1748 |
|
1749 |
def str_to_class(classname):
|
1750 |
return getattr(sys.modules[__name__], classname)
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 578416695
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:88af6fadc19698eaa5d49e63aa969487846fbdfb41852afe199350a98d04801d
|
3 |
size 578416695
|