Sakalti commited on
Commit
cb6fd2e
·
verified ·
1 Parent(s): 172f090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
3
- from datasets import load_dataset, Dataset, DatasetDict
4
  import os
5
  import time
6
 
@@ -123,7 +123,7 @@ class CustomCallback(TrainerCallback):
123
  global progress_info
124
  total_steps = state.max_steps
125
  current_step = state.global_step
126
- elapsed_time = time.time() - state.log_history[0]["epoch_time"]
127
  time_per_step = elapsed_time / (current_step + 1)
128
  remaining_steps = total_steps - current_step
129
  time_remaining = time_per_step * remaining_steps
@@ -149,6 +149,7 @@ with gr.Blocks() as demo:
149
  progress.update(value=progress_info["progress"])
150
  time_remaining.value = f"{progress_info['time_remaining']}秒" if progress_info['time_remaining'] else "待機中"
151
 
152
- train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output).then(fn=update_ui)
 
153
 
154
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, TrainerCallback
3
+ from datasets import load_dataset, DatasetDict
4
  import os
5
  import time
6
 
 
123
  global progress_info
124
  total_steps = state.max_steps
125
  current_step = state.global_step
126
+ elapsed_time = time.time() - state.log_history[0].get("epoch_time", time.time()) # デフォルト値を追加
127
  time_per_step = elapsed_time / (current_step + 1)
128
  remaining_steps = total_steps - current_step
129
  time_remaining = time_per_step * remaining_steps
 
149
  progress.update(value=progress_info["progress"])
150
  time_remaining.value = f"{progress_info['time_remaining']}秒" if progress_info['time_remaining'] else "待機中"
151
 
152
+ train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output)
153
+ train_button.click(fn=update_ui, inputs=[], outputs=[status, progress, time_remaining])
154
 
155
  demo.launch()