PEFT
Safetensors
English
jinjieyuan's picture
Add instruction for the sparse base model
79bbdac verified
|
raw
history blame
6.02 kB
metadata
language: en
license: apache-2.0

Shears Model Card: shears-llama-13b-50-math-heuristic-adapter

The heuristic adapter discovered from the super-adapter fine-tuned on sparsified LLaMA-13B with some math reasoning datasets using Shears.

Model Details

Information

  • Model name: shears-llama-13b-50-math-heuristic-adapter
  • Base model: Sparsified LLaMA-13B
  • Sparsity: 50%
  • Domain: Math
  • Subnetwork version: Heuristic
  • NNCF Configuration: nncf_shears_llama.json

Sparsified Base Model

Shears employs a simple but effective pruning approach Wanda to sparsify the language model, serving as the base model. Clone the Wanda repo:

git clone https://github.com/locuslab/wanda.git && cd wanda && git checkout 8e8fc87 && cd ..

The command for unstructured sparsifying LLaMA-13B with Wanda, to achieve unstructured 50% sparsity:

python wanda/main.py \
    --model yahma/llama-13b-hf \
    --prune_method wanda \
    --sparsity_ratio 0.5 \
    --sparsity_type unstructured \
    --save wanda_out \
    --save_model shears-llama-13b-50-base
  • --model: The identifier for the model on the Hugging Face model hub or local path.
  • --sparsity_ratio: Specifies the percentage of weights to be pruned.
  • --save_model: Specifies the directory where the pruned language model will be stored.

Refer to our repo for the environment information to run this command.

Adapter Configuration

  • LoRA rank: 32 (24 in the heuristic subnetwork)
  • LoRA alpha: 64
  • LoRA target modules: q_proj, k_proj, v_proj, up_proj, down_proj
  • LoRA rank search space: [32, 24, 16] (for each LoRA module)

Training Hyperparameters

  • Batch size: 16
  • Learning rate: 3e-4
  • Epoch: 3

Training Data

Unified math reasoning dataset: math_10k.json (collected with the training sets of GSM8K, MAWPS, and AQuA).

Evaluation Data

GSM8K, AQuA, MAWPS, SVAMP

How to use

Use our modified PEFT library (apply patch):

git clone https://github.com/huggingface/peft.git
cd peft && git checkout v0.5.0 && git apply --ignore-space-change --ignore-whitespace peft-modifications-for-shears-inference-usage.patch && pip install -e . && cd ..
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

def generate_prompt(instruction):
    return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. 

                    ### Instruction:
                    {instruction}

                    ### Response:
                    """

base_model = AutoModelForCausalLM.from_pretrained("shears-llama-13b-50-base")
model = PeftModel.from_pretrained(base_model, "IntelLabs/shears-llama-13b-50-math-heuristic-adapter")
model.eval()

non_zero_params = sum([(param.data != 0).sum().item() for _, param in model.named_parameters()])
print(f"Number of all non-zero parameters: {non_zero_params}")

tokenizer = AutoTokenizer.from_pretrained("shears-llama-13b-50-base")

instruction = "Edgar eats 18 pretzels a day. If his brother eats 1/2 as many, how many does his brother eat in a week?"
prompt = generate_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
with torch.no_grad():
    generation_output = model.generate(
        input_ids=input_ids,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=256,
        use_cache=True,
        num_beams=4,
    )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
print(output)

Evaluation Results

Model Sparsity GSM8K AQuA MAWPS SVAMP Average
LLaMA-7B-LoRA - 37.5 18.9 79.0 52.1 46.9
LLaMA-7B-Shears 50% 36.1 22.0 78.6 44.5 45.3
LLaMA-13B-LoRA - 47.5 18.5 83.6 54.6 51.1
LLaMA-13B-Shears 50% 45.1 22.0 83.2 53.3 50.9

Model Sources

Citation

@article{munoz2024shears,
  title = {Shears: Unstructured Sparsity with Neural Low-rank Adapter Search},
  author={J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
  journal={The 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL-2024)},
  year={2024}
}

License

Apache-2.0