Skip to content

Commit 0ce1376

Browse files
committed
Add unit and GPU tests for core sparse attention functionality
Signed-off-by: Kai Xu <[email protected]>
1 parent 27bb8da commit 0ce1376

File tree

30 files changed

+1359
-416
lines changed

30 files changed

+1359
-416
lines changed

examples/llm_sparse_attention/hf_spar_attn.py renamed to examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 9 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,56 +22,28 @@
2222

2323
import numpy as np
2424
import torch
25-
import torch.nn as nn
2625
from datasets import load_dataset
2726
from transformers import AutoModelForCausalLM, AutoTokenizer
2827

28+
import modelopt.torch.opt as mto
2929
import modelopt.torch.sparsity.attention_sparsity as mtsa
3030
from modelopt.torch.export import export_hf_checkpoint
3131
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
32-
from modelopt.torch.sparsity.attention_sparsity.config import (
33-
SKIP_SOFTMAX_CALIB,
34-
SKIP_SOFTMAX_DEFAULT,
35-
)
36-
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
32+
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
33+
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
3734
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
3835

3936
RAND_SEED = 1234
4037

38+
# Enable HuggingFace checkpointing support
39+
mto.enable_huggingface_checkpointing()
40+
4141
# You can define custom configurations or use the default
4242
SPARSE_ATTN_CFG_CHOICES = {
4343
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
44-
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
4544
}
4645

4746

48-
def print_sparsity_stats(model: nn.Module):
49-
"""Print sparsity statistics if available."""
50-
module_stats = []
51-
for name, module in model.named_modules():
52-
if hasattr(module, "get_stats"):
53-
stats = module.get_stats()
54-
if stats and "average_sparsity" in stats:
55-
module_stats.append((name, stats["average_sparsity"]))
56-
57-
if not module_stats:
58-
print("No sparsity statistics available")
59-
return
60-
61-
# Check if all modules have the same sparsity
62-
sparsities = [s for _, s in module_stats]
63-
if len(set(sparsities)) == 1:
64-
# All identical - show summary
65-
print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}")
66-
else:
67-
# Different sparsities - show individual values
68-
avg_sparsity = sum(sparsities) / len(sparsities)
69-
print(f"Average sparsity: {avg_sparsity:.2%}")
70-
print("Per-module breakdown:")
71-
for name, sparsity in module_stats:
72-
print(f" {name}: {sparsity:.2%} sparse")
73-
74-
7547
def get_narrativeqa_samples(num_samples=3):
7648
"""Load samples from NarrativeQA dataset for testing.
7749
@@ -173,9 +145,7 @@ def verify_outputs(model, tokenizer, args):
173145
print("BASELINE vs SPARSE ATTENTION COMPARISON")
174146
print("=" * 60)
175147
print(f"\nTest prompt: {display_prompt}")
176-
print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})")
177-
if "[...]" in truncated_prompt:
178-
print("Note: Text was middle-truncated to fit token limit")
148+
print(f"Input tokens: {inputs['input_ids'].shape[1]}")
179149

180150
# Helper function to generate text
181151
def generate_text(model, inputs, args, tokenizer):
@@ -235,23 +205,13 @@ def sparsify_model(model, args):
235205
modified_sparse_cfg[pattern] = modified_cfg
236206

237207
# Create new config with modified settings
238-
sparse_config = SparseAttentionConfig(
239-
method=base_config["method"],
240-
sparse_cfg=modified_sparse_cfg,
241-
collect_stats=True, # Enable stats collection for monitoring
242-
)
208+
sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg)
243209

244-
# Sparsify with optional calibration - framework handles calibration automatically
210+
# Sparsify the model
245211
model = mtsa.sparsify(model, config=sparse_config)
246212

247213
print("Sparse attention applied successfully!")
248214

249-
# Show sparsity statistics
250-
print("\n" + "=" * 60)
251-
print("Sparsity Statistics")
252-
print("=" * 60)
253-
print_sparsity_stats(model)
254-
255215
return model
256216

257217

File renamed without changes.
File renamed without changes.

examples/llm_sparsity/finetune.py renamed to examples/llm_sparsity/weight_sparsity/finetune.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
# Adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/3783d18/train.py
217

318
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
File renamed without changes.

0 commit comments

Comments
 (0)