# Author: Ricardo Lisboa Santos # Creation date: 2024-01-10 import torch # import torch_directml from transformers import pipeline def getDevice(DEVICE): device = None if DEVICE == "cpu": device = torch.device("cpu") dtype = torch.float32 elif DEVICE == "cuda": device = torch.device("cuda") dtype = torch.float16 # elif DEVICE == "directml": # device = torch_directml.device() # dtype = torch.float16 return device def loadGenerator(device): generator = pipeline("zero-shot-classification") # .to(device) return generator def classify(generator, text, labels=["education", "politics", "business"]): output = generator(text, candidate_labels=labels) return output def clearCache(DEVICE, generator): generator.tokenizer.save_pretrained("cache") generator.model.save_pretrained("cache") del generator # if DEVICE == "directml": # torch_directml.empty_cache()