Spaces:
Runtime error
Runtime error
# Author: Ricardo Lisboa Santos | |
# Creation date: 2024-01-10 | |
import torch | |
# import torch_directml | |
from transformers import pipeline | |
def getDevice(DEVICE): | |
device = None | |
if DEVICE == "cpu": | |
device = torch.device("cpu") | |
dtype = torch.float32 | |
elif DEVICE == "cuda": | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
# elif DEVICE == "directml": | |
# device = torch_directml.device() | |
# dtype = torch.float16 | |
return device | |
def loadClassifier(device): | |
classifier = pipeline("sentiment-analysis") # .to(device) | |
return classifier | |
def classify(classifier, text): | |
output = classifier(text) | |
return output | |
def clearCache(DEVICE, classifier): | |
classifier.tokenizer.save_pretrained("cache") | |
classifier.model.save_pretrained("cache") | |
del classifier | |
# if DEVICE == "directml": | |
# torch_directml.empty_cache() | |