PB Unity
commited on
Upload RunMusicGen.cs
Browse files- RunMusicGen.cs +348 -0
RunMusicGen.cs
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
using System.Collections;
|
2 |
+
using System.Collections.Generic;
|
3 |
+
using UnityEngine;
|
4 |
+
using Unity.Sentis;
|
5 |
+
using Newtonsoft.Json;
|
6 |
+
|
7 |
+
|
8 |
+
// Inference for MusicGen-300
|
9 |
+
// ==========================
|
10 |
+
//
|
11 |
+
// Details
|
12 |
+
// -------
|
13 |
+
// The model predicts 4 streams of codes staggered like this:
|
14 |
+
// * * * a b c
|
15 |
+
// * * a b c
|
16 |
+
// * a b c
|
17 |
+
// a b c
|
18 |
+
// Then aligns the streams so that it groups all the a's togther etc.
|
19 |
+
|
20 |
+
// Put sentis files and json file in Assets/StreamingAssets folder
|
21 |
+
// Put this script on the Main Camera object
|
22 |
+
// Put an audiosource on the Main Camera
|
23 |
+
// Press play and see console window for updates
|
24 |
+
|
25 |
+
// See https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/modeling_musicgen.py
|
26 |
+
|
27 |
+
|
28 |
+
public class RunMusicGen : MonoBehaviour
|
29 |
+
{
|
30 |
+
//Change this prompt to whatever you like:
|
31 |
+
string prompt = "80s pop track with bassy drums and synth";
|
32 |
+
|
33 |
+
// number of seconds to create clip for (up to 30 seconds)
|
34 |
+
const int seconds = 2;
|
35 |
+
|
36 |
+
// Make this value smaller to make music more random
|
37 |
+
float predictability = 1f;
|
38 |
+
|
39 |
+
public AudioClip clip;
|
40 |
+
IWorker toWavEngine, decoderEngine, textEngine, projectEngine;
|
41 |
+
|
42 |
+
const int numCodeBooks = 4;
|
43 |
+
|
44 |
+
// Special music decoder tokens
|
45 |
+
const int DECODER_START_TOKEN = 2048;
|
46 |
+
|
47 |
+
// Special text encoder tokens
|
48 |
+
const int END_TEXT_TOKEN = 1;
|
49 |
+
|
50 |
+
int decoderTokens; //text tokens
|
51 |
+
|
52 |
+
List<int> tokensSoFar = new();
|
53 |
+
TensorFloat encoder_hidden_states;
|
54 |
+
TensorInt encoder_attention_mask, input_ids;
|
55 |
+
Ops ops;
|
56 |
+
Model decoder;
|
57 |
+
|
58 |
+
// How much to stagger each code stream by wrt the next one
|
59 |
+
int DELAY = 1;
|
60 |
+
|
61 |
+
// Vocab list
|
62 |
+
List<string> tokens = new List<string>();
|
63 |
+
|
64 |
+
//The output frequency must be 32kHz
|
65 |
+
const int outputFrequency = 32000;
|
66 |
+
|
67 |
+
int maxFrames;
|
68 |
+
|
69 |
+
List<int> TOKENS;
|
70 |
+
|
71 |
+
int frame = 0;
|
72 |
+
bool hasDecodedMusic = false;
|
73 |
+
void Start()
|
74 |
+
{
|
75 |
+
ops = WorkerFactory.CreateOps(BackendType.GPUCompute, null);
|
76 |
+
|
77 |
+
maxFrames = 50 * seconds + 3;
|
78 |
+
|
79 |
+
LoadVocab();
|
80 |
+
|
81 |
+
TOKENS = GetTokens(prompt);
|
82 |
+
|
83 |
+
Debug.Log("Parsed tokens=\n" + string.Join(",", TOKENS));
|
84 |
+
|
85 |
+
CreateAttentionMask();
|
86 |
+
ParseText();
|
87 |
+
LoadDecoderModel();
|
88 |
+
|
89 |
+
SetupMusicCodeStreams();
|
90 |
+
|
91 |
+
frame = 0;
|
92 |
+
}
|
93 |
+
|
94 |
+
void LoadDecoderModel()
|
95 |
+
{
|
96 |
+
decoder = ModelLoader.Load(Application.streamingAssetsPath + "/decoder.sentis");
|
97 |
+
decoderEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, decoder);
|
98 |
+
}
|
99 |
+
|
100 |
+
void CreateAttentionMask()
|
101 |
+
{
|
102 |
+
int[] mask = new int[1 * decoderTokens];
|
103 |
+
for (int i = 0; i < mask.Length; i++) mask[i] = 1;
|
104 |
+
encoder_attention_mask = new TensorInt(new TensorShape(1, decoderTokens), mask);
|
105 |
+
}
|
106 |
+
|
107 |
+
void SetupMusicCodeStreams()
|
108 |
+
{
|
109 |
+
//Sets the staggered start codes
|
110 |
+
tokensSoFar.AddRange(new int[numCodeBooks * maxFrames]);
|
111 |
+
for (int j = 0; j < maxFrames; j++)
|
112 |
+
{
|
113 |
+
for (int i = 0; i < numCodeBooks; i++)
|
114 |
+
{
|
115 |
+
if ( i * DELAY >= j)
|
116 |
+
{
|
117 |
+
tokensSoFar[i * maxFrames + j] = DECODER_START_TOKEN;
|
118 |
+
}
|
119 |
+
else
|
120 |
+
{
|
121 |
+
tokensSoFar[i * maxFrames + j] = -1;
|
122 |
+
}
|
123 |
+
}
|
124 |
+
}
|
125 |
+
input_ids = new TensorInt(new TensorShape(numCodeBooks, maxFrames), tokensSoFar.ToArray());
|
126 |
+
}
|
127 |
+
|
128 |
+
List<int> GetTokens(string text)
|
129 |
+
{
|
130 |
+
//split over whitespace
|
131 |
+
string[] words = text.ToLower().Split(null);
|
132 |
+
for (int i = 0; i < words.Length; i++) words[i] = " " + words[i];
|
133 |
+
|
134 |
+
var ids = new List<int>();
|
135 |
+
|
136 |
+
string s = "";
|
137 |
+
|
138 |
+
foreach (var word in words)
|
139 |
+
{
|
140 |
+
int start = 0;
|
141 |
+
for (int i = word.Length; i >= 0; i--)
|
142 |
+
{
|
143 |
+
string subword = word.Substring(start, i - start);
|
144 |
+
int index = tokens.IndexOf(subword);
|
145 |
+
if (index >= 0)
|
146 |
+
{
|
147 |
+
ids.Add(index);
|
148 |
+
s += subword + " ";
|
149 |
+
if (i == word.Length) break;
|
150 |
+
start = i;
|
151 |
+
i = word.Length + 1;
|
152 |
+
}
|
153 |
+
}
|
154 |
+
}
|
155 |
+
|
156 |
+
ids.Add(END_TEXT_TOKEN);
|
157 |
+
|
158 |
+
decoderTokens = ids.Count;
|
159 |
+
|
160 |
+
Debug.Log("Tokenized sentece = " + s);
|
161 |
+
|
162 |
+
return ids;
|
163 |
+
}
|
164 |
+
|
165 |
+
void ParseText()
|
166 |
+
{
|
167 |
+
Model textencoder = ModelLoader.Load(Application.streamingAssetsPath + "/textencoder.sentis");
|
168 |
+
textEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, textencoder);
|
169 |
+
|
170 |
+
Model project = ModelLoader.Load(Application.streamingAssetsPath + "/project768_1024.sentis");
|
171 |
+
projectEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, project);
|
172 |
+
|
173 |
+
using var input = new TensorInt(new TensorShape(1, decoderTokens), TOKENS.ToArray());
|
174 |
+
|
175 |
+
var inputs = new Dictionary<string, Tensor>
|
176 |
+
{
|
177 |
+
{"input_ids", input },
|
178 |
+
{"attention_mask", encoder_attention_mask }
|
179 |
+
};
|
180 |
+
textEngine.Execute(inputs);
|
181 |
+
|
182 |
+
var output = textEngine.PeekOutput() as TensorFloat;
|
183 |
+
|
184 |
+
//Convert vectors of size 768 to size 1024
|
185 |
+
projectEngine.Execute(output);
|
186 |
+
encoder_hidden_states = projectEngine.PeekOutput() as TensorFloat;
|
187 |
+
encoder_hidden_states.TakeOwnership();
|
188 |
+
}
|
189 |
+
|
190 |
+
private class TokenizerData
|
191 |
+
{
|
192 |
+
public ModelData model;
|
193 |
+
}
|
194 |
+
private class ModelData
|
195 |
+
{
|
196 |
+
public object[][] vocab;
|
197 |
+
}
|
198 |
+
|
199 |
+
void LoadVocab()
|
200 |
+
{
|
201 |
+
var data = Newtonsoft.Json.JsonConvert.DeserializeObject<TokenizerData>(System.IO.File.ReadAllText(
|
202 |
+
Application.streamingAssetsPath+"/tokenizer.json"
|
203 |
+
));
|
204 |
+
for(int i = 0; i < data.model.vocab.Length; i++)
|
205 |
+
{
|
206 |
+
string tokenName = (string)data.model.vocab[i][0];
|
207 |
+
tokens.Add(tokenName);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
|
211 |
+
// Update is called once per frame
|
212 |
+
void Update()
|
213 |
+
{
|
214 |
+
if (frame < maxFrames)
|
215 |
+
{
|
216 |
+
GetOneMusicToken();
|
217 |
+
}
|
218 |
+
else if(!hasDecodedMusic)
|
219 |
+
{
|
220 |
+
hasDecodedMusic = true;
|
221 |
+
DecodeMusic();
|
222 |
+
}
|
223 |
+
frame++;
|
224 |
+
}
|
225 |
+
|
226 |
+
void GetOneMusicToken()
|
227 |
+
{
|
228 |
+
var inputs = new Dictionary<string, Tensor>
|
229 |
+
{
|
230 |
+
{"input_ids", input_ids },
|
231 |
+
{"encoder_hidden_states", encoder_hidden_states },
|
232 |
+
{"encoder_attention_mask" , encoder_attention_mask }
|
233 |
+
};
|
234 |
+
|
235 |
+
decoderEngine.Execute(inputs);
|
236 |
+
var decoderOutput = decoderEngine.PeekOutput() as TensorFloat;
|
237 |
+
using var dec2 = ops.Mul(decoderOutput, predictability);
|
238 |
+
using var probs = ops.Softmax(dec2, 2);
|
239 |
+
probs.MakeReadable();
|
240 |
+
|
241 |
+
int OFFSET = 1;
|
242 |
+
|
243 |
+
//Add new tokens to code streams
|
244 |
+
for (int j = 0; j < numCodeBooks; j++)
|
245 |
+
{
|
246 |
+
if (frame < maxFrames - OFFSET)
|
247 |
+
{
|
248 |
+
int N = j * maxFrames + frame + OFFSET;
|
249 |
+
|
250 |
+
if (tokensSoFar[N] != DECODER_START_TOKEN)
|
251 |
+
{
|
252 |
+
tokensSoFar[N] = SelectRandomToken(probs, j, frame);
|
253 |
+
}
|
254 |
+
}
|
255 |
+
}
|
256 |
+
Replace(ref input_ids, new TensorInt(input_ids.shape, tokensSoFar.ToArray()));
|
257 |
+
Debug.Log("Frame=" + frame + "/" + maxFrames);
|
258 |
+
}
|
259 |
+
|
260 |
+
int SelectRandomToken(TensorFloat probs,int j, int frame)
|
261 |
+
{
|
262 |
+
int numItems = probs.shape[2];
|
263 |
+
float p = UnityEngine.Random.Range(0, 1f);
|
264 |
+
float tot = 0;
|
265 |
+
for(int i = 0; i < numItems; i++)
|
266 |
+
{
|
267 |
+
tot += probs[j, frame, i];
|
268 |
+
if (p <= tot) return i;
|
269 |
+
}
|
270 |
+
return numItems - 1;
|
271 |
+
}
|
272 |
+
void LoadMusicTokensToWavModel()
|
273 |
+
{
|
274 |
+
if (toWavEngine != null) return;
|
275 |
+
Model toWav = ModelLoader.Load(Application.streamingAssetsPath + "/encodec.sentis");
|
276 |
+
toWavEngine = WorkerFactory.CreateWorker(BackendType.GPUCompute, toWav);
|
277 |
+
}
|
278 |
+
|
279 |
+
void DecodeMusic()
|
280 |
+
{
|
281 |
+
Debug.Log("Please wait while music is decoded...");
|
282 |
+
LoadMusicTokensToWavModel();
|
283 |
+
|
284 |
+
using var input2 = AlignCodeStreams(input_ids);
|
285 |
+
using var wavTokens = input2.ShallowReshape(new TensorShape(1, 1, numCodeBooks, maxFrames - 3));
|
286 |
+
|
287 |
+
toWavEngine.Execute(wavTokens);
|
288 |
+
var output = toWavEngine.PeekOutput() as TensorFloat;
|
289 |
+
output.MakeReadable();
|
290 |
+
|
291 |
+
int numSamples = Mathf.Min(output.shape.length, outputFrequency * seconds);
|
292 |
+
Debug.Log("Number of samples=" + numSamples + " / " + output.shape.length);
|
293 |
+
clip = AudioClip.Create("music", numSamples, 1, outputFrequency, false);
|
294 |
+
|
295 |
+
float[] wav = new float[numSamples];
|
296 |
+
System.Array.Copy(output.ToReadOnlyArray(), wav, numSamples);
|
297 |
+
clip.SetData(wav, 0);
|
298 |
+
|
299 |
+
var audioSource = GetComponent<AudioSource>();
|
300 |
+
if (audioSource != null)
|
301 |
+
{
|
302 |
+
audioSource.PlayOneShot(clip);
|
303 |
+
}
|
304 |
+
else
|
305 |
+
{
|
306 |
+
Debug.Log("You need to attach audio source to this object to hear the music");
|
307 |
+
}
|
308 |
+
}
|
309 |
+
|
310 |
+
TensorInt AlignCodeStreams(TensorInt input)
|
311 |
+
{
|
312 |
+
if (DELAY == 0)
|
313 |
+
{
|
314 |
+
return ops.Copy(input);
|
315 |
+
}
|
316 |
+
using var input2 = ops.Cast(input, DataType.Float);
|
317 |
+
TensorFloat[] B = new TensorFloat[4];
|
318 |
+
for (int i = 0; i < 4; i++) {
|
319 |
+
using TensorFloat A = ops.Slice(input2, new int[] { i }, new int[] { i + 1 }, new int[] { 0 }, new int[] { 1 }) as TensorFloat;
|
320 |
+
B[i] = ops.Pad(A, new int[] { 0, -i, 0, i - 3 });
|
321 |
+
}
|
322 |
+
using var input3 = ops.Concat(B, 0) as TensorFloat;
|
323 |
+
for(int i = 0; i < 4; i++)
|
324 |
+
{
|
325 |
+
B[i].Dispose();
|
326 |
+
}
|
327 |
+
return ops.Cast(input3, DataType.Int) as TensorInt;
|
328 |
+
}
|
329 |
+
|
330 |
+
void Replace<T>(ref T A, T B) where T:System.IDisposable
|
331 |
+
{
|
332 |
+
A?.Dispose();
|
333 |
+
A = B;
|
334 |
+
}
|
335 |
+
|
336 |
+
private void OnDestroy()
|
337 |
+
{
|
338 |
+
input_ids?.Dispose();
|
339 |
+
encoder_attention_mask?.Dispose();
|
340 |
+
encoder_hidden_states?.Dispose();
|
341 |
+
ops?.Dispose();
|
342 |
+
decoderEngine?.Dispose();
|
343 |
+
toWavEngine?.Dispose();
|
344 |
+
projectEngine?.Dispose();
|
345 |
+
textEngine?.Dispose();
|
346 |
+
}
|
347 |
+
}
|
348 |
+
|