Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModel | |
from pathlib import Path | |
import numpy as np | |
from decord import VideoReader | |
import imageio | |
FRAME_SAMPLING_RATE = 4 | |
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" | |
processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) | |
model = AutoModel.from_pretrained(DEFAULT_MODEL) | |
ROOM_TYPES = ( | |
"bathroom,sauna,living room, bedroom,kitchen,toilet,hallway,dressing,attic,basement,home office,garage" | |
) | |
examples = [ | |
[ | |
"movies/bathroom.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/bedroom.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/dressing.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/home-office.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/kitchen.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/living-room.mp4", | |
ROOM_TYPES, | |
], | |
[ | |
"movies/toilet.mp4", | |
ROOM_TYPES, | |
], | |
] | |
def sample_frames_from_video_file( | |
file_path: str, num_frames: int = 16, frame_sampling_rate=1 | |
): | |
videoreader = VideoReader(file_path) | |
videoreader.seek(0) | |
# sample frames | |
start_idx = 0 | |
end_idx = num_frames * frame_sampling_rate - 1 | |
indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64) | |
frames = videoreader.get_batch(indices).asnumpy() | |
return frames | |
def get_num_total_frames(file_path: str): | |
videoreader = VideoReader(file_path) | |
videoreader.seek(0) | |
return len(videoreader) | |
def select_model(model_name): | |
global processor, model | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
def get_frame_sampling_rate(video_path, num_model_input_frames): | |
# rearrange sampling rate based on video length and model input length | |
num_total_frames = get_num_total_frames(video_path) | |
if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: | |
frame_sampling_rate = num_total_frames // num_model_input_frames | |
else: | |
frame_sampling_rate = FRAME_SAMPLING_RATE | |
return frame_sampling_rate | |
def predict(video_path, labels_text): | |
labels = labels_text.split(",") | |
num_model_input_frames = model.config.vision_config.num_frames | |
frame_sampling_rate = get_frame_sampling_rate(video_path, num_model_input_frames) | |
frames = sample_frames_from_video_file( | |
video_path, num_model_input_frames, frame_sampling_rate | |
) | |
inputs = processor( | |
text=labels, videos=list(frames), return_tensors="pt", padding=True | |
) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() | |
label_to_prob = {} | |
for ind, label in enumerate(labels): | |
label_to_prob[label] = float(probs[ind]) | |
# return label_to_prob, gif_path | |
return label_to_prob | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("# **<p align='center'>Classification of Rooms</p>**") | |
gr.Markdown( | |
"#### **<p align='center'>Upload a video (mp4) of a room and provide a list of type of rooms the model should select from.</p>**" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
video_file = gr.Video(label="Video File:", show_label=True) | |
local_video_labels_text = gr.Textbox(value=ROOM_TYPES,label="Room Types", show_label=True) | |
submit_button = gr.Button(value="Predict") | |
with gr.Column(): | |
predictions = gr.Label(label="Predictions:", show_label=True) | |
gr.Markdown("**Examples:**") | |
gr.Examples( | |
examples, | |
[video_file, local_video_labels_text], | |
predictions, | |
fn=predict, | |
cache_examples=True, | |
) | |
submit_button.click( | |
predict, | |
inputs=[video_file, local_video_labels_text], | |
outputs=predictions, | |
) | |
app.launch() | |