YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from safetensors.torch import load_file as safetensors_load_file

# Define a transform to convert PIL images to tensors
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

class ViTSalesModel(nn.Module):
    def __init__(self):
        super(ViTSalesModel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
    
    def forward(self, pixel_values, labels=None):
        outputs = self.vit(pixel_values=pixel_values)
        cls_output = outputs.last_hidden_state[:, 0, :]  # Take the [CLS] token
        sales = self.classifier(cls_output)
        loss = None
        if labels is not None:
            loss_fct = nn.MSELoss()
            loss = loss_fct(sales.view(-1), labels.view(-1))
        return (loss, sales) if loss is not None else sales

model = ViTSalesModel()

# Load the saved model checkpoint
checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
state_dict = safetensors_load_file(checkpoint_path)
model.load_state_dict(state_dict)
model.eval()

# Maximum sales value for de-normalization (from training)
max_sales_value = 100000  # Replace with the actual max sales value used during training

def predict_sales(image_path):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        # Run the model
        prediction = model(image)
    
    print(prediction)
    # De-normalize the prediction
    sales_prediction = prediction.item() * max_sales_value
    return sales_prediction

# Example usage
image_path = "/content/0000.png"
predicted_sales = predict_sales(image_path)
print(f"Predicted sales: {predicted_sales}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
86.4M params
Tensor type
F32
·
Inference API
Unable to determine this model's library. Check the docs .