Spaces:
Running
on
Zero
Running
on
Zero
grahamwhiteuk
commited on
Revert "feat: temporarily switch out to 2b model"
Browse filesThis reverts commit 2e41a220488ada8d0d858681b691d78ef41327d8.
app.py
CHANGED
@@ -205,7 +205,7 @@ with gr.Blocks(
|
|
205 |
gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
|
206 |
gr.HTML(
|
207 |
elem_classes="system-description",
|
208 |
-
value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-
|
209 |
)
|
210 |
with gr.Row(elem_classes="column-gap"):
|
211 |
with gr.Column(scale=0, elem_classes="no-gap"):
|
|
|
205 |
gr.HTML("<h2>IBM Granite Guardian 3.0</h2>", elem_classes="title")
|
206 |
gr.HTML(
|
207 |
elem_classes="system-description",
|
208 |
+
value="<p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in generative AI systems. They can be used with any large language model to make interactions with generative AI systems safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user prompts, assistant responses, and for hallucinations in retrieval-augmented generation. In this demo, we use granite-guardian-3.0-8b.</p>",
|
209 |
)
|
210 |
with gr.Row(elem_classes="column-gap"):
|
211 |
with gr.Column(scale=0, elem_classes="no-gap"):
|
model.py
CHANGED
@@ -23,7 +23,7 @@ logger.debug(f"Inference engine is: '{inference_engine}'")
|
|
23 |
if inference_engine == "VLLM":
|
24 |
device = torch.device("cuda")
|
25 |
|
26 |
-
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-
|
27 |
logger.debug(f"model_path is {model_path}")
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
29 |
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
@@ -37,10 +37,10 @@ elif inference_engine == "WATSONX":
|
|
37 |
)
|
38 |
|
39 |
client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
|
40 |
-
hf_model_path = "ibm-granite/granite-guardian-3.0-
|
41 |
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
42 |
|
43 |
-
model_id = "ibm/granite-guardian-3-
|
44 |
model = ModelInference(model_id=model_id, api_client=client)
|
45 |
|
46 |
|
@@ -48,14 +48,13 @@ def parse_output(output, input_len):
|
|
48 |
label, prob_of_risk = None, None
|
49 |
if nlogprobs > 0:
|
50 |
|
51 |
-
list_index_logprobs_i = [
|
52 |
-
|
53 |
-
]
|
54 |
if list_index_logprobs_i is not None:
|
55 |
prob = get_probablities(list_index_logprobs_i)
|
56 |
prob_of_risk = prob[1]
|
57 |
|
58 |
-
res = tokenizer.decode(output.sequences[:,
|
59 |
if risky_token.lower() == res.lower():
|
60 |
label = risky_token
|
61 |
elif safe_token.lower() == res.lower():
|
@@ -65,7 +64,6 @@ def parse_output(output, input_len):
|
|
65 |
|
66 |
return label, prob_of_risk.item()
|
67 |
|
68 |
-
|
69 |
def get_probablities(logprobs):
|
70 |
safe_token_prob = 1e-50
|
71 |
unsafe_token_prob = 1e-50
|
@@ -77,7 +75,9 @@ def get_probablities(logprobs):
|
|
77 |
if decoded_token.strip().lower() == risky_token.lower():
|
78 |
unsafe_token_prob += math.exp(logprob)
|
79 |
|
80 |
-
probabilities = torch.softmax(
|
|
|
|
|
81 |
|
82 |
return probabilities
|
83 |
|
@@ -87,7 +87,6 @@ def softmax(values):
|
|
87 |
total = sum(exp_values)
|
88 |
return [v / total for v in exp_values]
|
89 |
|
90 |
-
|
91 |
def get_probablities_watsonx(top_tokens_list):
|
92 |
safe_token_prob = 1e-50
|
93 |
risky_token_prob = 1e-50
|
@@ -110,9 +109,9 @@ def get_prompt(messages, criteria_name, tokenize=False, add_generation_prompt=Fa
|
|
110 |
guardian_config=guardian_config,
|
111 |
tokenize=tokenize,
|
112 |
add_generation_prompt=add_generation_prompt,
|
113 |
-
return_tensors=return_tensors
|
114 |
)
|
115 |
-
logger.debug(f
|
116 |
return prompt
|
117 |
|
118 |
|
@@ -167,15 +166,18 @@ def generate_text(messages, criteria_name):
|
|
167 |
|
168 |
elif inference_engine == "VLLM":
|
169 |
# input_ids = get_prompt(
|
170 |
-
# messages=messages,
|
171 |
-
# criteria_name=criteria_name,
|
172 |
# tokenize=True,
|
173 |
# add_generation_prompt=True,
|
174 |
# return_tensors="pt").to(model.device)
|
175 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
176 |
-
logger.debug(f
|
177 |
input_ids = tokenizer.apply_chat_template(
|
178 |
-
messages,
|
|
|
|
|
|
|
179 |
).to(model.device)
|
180 |
logger.debug(f"input_ids are: {input_ids}")
|
181 |
input_len = input_ids.shape[1]
|
@@ -188,8 +190,7 @@ def generate_text(messages, criteria_name):
|
|
188 |
do_sample=False,
|
189 |
max_new_tokens=nlogprobs,
|
190 |
return_dict_in_generate=True,
|
191 |
-
output_scores=True,
|
192 |
-
)
|
193 |
logger.debug(f"model output is:\n{output}")
|
194 |
|
195 |
label, prob_of_risk = parse_output(output, input_len)
|
|
|
23 |
if inference_engine == "VLLM":
|
24 |
device = torch.device("cuda")
|
25 |
|
26 |
+
model_path = os.getenv("MODEL_PATH", "ibm-granite/granite-guardian-3.0-8b")
|
27 |
logger.debug(f"model_path is {model_path}")
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
29 |
# sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
|
|
|
37 |
)
|
38 |
|
39 |
client.set.default_project(os.getenv("WATSONX_PROJECT_ID"))
|
40 |
+
hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
|
41 |
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
|
42 |
|
43 |
+
model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
|
44 |
model = ModelInference(model_id=model_id, api_client=client)
|
45 |
|
46 |
|
|
|
48 |
label, prob_of_risk = None, None
|
49 |
if nlogprobs > 0:
|
50 |
|
51 |
+
list_index_logprobs_i = [torch.topk(token_i, k=nlogprobs, largest=True, sorted=True)
|
52 |
+
for token_i in list(output.scores)[:-1]]
|
|
|
53 |
if list_index_logprobs_i is not None:
|
54 |
prob = get_probablities(list_index_logprobs_i)
|
55 |
prob_of_risk = prob[1]
|
56 |
|
57 |
+
res = tokenizer.decode(output.sequences[:,input_len:][0],skip_special_tokens=True).strip()
|
58 |
if risky_token.lower() == res.lower():
|
59 |
label = risky_token
|
60 |
elif safe_token.lower() == res.lower():
|
|
|
64 |
|
65 |
return label, prob_of_risk.item()
|
66 |
|
|
|
67 |
def get_probablities(logprobs):
|
68 |
safe_token_prob = 1e-50
|
69 |
unsafe_token_prob = 1e-50
|
|
|
75 |
if decoded_token.strip().lower() == risky_token.lower():
|
76 |
unsafe_token_prob += math.exp(logprob)
|
77 |
|
78 |
+
probabilities = torch.softmax(
|
79 |
+
torch.tensor([math.log(safe_token_prob), math.log(unsafe_token_prob)]), dim=0
|
80 |
+
)
|
81 |
|
82 |
return probabilities
|
83 |
|
|
|
87 |
total = sum(exp_values)
|
88 |
return [v / total for v in exp_values]
|
89 |
|
|
|
90 |
def get_probablities_watsonx(top_tokens_list):
|
91 |
safe_token_prob = 1e-50
|
92 |
risky_token_prob = 1e-50
|
|
|
109 |
guardian_config=guardian_config,
|
110 |
tokenize=tokenize,
|
111 |
add_generation_prompt=add_generation_prompt,
|
112 |
+
return_tensors=return_tensors
|
113 |
)
|
114 |
+
logger.debug(f'prompt is\n{prompt}')
|
115 |
return prompt
|
116 |
|
117 |
|
|
|
166 |
|
167 |
elif inference_engine == "VLLM":
|
168 |
# input_ids = get_prompt(
|
169 |
+
# messages=messages,
|
170 |
+
# criteria_name=criteria_name,
|
171 |
# tokenize=True,
|
172 |
# add_generation_prompt=True,
|
173 |
# return_tensors="pt").to(model.device)
|
174 |
guardian_config = {"risk_name": criteria_name if criteria_name != "general_harm" else "harm"}
|
175 |
+
logger.debug(f'guardian_config is: {guardian_config}')
|
176 |
input_ids = tokenizer.apply_chat_template(
|
177 |
+
messages,
|
178 |
+
guardian_config=guardian_config,
|
179 |
+
add_generation_prompt=True,
|
180 |
+
return_tensors='pt'
|
181 |
).to(model.device)
|
182 |
logger.debug(f"input_ids are: {input_ids}")
|
183 |
input_len = input_ids.shape[1]
|
|
|
190 |
do_sample=False,
|
191 |
max_new_tokens=nlogprobs,
|
192 |
return_dict_in_generate=True,
|
193 |
+
output_scores=True,)
|
|
|
194 |
logger.debug(f"model output is:\n{output}")
|
195 |
|
196 |
label, prob_of_risk = parse_output(output, input_len)
|