Some changes to run this across multiple GPUs efficiently

#1
by smcleod - opened

Neat project.

I wanted to have a go at running this locally and ended up making some changes which others might find useful, they allow loading the model across multiple GPUs somewhat efficiently.

SCR-20240625-irix.png

I was going to submit a PR, but I think your code is probably pretty specific to running in a HF Space, but I thought I'd share here in case you or anyone else finds it useful:

import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained("instruction-pretrain/instruction-synthesizer", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("instruction-pretrain/instruction-synthesizer", attn_implementation="flash_attention_2", device_map="auto", torch_dtype=torch.bfloat16)

# Set the pad_token_id to eos_token_id to avoid warnings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

def parse_pred(pred):
    """Extract the list of instruction-response pairs from the prediction"""
    QA_str_list = pred.split('</END>')
    if not pred.endswith('</END>'):
        QA_str_list = QA_str_list[:-1]

    QA_list = []
    raw_questions = []
    for QA_str in QA_str_list:
        try:
            assert len(QA_str.split('<ANS>')) == 2, f'invalid QA string: {QA_str}'
            Q_str, A_str = QA_str.split('<ANS>')
            Q_str, A_str = Q_str.strip(), A_str.strip()
            assert Q_str.startswith('<QUE>'), f'invalid question string: {Q_str} in QA_str: {QA_str}'
            assert len(A_str) > 0, f'invalid answer string in QA_str: {QA_str}'
            Q_str = Q_str.replace('<QUE>', '').strip()
            assert Q_str.lower() not in raw_questions, f'duplicate question: {Q_str}'
            QA_list.append({'Q': Q_str, 'A': A_str})
            raw_questions.append(Q_str.lower())
        except:
            pass

    return QA_list

def get_instruction_response_pairs(context):
    '''Prompt the synthesizer to generate instruction-response pairs based on the given context'''
    prompt = f'<s> <CON> {context} </CON>\n\n'
    inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt", padding=True, truncation=True)
    inputs = {key: value.to(model.device) for key, value in inputs.items()}  # Move inputs to the model's device
    outputs = model.generate(**inputs, max_new_tokens=400, do_sample=False)

    pred_start = int(inputs['input_ids'].shape[-1])
    pred = tokenizer.decode(outputs[0][pred_start:], skip_special_tokens=True)
    return parse_pred(pred)

def generate_pairs(context):
    instruction_response_pairs = get_instruction_response_pairs(context)
    output = ""
    for index, pair in enumerate(instruction_response_pairs):
        output += f"## Instruction {index + 1}:\n{pair['Q']}\n## Response {index + 1}:\n{pair['A']}\n\n"
    return output

description = """
## Instruction Pre-Training: Language Models as Supervised Multitask Learners

This demo implements the instruction synthesis approach from the paper ["Instruction Pre-Training: Language Models are Supervised Multitask Learners"](https://huggingface.co/papers/2406.14491).

### Method:
1. An instruction synthesizer is trained on diverse datasets to generate instruction-response pairs from raw text.
2. The synthesizer augments raw pre-training corpora with synthesized instruction-response pairs.
3. Language models are then pre-trained on this augmented data, combining unsupervised and supervised multitask learning.

This approach enhances model performance and generalization, particularly benefiting from further instruction tuning.

Try it out by entering some text below!
"""

examples = [
    "Hugging Face, Inc. is a French-American company incorporated under the Delaware General Corporation Law[1] and based in New York City that develops computation tools for building applications using machine learning. It is most notable for its transformers library built for natural language processing applications and its platform that allows users to share machine learning models and datasets and showcase their work.",
    "In order to make your Space work with ZeroGPU you need to decorate the Python functions that actually require a GPU with @spaces.GPU \n During the time when a decorated function is invoked, the Space will be attributed a GPU, and it will release it upon completion of the function.",
    "A spectre is haunting Europe โ€“ the spectre of communism. All the powers of old Europe have entered into a holy alliance to exorcise this spectre: Pope and Tsar, Metternich and Guizot, French Radicals and German police-spies"
]

# Create Gradio interface
iface = gr.Interface(
    fn=generate_pairs,
    inputs=gr.Textbox(lines=5, label="Enter context here"),
    outputs=gr.Textbox(lines=20, label="Generated Instruction-Response Pairs"),
    title="Instruction-Response Pair Generator",
    description=description,
    examples=examples
)

# Launch the interface
iface.launch(server_name="0.0.0.0")

with requirements.txt of

transformers
accelerate
hf_transfer
gradio
spaces
accelerate
bitsandbytes
flash-attn
# If using Python 3.12 install with pip install wheel; pip install flash-attn --no-build-isolation

Hi, thanks for your code, but our model is fine-tuned with the precision of fp16, I don't know whether it also works well with bf16.

Sign up or log in to comment