Livia_Zaharia commited on
Commit
de5608c
·
1 Parent(s): 64e42c0
Files changed (2) hide show
  1. app.py +39 -3
  2. tools.py +39 -19
app.py CHANGED
@@ -2,7 +2,43 @@ import gradio as gr
2
  from tools import *
3
 
4
 
5
- def gradio_output(file):
6
- return (predict_glucose_tool(file))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- gr.Interface(fn=gradio_output,inputs=gr.File(label="Upload CSV File"),outputs="plot").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from tools import *
3
 
4
 
5
+ with gr.Blocks() as demo:
6
+ file_input = gr.File(label="Upload CSV File")
7
+ with gr.Row():
8
+ index_slider = gr.Slider(
9
+ minimum=0,
10
+ maximum=100, # This will be updated dynamically
11
+ value=10,
12
+ step=1,
13
+ label="Select Sample Index",
14
+ visible=False
15
+ )
16
+ sample_count = gr.Markdown(visible=False)
17
+ plot_output = gr.Plot()
18
+
19
+ # Update slider and show total samples when file is uploaded
20
+ file_input.change(
21
+ fn=prep_predict_glucose_tool,
22
+ inputs=[file_input],
23
+ outputs=[index_slider, sample_count],
24
+ queue=False
25
+ )
26
+ # Set visibility separately
27
+ file_input.change(
28
+ fn=lambda: (gr.Slider(visible=True), gr.Markdown(visible=True)),
29
+ outputs=[index_slider, sample_count]
30
+ )
31
 
32
+ # Update plot when slider changes or file uploads
33
+ file_input.change(
34
+ fn=predict_glucose_tool,
35
+ inputs=[index_slider],
36
+ outputs=plot_output
37
+ )
38
+ index_slider.change(
39
+ fn=predict_glucose_tool,
40
+ inputs=[index_slider],
41
+ outputs=plot_output
42
+ )
43
+
44
+ demo.launch()
tools.py CHANGED
@@ -13,13 +13,14 @@ import hashlib
13
  from urllib.parse import urlparse
14
  from huggingface_hub import hf_hub_download
15
  import plotly.graph_objects as go
 
16
 
17
 
18
  glucose = Path(os.path.abspath(__file__)).parent.resolve()
19
  file_directory = glucose / "files"
20
 
21
 
22
- def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str):
23
 
24
  forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
25
 
@@ -33,7 +34,6 @@ def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any,
33
  inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
34
 
35
  # Select a specific sample to plot
36
- ind = 10 # Example index
37
 
38
  samples = np.random.normal(
39
  loc=forecasts[ind, :], # Mean (center) of the distribution
@@ -129,8 +129,7 @@ def generate_filename_from_url(url: str, extension: str = "png") -> str:
129
  return filename
130
 
131
 
132
-
133
- def predict_glucose_tool(file) -> go.Figure:
134
  """
135
  Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
136
  :param file: it is the csv file imported as a string path to the temporary location gradio allows
@@ -140,16 +139,17 @@ def predict_glucose_tool(file) -> go.Figure:
140
  :return:
141
  """
142
 
143
- url = file
144
  model="Livia-Zaharia/gluformer_models"
145
  model_path = hf_hub_download(repo_id= model, filename="gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth")
146
 
147
-
148
- formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True,
 
 
149
  cov_type='dual',
150
  use_static_covs=True)
151
 
152
- filename = generate_filename_from_url(url)
153
 
154
  formatter.params['gluformer'] = {
155
  'in_len': 96, # example input length, adjust as necessary
@@ -158,12 +158,13 @@ def predict_glucose_tool(file) -> go.Figure:
158
  'd_fcn': 1024, # fully connected layer dimension
159
  'num_enc_layers': 2, # number of encoder layers
160
  'num_dec_layers': 2, # number of decoder layers
161
- 'length_pred': 12 # prediction length, adjust as necessary
162
  }
163
 
164
  num_dynamic_features = series['train']['future'][-1].n_components
165
  num_static_features = series['train']['static'][-1].n_components
