deepapaikar commited on
Commit
1a4cca4
·
verified ·
1 Parent(s): b56227c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -3,13 +3,16 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import spaces
5
 
 
 
6
  # Load model and tokenizer only once, outside the function
7
  model_name = "deepapaikar/Katzbot_Llama_7b_QA_10eps"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')
10
 
11
 
12
- @spaces
 
13
  def generate_text(input_text):
14
  """Generates text using the LlamaKatz-3x8B model.
15
 
@@ -19,7 +22,7 @@ def generate_text(input_text):
19
  Returns:
20
  str: The generated text.
21
  """
22
- inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
23
  outputs = model.generate(**inputs)
24
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  return generated_text
 
3
  import torch
4
  import spaces
5
 
6
+ zero = torch.Tensor([0]).cuda()
7
+
8
  # Load model and tokenizer only once, outside the function
9
  model_name = "deepapaikar/Katzbot_Llama_7b_QA_10eps"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')
12
 
13
 
14
+
15
+ @spaces.GPU
16
  def generate_text(input_text):
17
  """Generates text using the LlamaKatz-3x8B model.
18
 
 
22
  Returns:
23
  str: The generated text.
24
  """
25
+ inputs = tokenizer(input_text, return_tensors="pt").to(zero.device)
26
  outputs = model.generate(**inputs)
27
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return generated_text