Spaces:
Sleeping
Sleeping
import sys | |
import os | |
import pickle | |
import gzip | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from scipy import stats | |
from gluformer.model import Gluformer | |
from utils.darts_processing import * | |
from utils.darts_dataset import * | |
import hashlib | |
from urllib.parse import urlparse | |
from huggingface_hub import hf_hub_download | |
import plotly.graph_objects as go | |
import gradio as gr | |
from format_dexcom import * | |
from typing import Tuple, Union, List | |
from plotly.graph_objs._figure import Figure | |
from gradio.components import Slider | |
from gradio.components import Markdown | |
glucose = Path(os.path.abspath(__file__)).parent.resolve() | |
file_directory = glucose / "files" | |
def plot_forecast(forecasts: np.ndarray, filename: str,ind:int=10) -> Tuple[Path, Figure]: | |
forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_ | |
trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))] | |
trues = scalers['target'].inverse_transform(trues) | |
trues = [ts.values() for ts in trues] # Convert TimeSeries to numpy arrays | |
trues = np.array(trues) | |
inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))] | |
inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_ | |
# Select a specific sample to plot | |
samples = np.random.normal( | |
loc=forecasts[ind, :], # Mean (center) of the distribution | |
scale=0.1, # Standard deviation (spread) of the distribution | |
size=(forecasts.shape[1], forecasts.shape[2]) | |
) | |
# Create figure | |
fig = go.Figure() | |
# Plot predictive distribution | |
for point in range(samples.shape[0]): | |
kde = stats.gaussian_kde(samples[point,:]) | |
maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :]) | |
y_grid = np.linspace(mini, maxi, 200) | |
x = kde(y_grid) | |
# Create gradient color | |
color = f'rgba(53, 138, 217, {(point + 1) / samples.shape[0]})' | |
# Add filled area | |
fig.add_trace(go.Scatter( | |
x=np.concatenate([np.full_like(y_grid, point), np.full_like(y_grid, point - x * 15)[::-1]]), | |
y=np.concatenate([y_grid, y_grid[::-1]]), | |
fill='tonexty', | |
fillcolor=color, | |
line=dict(color='rgba(0,0,0,0)'), | |
showlegend=False | |
)) | |
true_values = np.concatenate([inputs[ind, -12:], trues[ind, :]]) | |
true_values_flat=true_values.flatten() | |
fig.add_trace(go.Scatter( | |
x=list(range(-12, 12)), | |
y=true_values_flat.tolist(), # Convert to list explicitly | |
mode='lines+markers', | |
line=dict(color='blue', width=2), | |
marker=dict(size=6), | |
name='True Values' | |
)) | |
# Plot median | |
forecast = samples[:, :] | |
median = np.quantile(forecast, 0.5, axis=-1) | |
fig.add_trace(go.Scatter( | |
x=list(range(12)), | |
y=median.tolist(), # Convert to list explicitly | |
mode='lines+markers', | |
line=dict(color='red', width=2), | |
marker=dict(size=8), | |
name='Median Forecast' | |
)) | |
# Update layout | |
fig.update_layout( | |
title='Gluformer Prediction with Gradient for dataset', | |
xaxis_title='Time (in 5 minute intervals)', | |
yaxis_title='Glucose (mg/dL)', | |
font=dict(size=14), | |
showlegend=True, | |
width=1000, | |
height=600 | |
) | |
# Save figure | |
where = file_directory / filename | |
fig.write_html(str(where.with_suffix('.html'))) | |
fig.write_image(str(where)) | |
return where, fig | |
def generate_filename_from_url(url: str, extension: str = "png") -> str: | |
""" | |
:param url: | |
:param extension: | |
:return: | |
""" | |
# Extract the last segment of the URL | |
last_segment = urlparse(url).path.split('/')[-1] | |
# Compute the hash of the URL | |
url_hash = hashlib.md5(url.encode('utf-8')).hexdigest() | |
# Create the filename | |
filename = f"{last_segment.replace('.','_')}_{url_hash}.{extension}" | |
return filename | |
glufo = None | |
scalers = None | |
dataset_test_glufo = None | |
filename = None | |
def prep_predict_glucose_tool(file: Union[str, Path], model_name: str = "gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth") -> Tuple[Slider, Markdown]: | |
""" | |
Function to predict future glucose of user. | |
""" | |
global formatter, series, scalers, glufo, dataset_test_glufo, filename | |
model = "Livia-Zaharia/gluformer_models" | |
model_path = hf_hub_download(repo_id=model, filename=model_name) | |
formatter, series, scalers = load_data( | |
url=str(file), | |
config_path=file_directory / "config.yaml", | |
use_covs=True, | |
cov_type='dual', | |
use_static_covs=True | |
) | |
formatter.params['gluformer'] = { | |
'in_len': 96, # example input length, adjust as necessary | |
'd_model': 512, # model dimension | |
'n_heads': 10, # number of attention heads######################## | |
'd_fcn': 1024, # fully connected layer dimension | |
'num_enc_layers': 2, # number of encoder layers | |
'num_dec_layers': 2, # number of decoder layers | |
'length_pred': 12 # prediction length, adjust as necessary represents 1 h | |
} | |
num_dynamic_features = series['train']['future'][-1].n_components | |
num_static_features = series['train']['static'][-1].n_components | |
global glufo | |
glufo = Gluformer( | |
d_model=formatter.params['gluformer']['d_model'], | |
n_heads=formatter.params['gluformer']['n_heads'], | |
d_fcn=formatter.params['gluformer']['d_fcn'], | |
r_drop=0.2, | |
activ='gelu', | |
num_enc_layers=formatter.params['gluformer']['num_enc_layers'], | |
num_dec_layers=formatter.params['gluformer']['num_dec_layers'], | |
distil=True, | |
len_seq=formatter.params['gluformer']['in_len'], | |
label_len=formatter.params['gluformer']['in_len'] // 3, | |
len_pred=formatter.params['length_pred'], | |
num_dynamic_features=num_dynamic_features, | |
num_static_features=num_static_features | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
glufo.load_state_dict(torch.load(str(model_path), map_location=torch.device(device))) | |
global dataset_test_glufo | |
dataset_test_glufo = SamplingDatasetInferenceDual( | |
target_series=series['test']['target'], | |
covariates=series['test']['future'], | |
input_chunk_length=formatter.params['gluformer']['in_len'], | |
output_chunk_length=formatter.params['length_pred'], | |
use_static_covariates=True, | |
array_output_only=True | |
) | |
global filename | |
filename = generate_filename_from_url(file) | |
max_index = len(dataset_test_glufo) - 1 | |
print(f"Total number of test samples: {max_index + 1}") | |
return ( | |
gr.Slider( | |
minimum=0, | |
maximum=max_index-1, | |
value=max_index, | |
step=1, | |
label="Select Sample Index", | |
visible=True | |
), | |
gr.Markdown(f"Total number of test samples: {max_index + 1}", visible=True) | |
) | |
def predict_glucose_tool(ind: int) -> Figure: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
forecasts, _ = glufo.predict( | |
dataset_test_glufo, | |
batch_size=16,####### | |
num_samples=10, | |
device=device | |
) | |
figure_path, result = plot_forecast(forecasts,filename,ind) | |
return result | |