cmeraki commited on
Commit
fc88a3b
·
1 Parent(s): c149c23

commit files to HF hub

Browse files
Files changed (3) hide show
  1. config.json +10 -0
  2. generation_config.json +7 -3
  3. tts_pipeline.py +102 -0
config.json CHANGED
@@ -1,10 +1,20 @@
1
  {
 
2
  "activation_function": "gelu",
3
  "architectures": [
4
  "GPT2LMHeadModel"
5
  ],
6
  "attn_pdrop": 0,
7
  "bos_token_id": 50256,
 
 
 
 
 
 
 
 
 
8
  "dropout": 0,
9
  "embd_pdrop": 0,
10
  "eos_token_id": 50256,
 
1
  {
2
+ "_name_or_path": "cmeraki/mimi_124m_8cb",
3
  "activation_function": "gelu",
4
  "architectures": [
5
  "GPT2LMHeadModel"
6
  ],
7
  "attn_pdrop": 0,
8
  "bos_token_id": 50256,
9
+ "custom_pipelines": {
10
+ "indri-tts": {
11
+ "impl": "tts_pipeline.IndriTTSPipeline",
12
+ "pt": [
13
+ "AutoModelForCausalLM"
14
+ ],
15
+ "tf": []
16
+ }
17
+ },
18
  "dropout": 0,
19
  "embd_pdrop": 0,
20
  "eos_token_id": 50256,
generation_config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
2
- "_from_model_config": true,
3
- "bos_token_id": 50256,
4
- "eos_token_id": 50256,
 
 
 
 
5
  "transformers_version": "4.46.0"
6
  }
 
1
  {
2
+ "do_sample": true,
3
+ "eos_token_id": [
4
+ 66645
5
+ ],
6
+ "max_length": 1024,
7
+ "temperature": 0.5,
8
+ "top_k": 15,
9
  "transformers_version": "4.46.0"
10
  }
tts_pipeline.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import numpy as np
4
+ from transformers import MimiModel, GenerationConfig
5
+ from transformers import Pipeline
6
+
7
+ class IndriTTSPipeline(Pipeline):
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+
11
+ self.audio_tokenizer = MimiModel.from_pretrained('kyutai/mimi').to(device=self.device)
12
+
13
+ # TODO: Ideally all of this should come from model config
14
+ self.convert_token = self.tokenizer.encode('[convert]')
15
+ self.stop_token = self.tokenizer.encode('[stop]')
16
+ self.text_modality_token = self.tokenizer.encode('[text]')
17
+ self.acoustic_modality_token = self.tokenizer.encode('[mimi]')
18
+ self.num_codebooks = 8
19
+ self.audio_offset = 50257
20
+
21
+ self.model.generation_config = GenerationConfig(
22
+ eos_token_id=self.stop_token,
23
+ max_length=kwargs.get('max_length', 1024),
24
+ temperature=kwargs.get('temperature', 0.5),
25
+ top_k=kwargs.get('top_k', 15),
26
+ do_sample=kwargs.get('do_sample', True)
27
+ )
28
+
29
+ def _sanitize_parameters(self, **kwargs):
30
+ speaker = kwargs.get('speaker', '[spkr_unk]')
31
+
32
+ preprocess_kwargs = {
33
+ 'speaker': speaker
34
+ }
35
+
36
+ return preprocess_kwargs, {}, {}
37
+
38
+ def _prepare_tts_tokens(self, text_tokens, speaker):
39
+ input_tokens = np.hstack([
40
+ self.text_modality_token,
41
+ text_tokens,
42
+ self.convert_token,
43
+ self.acoustic_modality_token,
44
+ self.tokenizer.encode(speaker)
45
+ ])
46
+
47
+ return input_tokens.tolist()
48
+
49
+ def _sanitize_text(self, text):
50
+ text = text.lower()
51
+ text = re.sub(r'\n+', ' ', text)
52
+ text = re.sub(r'[ \t]+', ' ', text)
53
+
54
+ text = re.sub(r'([,\.?])+', r'\1', text)
55
+
56
+ return text.strip()
57
+
58
+ def _deserialize_tokens(self, tokens, num_codebooks):
59
+ cb = [tokens[i::num_codebooks] for i in range(num_codebooks)]
60
+ min_shape = min([c.shape for c in cb])[0]
61
+ acoustic_tokens = torch.vstack([c[:min_shape] - 2048*i for i, c in enumerate(cb)])
62
+
63
+ return acoustic_tokens
64
+
65
+ def preprocess(self, inputs, speaker):
66
+ # TODO: Check for batching
67
+ input_text = self._sanitize_text(inputs)
68
+ input_tokens = self.tokenizer.encode(input_text)
69
+ task_tokens = self._prepare_tts_tokens(input_tokens, speaker)
70
+ task_tokens = torch.tensor(task_tokens).unsqueeze(0)
71
+
72
+ return {'task_tokens': task_tokens}
73
+
74
+ def _forward(self, model_inputs, **forward_args):
75
+
76
+ outputs = self.model.generate(model_inputs['task_tokens'])
77
+ audio_tokens = []
78
+
79
+ for idx, inputs in enumerate(model_inputs['task_tokens']):
80
+ truncated = outputs[idx, inputs.shape[-1]:]
81
+ end = torch.where(truncated == self.stop_token[0])[-1]
82
+
83
+ if end.shape[-1] > 0:
84
+ end = end[0]
85
+ else:
86
+ end = truncated.shape[-1]
87
+
88
+ truncated = truncated[:end]
89
+ truncated -= self.audio_offset
90
+ truncated = self._deserialize_tokens(torch.tensor(truncated), self.num_codebooks)
91
+ audio_tokens.append(truncated)
92
+
93
+ audio_tokens = torch.vstack(audio_tokens).unsqueeze(0)
94
+ audio = self.audio_tokenizer.decode(audio_tokens).audio_values
95
+
96
+ return {
97
+ 'audio_tokens': audio_tokens, # (B, num_codebooks, num_samples)
98
+ 'audio': audio # (B, 1, num_audio_samples)
99
+ }
100
+
101
+ def postprocess(self, model_outputs):
102
+ return model_outputs