import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextStreamer, Qwen2VLForConditionalGeneration import torch from PIL import Image import re import requests from io import BytesIO import copy import secrets from pathlib import Path from argparse import ArgumentParser from pathlib import Path import copy import gradio as gr import os import re import secrets import tempfile from pathlib import Path import copy import os import re import secrets import tempfile from transformers import AutoTokenizer from transformers.generation import GenerationConfig DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4' device_map = "auto" tokenizer = AutoTokenizer.from_pretrained( DEFAULT_CKPT_PATH, trust_remote_code=True, ) model = Qwen2VLForConditionalGeneration.from_pretrained( DEFAULT_CKPT_PATH, device_map=device_map, trust_remote_code=True, ).eval() model.generation_config = GenerationConfig.from_pretrained( DEFAULT_CKPT_PATH, trust_remote_code=True, ) BOX_TAG_PATTERN = r"([\s\S]*?)" PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def predict(_chatbot, task_history): chat_query = _chatbot[-1][0] query = task_history[-1][0] history_cp = copy.deepcopy(task_history) full_response = "" history_filter = [] pic_idx = 1 pre = "" for i, (q, a) in enumerate(history_cp): if isinstance(q, (tuple, list)): q = f'Picture {pic_idx}: {q[0]}' pre += q + '\n' pic_idx += 1 else: pre += q history_filter.append((pre, a)) pre = "" history, message = history_filter[:-1], history_filter[-1][0] response, history = model.chat(tokenizer, message, history=history) image = tokenizer.draw_bbox_on_latest_picture(response, history) if image is not None: temp_dir = secrets.token_hex(20) temp_dir = Path("/tmp") / temp_dir temp_dir.mkdir(exist_ok=True, parents=True) name = f"tmp{secrets.token_hex(5)}.jpg" filename = temp_dir / name image.save(str(filename)) _chatbot[-1] = (_parse_text(chat_query), (str(filename),)) chat_response = response.replace("", "") chat_response = chat_response.replace(r"", "") chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response) if chat_response != "": _chatbot.append((None, chat_response)) else: _chatbot[-1] = (_parse_text(chat_query), response) full_response = _parse_text(response) task_history[-1] = (query, full_response) return _chatbot def add_text(history, task_history, text): task_text = text if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: task_text = text[:-1] history = history + [(_parse_text(text), None)] task_history = task_history + [(task_text, None)] return history, task_history, "" def add_file(history, task_history, file): history = history + [((file.name,), None)] task_history = task_history + [((file.name,), None)] return history, task_history def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] def regenerate(_chatbot, task_history): print("Regenerate clicked") print("Before:", task_history, _chatbot) if not task_history: return _chatbot item = task_history[-1] if item[1] is None: return _chatbot task_history[-1] = (item[0], None) chatbot_item = _chatbot.pop(-1) if chatbot_item[0] is None: _chatbot[-1] = (_chatbot[-1][0], None) else: _chatbot.append((chatbot_item[0], None)) print("After:", task_history, _chatbot) return predict(_chatbot, task_history) css = ''' .gradio-container{max-width:800px !important} ''' with gr.Blocks(css=css) as demo: gr.Markdown("# Qwen-VL-Chat Bot") gr.Markdown("## Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud **Space by [@Artificialguybr](https://twitter.com/artificialguybr). Test the [QwenLLM-14B](https://huggingface.co/spaces/artificialguybr/qwen-14b-chat-demo) here for free!") chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520) query = gr.Textbox(lines=2, label='Input') task_history = gr.State([]) with gr.Row(): addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"]) submit_btn = gr.Button("🚀 Submit") regen_btn = gr.Button("🤔️ Regenerate") empty_bin = gr.Button("🧹 Clear History") gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.") submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then( predict, [chatbot, task_history], [chatbot], show_progress=True ) submit_btn.click(reset_user_input, [], [query]) empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) demo.launch()