YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from safetensors.torch import load_file as safetensors_load_file
# Define a transform to convert PIL images to tensors
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
class ViTSalesModel(nn.Module):
def __init__(self):
super(ViTSalesModel, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
def forward(self, pixel_values, labels=None):
outputs = self.vit(pixel_values=pixel_values)
cls_output = outputs.last_hidden_state[:, 0, :] # Take the [CLS] token
sales = self.classifier(cls_output)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct(sales.view(-1), labels.view(-1))
return (loss, sales) if loss is not None else sales
model = ViTSalesModel()
# Load the saved model checkpoint
checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
state_dict = safetensors_load_file(checkpoint_path)
model.load_state_dict(state_dict)
model.eval()
# Maximum sales value for de-normalization (from training)
max_sales_value = 100000 # Replace with the actual max sales value used during training
def predict_sales(image_path):
# Load and preprocess the image
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
# Run the model
prediction = model(image)
print(prediction)
# De-normalize the prediction
sales_prediction = prediction.item() * max_sales_value
return sales_prediction
# Example usage
image_path = "/content/0000.png"
predicted_sales = predict_sales(image_path)
print(f"Predicted sales: {predicted_sales}")