File size: 2,299 Bytes
a88b5a8
0d80a1b
a88b5a8
 
 
 
 
 
 
 
 
 
0d80a1b
 
 
 
 
 
 
 
a88b5a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel

class MedicalAssistant:
    def __init__(self, model_name="sethuiyer/Medichat-Llama3-8B", device="cuda"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        self.sys_message = ''' 
        You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
        provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
        '''
    def __init__(self, model_name="sethuiyer/Medichat-Llama3-8B", device="cuda"):
        # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
        self.model = AutoModel.from_pretrained("ThisIs-Developer/Llama-2-GGML-Medical-Chatbot")
        self.sys_message = ''' 
        You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
        provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
        '''

    def format_prompt(self, question):
        messages = [
            {"role": "system", "content": self.sys_message},
            {"role": "user", "content": question}
        ]
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return prompt

    def generate_response(self, question, max_new_tokens=512):
        prompt = self.format_prompt(question)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
        answer = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        return answer

# if __name__ == "__main__":
#     assistant = MedicalAssistant()
#     question = '''
#     Symptoms:
#     Dizziness, headache, and nausea.

#     What is the differential diagnosis?
#     '''
#     response = assistant.generate_response(question)
#     print(response)