PB Unity commited on
Commit
6df43b3
·
verified ·
1 Parent(s): 6689fb2

Upload RunMusicGen.cs

Browse files
Files changed (1) hide show
  1. RunMusicGen.cs +348 -0
RunMusicGen.cs ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using System.Collections;
2
+ using System.Collections.Generic;
3
+ using UnityEngine;
4
+ using Unity.Sentis;
5
+ using Newtonsoft.Json;
6
+
7
+
8
+ // Inference for MusicGen-300
9
+ // ==========================
10
+ //
11
+ // Details
12
+ // -------
13
+ // The model predicts 4 streams of codes staggered like this:
14
+ // * * * a b c
15
+ // * * a b c
16
+ // * a b c
17
+ // a b c
18
+ // Then aligns the streams so that it groups all the a's togther etc.
19
+
20
+ // Put sentis files and json file in Assets/StreamingAssets folder
21
+ // Put this script on the Main Camera object
22
+ // Put an audiosource on the Main Camera
23
+ // Press play and see console window for updates
24
+
25
+ // See https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/modeling_musicgen.py
26
+
27
+
28
+ public class RunMusicGen : MonoBehaviour
29
+ {
30
+ //Change this prompt to whatever you like:
31
+ string prompt = "80s pop track with bassy drums and synth";
32
+
33
+ // number of seconds to create clip for (up to 30 seconds)
34
+ const int seconds = 2;
35
+
36
+ // Make this value smaller to make music more random
37
+ float predictability = 1f;
38
+
39
+ public AudioClip clip;
40
+ IWorker toWavEngine, decoderEngine, textEngine, projectEngine;
41
+
42
+ const int numCodeBooks = 4;
43
+
44
+ // Special music decoder tokens
45
+ const int DECODER_START_TOKEN = 2048;
46
+
47
+ // Special text encoder tokens
48
+ const int END_TEXT_TOKEN = 1;
49
+
50
+ int decoderTokens; //text tokens
51
+
52
+ List<int> tokensSoFar = new();
53
+ TensorFloat encoder_hidden_states;
54
+ TensorInt encoder_attention_mask, input_ids;
55
+ Ops ops;
56
+ Model decoder;
57
+
58
+ // How much to stagger each code stream by wrt the next one
59
+ int DELAY = 1;
60
+
61
+ // Vocab list
62
+ List<string> tokens = new List<string>();
63
+
64
+ //The output frequency must be 32kHz
65
+ const int outputFrequency = 32000;
66
+
67
+ int maxFrames;
68
+
69
+ List<int> TOKENS;
70
+
71
+ int frame = 0;
72
+ bool hasDecodedMusic = false;
73
+ void Start()
74
+ {
75
+ ops = WorkerFactory.CreateOps(BackendType.GPUCompute, null);
76
+
77
+ maxFrames = 50 * seconds + 3;
78
+
79
+ LoadVocab();
80
+
81
+ TOKENS = GetTokens(prompt);
82
+
83
+ Debug.Log("Parsed tokens=\n" + string.Join(",", TOKENS));
84
+
85
+ CreateAttentionMask();
86
+ ParseText();
87
+ LoadDecoderModel();
88
+
89
+ SetupMusicCodeStreams();
90
+
91
+ frame = 0;
92
+ }
93
+
94
+ void LoadDecoderModel()
95
+ {
96
+ decoder = ModelLoader.Load(Application.streamingAssetsPath + "/decoder.sentis");
97
+ decoderEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, decoder);
98
+ }
99
+
100
+ void CreateAttentionMask()
101
+ {
102
+ int[] mask = new int[1 * decoderTokens];
103
+ for (int i = 0; i < mask.Length; i++) mask[i] = 1;
104
+ encoder_attention_mask = new TensorInt(new TensorShape(1, decoderTokens), mask);
105
+ }
106
+
107
+ void SetupMusicCodeStreams()
108
+ {
109
+ //Sets the staggered start codes
110
+ tokensSoFar.AddRange(new int[numCodeBooks * maxFrames]);
111
+ for (int j = 0; j < maxFrames; j++)
112
+ {
113
+ for (int i = 0; i < numCodeBooks; i++)
114
+ {
115
+ if ( i * DELAY >= j)
116
+ {
117
+ tokensSoFar[i * maxFrames + j] = DECODER_START_TOKEN;
118
+ }
119
+ else
120
+ {
121
+ tokensSoFar[i * maxFrames + j] = -1;
122
+ }
123
+ }
124
+ }
125
+ input_ids = new TensorInt(new TensorShape(numCodeBooks, maxFrames), tokensSoFar.ToArray());
126
+ }
127
+
128
+ List<int> GetTokens(string text)
129
+ {
130
+ //split over whitespace
131
+ string[] words = text.ToLower().Split(null);
132
+ for (int i = 0; i < words.Length; i++) words[i] = " " + words[i];
133
+
134
+ var ids = new List<int>();
135
+
136
+ string s = "";
137
+
138
+ foreach (var word in words)
139
+ {
140
+ int start = 0;
141
+ for (int i = word.Length; i >= 0; i--)
142
+ {
143
+ string subword = word.Substring(start, i - start);
144
+ int index = tokens.IndexOf(subword);
145
+ if (index >= 0)
146
+ {
147
+ ids.Add(index);
148
+ s += subword + " ";
149
+ if (i == word.Length) break;
150
+ start = i;
151
+ i = word.Length + 1;
152
+ }
153
+ }
154
+ }
155
+
156
+ ids.Add(END_TEXT_TOKEN);
157
+
158
+ decoderTokens = ids.Count;
159
+
160
+ Debug.Log("Tokenized sentece = " + s);
161
+
162
+ return ids;
163
+ }
164
+
165
+ void ParseText()
166
+ {
167
+ Model textencoder = ModelLoader.Load(Application.streamingAssetsPath + "/textencoder.sentis");
168
+ textEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, textencoder);
169
+
170
+ Model project = ModelLoader.Load(Application.streamingAssetsPath + "/project768_1024.sentis");
171
+ projectEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, project);
172
+
173
+ using var input = new TensorInt(new TensorShape(1, decoderTokens), TOKENS.ToArray());
174
+
175
+ var inputs = new Dictionary<string, Tensor>
176
+ {
177
+ {"input_ids", input },
178
+ {"attention_mask", encoder_attention_mask }
179
+ };
180
+ textEngine.Execute(inputs);
181
+
182
+ var output = textEngine.PeekOutput() as TensorFloat;
183
+
184
+ //Convert vectors of size 768 to size 1024
185
+ projectEngine.Execute(output);
186
+ encoder_hidden_states = projectEngine.PeekOutput() as TensorFloat;
187
+ encoder_hidden_states.TakeOwnership();
188
+ }
189
+
190
+ private class TokenizerData
191
+ {
192
+ public ModelData model;
193
+ }
194
+ private class ModelData
195
+ {
196
+ public object[][] vocab;
197
+ }
198
+
199
+ void LoadVocab()
200
+ {
201
+ var data = Newtonsoft.Json.JsonConvert.DeserializeObject<TokenizerData>(System.IO.File.ReadAllText(
202
+ Application.streamingAssetsPath+"/tokenizer.json"
203
+ ));
204
+ for(int i = 0; i < data.model.vocab.Length; i++)
205
+ {
206
+ string tokenName = (string)data.model.vocab[i][0];
207
+ tokens.Add(tokenName);
208
+ }
209
+ }
210
+
211
+ // Update is called once per frame
212
+ void Update()
213
+ {
214
+ if (frame < maxFrames)
215
+ {
216
+ GetOneMusicToken();
217
+ }
218
+ else if(!hasDecodedMusic)
219
+ {
220
+ hasDecodedMusic = true;
221
+ DecodeMusic();
222
+ }
223
+ frame++;
224
+ }
225
+
226
+ void GetOneMusicToken()
227
+ {
228
+ var inputs = new Dictionary<string, Tensor>
229
+ {
230
+ {"input_ids", input_ids },
231
+ {"encoder_hidden_states", encoder_hidden_states },
232
+ {"encoder_attention_mask" , encoder_attention_mask }
233
+ };
234
+
235
+ decoderEngine.Execute(inputs);
236
+ var decoderOutput = decoderEngine.PeekOutput() as TensorFloat;
237
+ using var dec2 = ops.Mul(decoderOutput, predictability);
238
+ using var probs = ops.Softmax(dec2, 2);
239
+ probs.MakeReadable();
240
+
241
+ int OFFSET = 1;
242
+
243
+ //Add new tokens to code streams
244
+ for (int j = 0; j < numCodeBooks; j++)
245
+ {
246
+ if (frame < maxFrames - OFFSET)
247
+ {
248
+ int N = j * maxFrames + frame + OFFSET;
249
+
250
+ if (tokensSoFar[N] != DECODER_START_TOKEN)
251
+ {
252
+ tokensSoFar[N] = SelectRandomToken(probs, j, frame);
253
+ }
254
+ }
255
+ }
256
+ Replace(ref input_ids, new TensorInt(input_ids.shape, tokensSoFar.ToArray()));
257
+ Debug.Log("Frame=" + frame + "/" + maxFrames);
258
+ }
259
+
260
+ int SelectRandomToken(TensorFloat probs,int j, int frame)
261
+ {
262
+ int numItems = probs.shape[2];
263
+ float p = UnityEngine.Random.Range(0, 1f);
264
+ float tot = 0;
265
+ for(int i = 0; i < numItems; i++)
266
+ {
267
+ tot += probs[j, frame, i];
268
+ if (p <= tot) return i;
269
+ }
270
+ return numItems - 1;
271
+ }
272
+ void LoadMusicTokensToWavModel()
273
+ {
274
+ if (toWavEngine != null) return;
275
+ Model toWav = ModelLoader.Load(Application.streamingAssetsPath + "/encodec.sentis");
276
+ toWavEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, toWav);
277
+ }
278
+
279
+ void DecodeMusic()
280
+ {
281
+ Debug.Log("Please wait while music is decoded...");
282
+ LoadMusicTokensToWavModel();
283
+
284
+ using var input2 = AlignCodeStreams(input_ids);
285
+ using var wavTokens = input2.ShallowReshape(new TensorShape(1, 1, numCodeBooks, maxFrames - 3));
286
+
287
+ toWavEngine.Execute(wavTokens);
288
+ var output = toWavEngine.PeekOutput() as TensorFloat;
289
+ output.MakeReadable();
290
+
291
+ int numSamples = Mathf.Min(output.shape.length, outputFrequency * seconds);
292
+ Debug.Log("Number of samples=" + numSamples + " / " + output.shape.length);
293
+ clip = AudioClip.Create("music", numSamples, 1, outputFrequency, false);
294
+
295
+ float[] wav = new float[numSamples];
296
+ System.Array.Copy(output.ToReadOnlyArray(), wav, numSamples);
297
+ clip.SetData(wav, 0);
298
+
299
+ var audioSource = GetComponent<AudioSource>();
300
+ if (audioSource != null)
301
+ {
302
+ audioSource.PlayOneShot(clip);
303
+ }
304
+ else
305
+ {
306
+ Debug.Log("You need to attach audio source to this object to hear the music");
307
+ }
308
+ }
309
+
310
+ TensorInt AlignCodeStreams(TensorInt input)
311
+ {
312
+ if (DELAY == 0)
313
+ {
314
+ return ops.Copy(input);
315
+ }
316
+ using var input2 = ops.Cast(input, DataType.Float);
317
+ TensorFloat[] B = new TensorFloat[4];
318
+ for (int i = 0; i < 4; i++) {
319
+ using TensorFloat A = ops.Slice(input2, new int[] { i }, new int[] { i + 1 }, new int[] { 0 }, new int[] { 1 }) as TensorFloat;
320
+ B[i] = ops.Pad(A, new int[] { 0, -i, 0, i - 3 });
321
+ }
322
+ using var input3 = ops.Concat(B, 0) as TensorFloat;
323
+ for(int i = 0; i < 4; i++)
324
+ {
325
+ B[i].Dispose();
326
+ }
327
+ return ops.Cast(input3, DataType.Int) as TensorInt;
328
+ }
329
+
330
+ void Replace<T>(ref T A, T B) where T:System.IDisposable
331
+ {
332
+ A?.Dispose();
333
+ A = B;
334
+ }
335
+
336
+ private void OnDestroy()
337
+ {
338
+ input_ids?.Dispose();
339
+ encoder_attention_mask?.Dispose();
340
+ encoder_hidden_states?.Dispose();
341
+ ops?.Dispose();
342
+ decoderEngine?.Dispose();
343
+ toWavEngine?.Dispose();
344
+ projectEngine?.Dispose();
345
+ textEngine?.Dispose();
346
+ }
347
+ }
348
+