import base64 import json import os import math from io import BytesIO from typing import Any, Dict, List, Literal, Optional, Union import requests import torch from PIL import Image from torch import nn from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, AutoConfig class Transformer(nn.Module): save_in_root: bool = True def __init__( self, model_name_or_path: str = 'llamaindex/vdr-2b-multi-v1', processor_name_or_path: Optional[str] = None, max_pixels: int = 768 * 28 * 28, min_pixels: int = 1 * 28 * 28, dimension: int = 2048, cache_dir: Optional[str] = None, device: str = 'cuda:0', config_args: Optional[Dict[str, Any]] = None, model_args: Optional[Dict[str, Any]] = None, processor_args: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super(Transformer, self).__init__() self.device = device self.dimension = dimension self.max_pixels = max_pixels self.min_pixels = min_pixels self.model_name_or_path = model_name_or_path self.processor_name_or_path = processor_name_or_path or model_name_or_path self.cache_dir = cache_dir self.config_args = config_args or {} self.model_args = model_args or {} self.processor_args = processor_args or {} self.document_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>What is shown in this image?<|im_end|>\n<|endoftext|>" self.query_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Query: %s<|im_end|>\n<|endoftext|>" @classmethod def load(cls, input_path: str) -> 'Transformer': config_path = os.path.join(input_path, 'config.json') if os.path.exists(config_path): with open(config_path) as f: config = json.load(f) else: config = {} instance = cls(model_name_or_path=input_path, **config) # Load model with flash attention if available try: instance.model = Qwen2VLForConditionalGeneration.from_pretrained( input_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map=instance.device, cache_dir=instance.cache_dir, **instance.model_args ).eval() except (ImportError, ValueError) as e: print(f"Flash attention not available, falling back to default attention: {e}") instance.model = Qwen2VLForConditionalGeneration.from_pretrained( input_path, torch_dtype=torch.bfloat16, device_map=instance.device, cache_dir=instance.cache_dir, **instance.model_args ).eval() # Initialize processor instance.processor = AutoProcessor.from_pretrained( input_path, min_pixels=instance.min_pixels, max_pixels=instance.max_pixels, cache_dir=instance.cache_dir, **instance.processor_args ) instance.model.padding_side = "left" instance.processor.tokenizer.padding_side = "left" return instance def _smart_resize(self, height: int, width: int) -> tuple[int, int]: h_bar = max(28, self._round_by_factor(height, 28)) w_bar = max(28, self._round_by_factor(width, 28)) if h_bar * w_bar > self.max_pixels: beta = math.sqrt((height * width) / self.max_pixels) h_bar = self._floor_by_factor(height / beta, 28) w_bar = self._floor_by_factor(width / beta, 28) elif h_bar * w_bar < self.min_pixels: beta = math.sqrt(self.min_pixels / (height * width)) h_bar = self._ceil_by_factor(height * beta, 28) w_bar = self._ceil_by_factor(width * beta, 28) return w_bar, h_bar @staticmethod def _round_by_factor(number: float, factor: int) -> int: return round(number / factor) * factor @staticmethod def _ceil_by_factor(number: float, factor: int) -> int: return math.ceil(number / factor) * factor @staticmethod def _floor_by_factor(number: float, factor: int) -> int: return math.floor(number / factor) * factor def _resize_image(self, image: Image.Image) -> Image.Image: new_size = self._smart_resize(image.height, image.width) return image.resize(new_size) @staticmethod def _decode_data_image(data_image_str: str) -> Image.Image: header, data = data_image_str.split(',', 1) image_data = base64.b64decode(data) return Image.open(BytesIO(image_data)) def _process_input(self, texts: List[Union[str, Image.Image]]) -> tuple[List[str], List[Image.Image]]: processed_texts = [] processed_images = [] dummy_image = Image.new('RGB', (56, 56)) for sample in texts: if isinstance(sample, str): if sample.startswith('http') or sample.startswith('data:image/'): try: if sample.startswith('http'): response = requests.get(sample) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = self._decode_data_image(sample).convert('RGB') processed_texts.append(self.document_prompt) processed_images.append(self._resize_image(image)) except Exception as e: processed_texts.append(self.query_prompt % sample) processed_images.append(dummy_image) else: processed_texts.append(self.query_prompt % sample) processed_images.append(dummy_image) elif isinstance(sample, Image.Image): processed_texts.append(self.document_prompt) processed_images.append(self._resize_image(sample)) return processed_texts, processed_images def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: cache_position = torch.arange(0, features['input_ids'].shape[0]) inputs = self.model.prepare_inputs_for_generation( **features, cache_position=cache_position, use_cache=False ) with torch.no_grad(): output = self.model( **inputs, return_dict=True, output_hidden_states=True ) embeddings = output.hidden_states[-1][:, -1] features['sentence_embedding'] = torch.nn.functional.normalize( embeddings[:, :self.dimension], p=2, dim=-1 ) return features def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]: processed_texts, processed_images = self._process_input(texts) inputs = self.processor( text=processed_texts, images=processed_images, videos=None, padding=padding, return_tensors='pt' ) return {k: v.to(self.device) for k, v in inputs.items()} def save(self, output_path: str, safe_serialization: bool = True) -> None: # Save the configuration config = { 'model_name_or_path': self.model_name_or_path, 'processor_name_or_path': self.processor_name_or_path, 'max_pixels': self.max_pixels, 'min_pixels': self.min_pixels, 'dimension': self.dimension, 'config_args': self.config_args, 'model_args': self.model_args, 'processor_args': self.processor_args, } os.makedirs(output_path, exist_ok=True) with open(os.path.join(output_path, 'config.json'), 'w') as f: json.dump(config, f) self.model.save_pretrained(output_path, safe_serialization=safe_serialization) self.processor.save_pretrained(output_path)