166
 
 
167
  glufo = Gluformer(
168
  d_model=formatter.params['gluformer']['d_model'],
169
  n_heads=formatter.params['gluformer']['n_heads'],
@@ -183,6 +184,7 @@ def predict_glucose_tool(file) -> go.Figure:
183
  device = "cuda" if torch.cuda.is_available() else "cpu"
184
  glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device), weights_only=True))
185
 
 
186
  # Define dataset for inference
187
  dataset_test_glufo = SamplingDatasetInferenceDual(
188
  target_series=series['test']['target'],
@@ -193,17 +195,35 @@ def predict_glucose_tool(file) -> go.Figure:
193
  array_output_only=True
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  forecasts, _ = glufo.predict(
197
- dataset_test_glufo,
198
- batch_size=16,#######
199
- num_samples=10,
200
- device=device
201
  )
202
- figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename)
203
 
204
  return result
205
-
206
-
207
-
208
- if __name__ == "__main__":
209
- predict_glucose_tool()
 
13
  from urllib.parse import urlparse
14
  from huggingface_hub import hf_hub_download
15
  import plotly.graph_objects as go
16
+ import gradio as gr
17
 
18
 
19
  glucose = Path(os.path.abspath(__file__)).parent.resolve()
20
  file_directory = glucose / "files"
21
 
22
 
23
+ def plot_forecast(forecasts: np.ndarray, filename: str,ind:int=10):
24
 
25
  forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
26
 
 
34
  inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
35
 
36
  # Select a specific sample to plot
 
37
 
38
  samples = np.random.normal(
39
  loc=forecasts[ind, :], # Mean (center) of the distribution
 
129
  return filename
130
 
131
 
132
+ def prep_predict_glucose_tool(file):
 
133
  """
134
  Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
135
  :param file: it is the csv file imported as a string path to the temporary location gradio allows
 
139
  :return:
140
  """
141
 
142
+
143
  model="Livia-Zaharia/gluformer_models"
144
  model_path = hf_hub_download(repo_id= model, filename="gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth")
145
 
146
+ global formatter
147
+ global series
148
+ global scalers
149
+ formatter, series, scalers = load_data(url=str(file), config_path=file_directory / "config.yaml", use_covs=True,
150
  cov_type='dual',
151
  use_static_covs=True)
152
 
 
153
 
154
  formatter.params['gluformer'] = {
155
  'in_len': 96, # example input length, adjust as necessary
 
158
  'd_fcn': 1024, # fully connected layer dimension
159
  'num_enc_layers': 2, # number of encoder layers
160
  'num_dec_layers': 2, # number of decoder layers
161
+ 'length_pred': 12 # prediction length, adjust as necessary represents 1 h
162
  }
163
 
164
  num_dynamic_features = series['train']['future'][-1].n_components
165
  num_static_features = series['train']['static'][-1].n_components
166
 
167
+ global glufo
168
  glufo = Gluformer(
169
  d_model=formatter.params['gluformer']['d_model'],
170
  n_heads=formatter.params['gluformer']['n_heads'],
 
184
  device = "cuda" if torch.cuda.is_available() else "cpu"
185
  glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device), weights_only=True))
186
 
187
+ global dataset_test_glufo
188
  # Define dataset for inference
189
  dataset_test_glufo = SamplingDatasetInferenceDual(
190
  target_series=series['test']['target'],
 
195
  array_output_only=True
196
  )
197
 
198
+ global filename
199
+ filename = generate_filename_from_url(file)
200
+
201
+ max_index = len(dataset_test_glufo) - 1
202
+
203
+ return (
204
+ gr.Slider(
205
+ minimum=0,
206
+ maximum=max_index,
207
+ value=10,
208
+ step=1,
209
+ label="Select Sample Index",
210
+ ),
211
+ gr.Markdown(f"Total number of test samples: {max_index + 1}")
212
+ )
213
+
214
+
215
+ def predict_glucose_tool(ind) -> go.Figure:
216
+
217
+
218
+
219
+ device = "cuda" if torch.cuda.is_available() else "cpu"
220
+
221
  forecasts, _ = glufo.predict(
222
+ dataset_test_glufo,
223
+ batch_size=16,#######
224
+ num_samples=10,
225
+ device=device
226
  )
227
+ figure_path, result = plot_forecast(forecasts,filename,ind)
228
 
229
  return result