lengyue233 commited on
Commit
cb728f4
·
verified ·
1 Parent(s): d69caf0

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +1 -15
tools/llama/generate.py CHANGED
@@ -154,16 +154,11 @@ def decode_one_token_ar_agent(
154
  logits = x.logits # [:, -1:]
155
  hidden_states = x.hidden_states # [:, -1:]
156
 
157
- sampling_kwargs_main = sampling_kwargs.copy()
158
- sampling_kwargs_main["temperature"] = 0.1
159
- sampling_kwargs_main["top_p"] = 0.1
160
- sampling_kwargs_main["repetition_penalty"] = 1.0
161
-
162
  codebooks = [
163
  sample_agent(
164
  logits,
165
  previous_tokens=None, # Disable repetition penalty for the token codebook
166
- **sampling_kwargs_main,
167
  )[0]
168
  ]
169
 
@@ -194,15 +189,6 @@ def decode_one_token_ar_agent(
194
  codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
195
  )
196
 
197
- # for i in range(codebooks.size(1) - 1):
198
- # codebooks[:, i + 1, :] = torch.masked_fill(
199
- # codebooks[:, i + 1, :],
200
- # codebooks[:, :1, :] != semantic_id,
201
- # CODEBOOK_PAD_TOKEN_ID + i * 1024,
202
- # )
203
-
204
- # print(codebooks)
205
-
206
  return codebooks
207
 
208
 
 
154
  logits = x.logits # [:, -1:]
155
  hidden_states = x.hidden_states # [:, -1:]
156
 
 
 
 
 
 
157
  codebooks = [
158
  sample_agent(
159
  logits,
160
  previous_tokens=None, # Disable repetition penalty for the token codebook
161
+ **sampling_kwargs,
162
  )[0]
163
  ]
164
 
 
189
  codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
190
  )
191
 
 
 
 
 
 
 
 
 
 
192
  return codebooks
193
 
194