swag_base
This model is a fine-tuned version of FacebookAI/roberta-base on the SWAG (Situations With Adversarial Generations) dataset.
Model description
The model is designed to perform multiple-choice reasoning about real-world situations. Given a context and four possible continuations, it predicts the most plausible ending based on common sense understanding.
Key Features:
- Base model: RoBERTa-base
- Task: Multiple Choice Prediction
- Training dataset: SWAG
- Performance: 75.21% accuracy on evaluation set
Training Procedure
Training hyperparameters
- Learning rate: 5e-05
- Batch size: 16
- Number of epochs: 3
- Optimizer: AdamW
- Learning rate scheduler: Linear
- Training samples: 73,546
- Training time: 17m 53s
Training Results
- Training loss: 0.73
- Evaluation loss: 0.7362
- Evaluation accuracy: 0.7521
- Training samples/second: 205.623
- Training steps/second: 12.852
Usage Example
Here's how to use the model:
from transformers import AutoTokenizer, AutoModelForMultipleChoice
import torch
# Load model and tokenizer
model_path = "real-jiakai/roberta-base-uncased-finetuned-swag"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForMultipleChoice.from_pretrained(model_path)
def predict_swag(context, endings, model, tokenizer):
encoding = tokenizer(
[context] * 4,
endings,
truncation=True,
max_length=128,
padding="max_length",
return_tensors="pt"
)
input_ids = encoding['input_ids'].unsqueeze(0)
attention_mask = encoding['attention_mask'].unsqueeze(0)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_idx = torch.argmax(logits).item()
return {
'context': context,
'predicted_ending': endings[predicted_idx],
'probabilities': torch.softmax(logits, dim=1)[0].tolist()
}
# Example scenarios
test_examples = [
{
'context': "Stephen Curry dribbles the ball at the three-point line",
'endings': [
"He quickly releases a perfect shot that swishes through the net", # Most plausible
"He suddenly starts dancing ballet on the court",
"He transforms the basketball into a pizza",
"He flies to the moon with the basketball"
]
},
{
'context': "Elon Musk walks into a SpaceX facility and looks at a rocket",
'endings': [
"He discusses technical details with the engineering team", # Most plausible
"He turns the rocket into a giant chocolate bar",
"He starts playing basketball with the rocket",
"He teaches the rocket to speak French"
]
}
]
for i, example in enumerate(test_examples, 1):
result = predict_swag(
example['context'],
example['endings'],
model,
tokenizer
)
print(f"\n=== Test Scenario {i} ===")
print(f"Initial Context: {result['context']}")
print(f"\nPredicted Most Likely Ending: {result['predicted_ending']}")
print("\nProbabilities for All Options:")
for idx, (ending, prob) in enumerate(zip(result['all_endings'], result['probabilities'])):
print(f"Option {idx}: {ending}")
print(f"Probability: {prob:.3f}")
print("\n" + "="*50)
Limitations and Biases
The model's performance is limited by its training data and may not generalize well to all domains Performance might vary depending on the complexity and domain of the input scenarios The model may exhibit biases present in the training data
Framework versions
Transformers 4.47.0.dev0 PyTorch 2.5.1+cu124 Datasets 3.1.0 Tokenizers 0.20.3
Citation
If you use this model, please cite:
@inproceedings{zellers2018swagaf,
title={SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference},
author={Zellers, Rowan and Bisk, Yonatan and Schwartz, Roy and Choi, Yejin},
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
year={2018}
}
- Downloads last month
- 4
Inference API (serverless) does not yet support transformers models for this pipeline type.
Model tree for real-jiakai/roberta-base-uncased-finetuned-swag
Base model
FacebookAI/roberta-base