# # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All Rights Reserved. # """Module to generate OpenELM output given a model and an input prompt.""" import os import logging import time import argparse from typing import Optional, Union import torch from transformers import AutoTokenizer, AutoModelForCausalLM def generate( prompt: str, model: Union[str, AutoModelForCausalLM], hf_access_token: str = None, tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', device: Optional[str] = None, max_length: int = 1024, assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, generate_kwargs: Optional[dict] = None, ) -> str: """ Generates output given a prompt. Args: prompt: The string prompt. model: The LLM Model. If a string is passed, it should be the path to the hf converted checkpoint. hf_access_token: Hugging face access token. tokenizer: Tokenizer instance. If model is set as a string path, the tokenizer will be loaded from the checkpoint. device: String representation of device to run the model on. If None and cuda available it would be set to cuda:0 else cpu. max_length: Maximum length of tokens, input prompt + generated tokens. assistant_model: If set, this model will be used for speculative generation. If a string is passed, it should be the path to the hf converted checkpoint. generate_kwargs: Extra kwargs passed to the hf generate function. Returns: output_text: output generated as a string. generation_time: generation time in seconds. Raises: ValueError: If device is set to CUDA but no CUDA device is detected. ValueError: If tokenizer is not set. ValueError: If hf_access_token is not specified. """ if not device: if torch.cuda.is_available() and torch.cuda.device_count(): device = "cuda:0" logging.warning( 'inference device is not set, using cuda:0, %s', torch.cuda.get_device_name(0) ) else: device = 'cpu' logging.warning( ( 'No CUDA device detected, using cpu, ' 'expect slower speeds.' ) ) if 'cuda' in device and not torch.cuda.is_available(): raise ValueError('CUDA device requested but no CUDA device detected.') if not tokenizer: raise ValueError('Tokenizer is not set in the generate function.') if not hf_access_token: raise ValueError(( 'Hugging face access token needs to be specified. ' 'Please refer to https://huggingface.co/docs/hub/security-tokens' ' to obtain one.' ) ) if isinstance(model, str): checkpoint_path = model model = AutoModelForCausalLM.from_pretrained( checkpoint_path, trust_remote_code=True ) model.to(device).eval() if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained( tokenizer, token=hf_access_token, ) # Speculative mode draft_model = None if assistant_model: draft_model = assistant_model if isinstance(assistant_model, str): draft_model = AutoModelForCausalLM.from_pretrained( assistant_model, trust_remote_code=True ) draft_model.to(device).eval() # Prepare the prompt tokenized_prompt = tokenizer(prompt) tokenized_prompt = torch.tensor( tokenized_prompt['input_ids'], device=device ) tokenized_prompt = tokenized_prompt.unsqueeze(0) # Generate stime = time.time() output_ids = model.generate( tokenized_prompt, max_length=max_length, pad_token_id=0, assistant_model=draft_model, **(generate_kwargs if generate_kwargs else {}), ) generation_time = time.time() - stime output_text = tokenizer.decode( output_ids[0].tolist(), skip_special_tokens=True ) return output_text, generation_time def openelm_generate_parser(): """Argument Parser""" class KwargsParser(argparse.Action): """Parser action class to parse kwargs of form key=value""" def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for val in values: if '=' not in val: raise ValueError( ( 'Argument parsing error, kwargs are expected in' ' the form of key=value.' ) ) kwarg_k, kwarg_v = val.split('=') try: converted_v = int(kwarg_v) except ValueError: try: converted_v = float(kwarg_v) except ValueError: converted_v = kwarg_v getattr(namespace, self.dest)[kwarg_k] = converted_v parser = argparse.ArgumentParser('OpenELM Generate Module') parser.add_argument( '--model', dest='model', help='Path to the hf converted model.', required=True, type=str, ) parser.add_argument( '--hf_access_token', dest='hf_access_token', help='Hugging face access token, starting with "hf_".', type=str, ) parser.add_argument( '--prompt', dest='prompt', help='Prompt for LLM call.', default='', type=str, ) parser.add_argument( '--device', dest='device', help='Device used for inference.', type=str, ) parser.add_argument( '--max_length', dest='max_length', help='Maximum length of tokens.', default=256, type=int, ) parser.add_argument( '--assistant_model', dest='assistant_model', help=( ( 'If set, this is used as a draft model ' 'for assisted speculative generation.' ) ), type=str, ) parser.add_argument( '--generate_kwargs', dest='generate_kwargs', help='Additional kwargs passed to the HF generate function.', type=str, nargs='*', action=KwargsParser, ) return parser.parse_args() if __name__ == '__main__': args = openelm_generate_parser() prompt = args.prompt output_text, genertaion_time = generate( prompt=prompt, model=args.model, device=args.device, max_length=args.max_length, assistant_model=args.assistant_model, generate_kwargs=args.generate_kwargs, hf_access_token=args.hf_access_token, ) print_txt = ( f'\r\n{"=" * os.get_terminal_size().columns}\r\n' '\033[1m Prompt + Generated Output\033[0m\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' f'{output_text}\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' '\r\nGeneration took' f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' 'seconds.\r\n' ) print(print_txt)