import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextStreamer, Qwen2VLForConditionalGeneration
import torch
from PIL import Image
import re
import requests
from io import BytesIO
import copy
import secrets
from pathlib import Path
from argparse import ArgumentParser
from pathlib import Path
import copy
import gradio as gr
import os
import re
import secrets
import tempfile
from pathlib import Path
import copy
import os
import re
import secrets
import tempfile
from transformers import AutoTokenizer
from transformers.generation import GenerationConfig
DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4'
device_map = "auto"
tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_CKPT_PATH, trust_remote_code=True,
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
DEFAULT_CKPT_PATH,
device_map=device_map,
trust_remote_code=True,
).eval()
model.generation_config = GenerationConfig.from_pretrained(
DEFAULT_CKPT_PATH, trust_remote_code=True,
)
BOX_TAG_PATTERN = r"([\s\S]*?)"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = f"
"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "
" + line
text = "".join(lines)
return text
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
history_cp = copy.deepcopy(task_history)
full_response = ""
history_filter = []
pic_idx = 1
pre = ""
for i, (q, a) in enumerate(history_cp):
if isinstance(q, (tuple, list)):
q = f'Picture {pic_idx}: {q[0]}'
pre += q + '\n'
pic_idx += 1
else:
pre += q
history_filter.append((pre, a))
pre = ""
history, message = history_filter[:-1], history_filter[-1][0]
response, history = model.chat(tokenizer, message, history=history)
image = tokenizer.draw_bbox_on_latest_picture(response, history)
if image is not None:
temp_dir = secrets.token_hex(20)
temp_dir = Path("/tmp") / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = f"tmp{secrets.token_hex(5)}.jpg"
filename = temp_dir / name
image.save(str(filename))
_chatbot[-1] = (_parse_text(chat_query), (str(filename),))
chat_response = response.replace("[", "")
chat_response = chat_response.replace(r"]", "")
chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
if chat_response != "":
_chatbot.append((None, chat_response))
else:
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
return _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
def regenerate(_chatbot, task_history):
print("Regenerate clicked")
print("Before:", task_history, _chatbot)
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
print("After:", task_history, _chatbot)
return predict(_chatbot, task_history)
css = '''
.gradio-container{max-width:800px !important}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Qwen-VL-Chat Bot")
gr.Markdown("## Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud **Space by [@Artificialguybr](https://twitter.com/artificialguybr). Test the [QwenLLM-14B](https://huggingface.co/spaces/artificialguybr/qwen-14b-chat-demo) here for free!")
chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"])
submit_btn = gr.Button("🚀 Submit")
regen_btn = gr.Button("🤔️ Regenerate")
empty_bin = gr.Button("🧹 Clear History")
gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
demo.launch()