MedIntel / backend.py
Paulie-Aditya's picture
trying other model
0d80a1b
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
class MedicalAssistant:
def __init__(self, model_name="sethuiyer/Medichat-Llama3-8B", device="cuda"):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
self.sys_message = '''
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
'''
def __init__(self, model_name="sethuiyer/Medichat-Llama3-8B", device="cuda"):
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
self.model = AutoModel.from_pretrained("ThisIs-Developer/Llama-2-GGML-Medical-Chatbot")
self.sys_message = '''
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
'''
def format_prompt(self, question):
messages = [
{"role": "system", "content": self.sys_message},
{"role": "user", "content": question}
]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def generate_response(self, question, max_new_tokens=512):
prompt = self.format_prompt(question)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
answer = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return answer
# if __name__ == "__main__":
# assistant = MedicalAssistant()
# question = '''
# Symptoms:
# Dizziness, headache, and nausea.
# What is the differential diagnosis?
# '''
# response = assistant.generate_response(question)
# print(response)