Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@ assignees: ''
---

# 🎯 **Goal (What & Why)**
> **Clearly state the purpose of this feature.**
> **Clearly state the purpose of this feature.**
> _(Example: Add FP8 support using torchao to improve training throughput by 1.5x.)_

# 🚀 **Execution Plan**
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_
> _(This section may start as an incomplete draft but must be defined before implementation begins.)_

### **Step 1: What is the smallest working version?**
> _(Describe the simplest way to implement this feature with minimal effort.)_
> _(Describe the simplest way to implement this feature with minimal effort.)_

### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_
### **Step 2: What additional optimizations are possible (but optional)?**
> _(List potential refinements that can be added in later PRs if needed.)_

# 📌 **Acceptance Criteria** (Must-Haves for Completion)
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.
* The feature must be **functional and tested**.
* The implementation must be **documented in practical terms**.
* The PR must include a **performance/impact summary**.
* **No refactors unless directly necessary** for feature completion.

# 🛠️ **Project Management**
- [ ] **Assign the project to the Fast-LLM project.**
- [ ] **Set the `Estimate` field (in days) in the GitHub project.**
- [ ] **Use the `Size` field to categorize the PR size (Small/Medium/Large).**
- [ ] **Assign an owner when opening the issue.**
- [ ] **Assign an owner when opening the issue.**
14 changes: 7 additions & 7 deletions .github/workflows/manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ jobs:
sudo rm -rf /usr/share/dotnet || true
sudo rm -rf /opt/ghc || true
sudo rm -rf /usr/local/.ghcup || true

- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.commit_sha != '' && inputs.commit_sha || inputs.branch }}

- name: Get commit info
id: commit_info
run: |
Expand All @@ -48,7 +48,7 @@ jobs:
echo "full_sha=${COMMIT_SHA}" >> $GITHUB_OUTPUT
echo "short_sha=${COMMIT_SHORT}" >> $GITHUB_OUTPUT
echo "Building from commit: ${COMMIT_SHA}"

- name: Docker meta
id: meta
uses: docker/metadata-action@v5
Expand All @@ -59,18 +59,18 @@ jobs:
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}
type=raw,value=${{ inputs.branch }}-${{ inputs.tag_suffix }}-${{ steps.commit_info.outputs.short_sha }}
type=raw,value=latest-${{ inputs.tag_suffix }},enable=${{ inputs.branch == 'main' && inputs.commit_sha == '' }}

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Login to GHCR
if: ${{ inputs.push_image }}
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Build and push
uses: docker/build-push-action@v6
with:
Expand All @@ -80,7 +80,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=registry,ref=ghcr.io/servicenow/fast-llm:cache
cache-to: type=registry,ref=ghcr.io/servicenow/fast-llm:cache,mode=max

- name: Output build info
run: |
echo "Built Docker image with tags:"
Expand Down
75 changes: 75 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class
Expand All @@ -11,6 +12,7 @@

if typing.TYPE_CHECKING:
from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer


@config_class()
Expand Down Expand Up @@ -55,6 +57,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


class SamplingStrategy(str, enum.Enum):
"""Strategy for sampling mixers in a stochastic mixer."""

uniform = "uniform"
weighted = "weighted"


@config_class(registry=True)
class MixerConfig(BlockWithBiasConfig):
"""
Expand All @@ -71,6 +80,72 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi
return super()._from_dict(default, strict=strict)


@config_class(dynamic_type={MixerConfig: "stochastic"})
class StochasticMixerConfig(MixerConfig):
"""
Stochastic mixer that uniformly samples from multiple mixer options during training.

For supernet training, each forward pass randomly selects one mixer to execute,
training all mixers with different subsets of data.
"""

_abstract = False

mixers: list[MixerConfig] = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a dict so we can refer to them by name, ex. in debug?

desc="List of mixer options to sample from (must contain at least 1).",
hint=FieldHint.architecture,
)

sampling_strategy: SamplingStrategy = Field(
default=SamplingStrategy.uniform,
desc="Strategy for sampling mixers during training.",
hint=FieldHint.feature,
)

sampling_weights: list[float] | None = Field(
default=None,
desc="Sampling probability for each mixer (must sum to 1.0). "
"Only used when sampling_strategy='weighted'. "
"If None with uniform strategy, all mixers have equal probability.",
hint=FieldHint.feature,
)

main_mixer_index: int = Field(
default=0,
desc="Index of the main mixer. "
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)

def _validate(self) -> None:
super()._validate()

# Validate mixers list is not empty
Assert.gt(len(self.mixers), 0)

# Validate sampling weights
if self.sampling_weights is not None:
Assert.eq(len(self.sampling_weights), len(self.mixers))
# Check sum is close to 1.0
weight_sum = sum(self.sampling_weights)
if abs(weight_sum - 1.0) > 1e-5:
raise ValueError(f"Sampling weights must sum to 1.0, got {weight_sum}")
# Check all weights are non-negative
if any(w < 0 for w in self.sampling_weights):
raise ValueError("All sampling weights must be non-negative")

# Validate main mixer index
Assert.lt(self.main_mixer_index, len(self.mixers))

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer

return StochasticMixer


@config_class(dynamic_type={BlockConfig: "decoder"})
class DecoderBlockConfig(BlockConfig):
_abstract = False
Expand Down
193 changes: 193 additions & 0 deletions fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import logging
import typing

import torch

from fast_llm.core.distributed import set_generator
from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.block import BlockWithBias
from fast_llm.layers.decoder.config import SamplingStrategy, StochasticMixerConfig
from fast_llm.tensor import TensorMeta

logger = logging.getLogger(__name__)


class StochasticMixer[ConfigType: StochasticMixerConfig](BlockWithBias[ConfigType]):
"""
A mixer that stochastically samples from multiple mixer options during training.
In training mode, each forward pass randomly selects one mixer according to
the sampling strategy. In eval mode, uses the configured inference mixer.
This is useful for supernet training where you want to train multiple
architecture variants (e.g., attention vs. Mamba) with different data subsets.
"""

_config: ConfigType

def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
hidden_dim: TensorDim,
lr_scale: float | None,
peft: PeftConfig | None,
return_bias: bool = True,
):
super().__init__(
config,
distributed_config,
hidden_dim=hidden_dim,
lr_scale=lr_scale,
peft=peft,
return_bias=return_bias,
)

# Initialize all mixers
self.mixers = torch.nn.ModuleList(
[
mixer_config.get_layer(
distributed_config,
hidden_dim,
lr_scale,
peft=peft,
return_bias=return_bias,
)
for mixer_config in self._config.mixers
]
)

# Precompute sampling probabilities as a tensor
if self._config.sampling_strategy == SamplingStrategy.uniform:
self._sampling_probs = torch.ones(len(self.mixers)) / len(self.mixers)
elif self._config.sampling_strategy == SamplingStrategy.weighted:
if self._config.sampling_weights is None:
raise ValueError("sampling_weights must be provided when using weighted sampling strategy")
self._sampling_probs = torch.tensor(self._config.sampling_weights, dtype=torch.float32)
else:
raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented")

logger.info(
f"Initialized StochasticMixer with {len(self.mixers)} mixers: "
f"{[type(m).__name__ for m in self.mixers]}"
)

# Mark all mixer parameters with allow_no_grad since only one mixer
# is active per forward pass during training. Even though all mixers
# will eventually be trained, on any single forward pass, the non-selected
# mixers won't receive gradients.
for mixer in self.mixers:
for param in mixer.parameters(recurse=True):
if hasattr(param, 'allow_no_grad'):
param.allow_no_grad = True

def setup(self, distributed: Distributed) -> None:
"""Setup all mixers with the distributed context."""
super().setup(distributed)
for mixer in self.mixers:
mixer.setup(distributed)

def _sample_mixer_index(self) -> int:
"""
Sample a mixer index according to the configured strategy.
Returns:
Index of the mixer to use for this forward pass.
"""
if not self.training:
# Inference mode: use the configured main mixer
return self._config.main_mixer_index

# Training mode: stochastic sampling
# Use distributed RNG to ensure consistency across TP/PP ranks
# This ensures all ranks in a TP/PP group use the same mixer
generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator

with set_generator(generator):
# Sample from categorical distribution
idx = torch.multinomial(self._sampling_probs, num_samples=1).item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires a costly cuda sync. How about we sample for all layers at once during preprocessing?


return idx

def _forward(
self,
input_: torch.Tensor,
kwargs: dict[str, typing.Any],
losses: dict[str, typing.Any] | None = None,
metrics: dict[str, typing.Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Forward pass through a randomly selected mixer.
Args:
input_: Input tensor
kwargs: Forward pass arguments
losses: Optional dictionary to store losses
metrics: Optional dictionary to store metrics
Returns:
Tuple of (output tensor, bias tensor or None)
"""
# Sample which mixer to use
mixer_idx = self._sample_mixer_index()

if self._debug.enabled:
logger.debug(f"StochasticMixer selecting mixer {mixer_idx}: {type(self.mixers[mixer_idx]).__name__}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous if multiple mixers share the same type. Use named mixers instead?


# Forward through selected mixer
return self.mixers[mixer_idx]._forward(input_, kwargs, losses, metrics)

def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None:
"""
Preprocess for all mixers.
Since we don't know which mixer will be selected during training,
we need to preprocess for all of them. This includes things like
attention masks, rotary embeddings, etc.
"""
for mixer in self.mixers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be name conflicts. Consider namespace?

mixer.preprocess(batch, kwargs)

def get_compute_usage(
self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig
) -> int:
"""
Return expected compute usage (weighted average of all mixers).
This gives a more accurate estimate than just using one mixer,
since during training we'll be using all of them according to
their sampling probabilities.
"""
usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers]

# Weight by sampling probability and return the expected value
expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs))

return int(expected_usage)

def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit dangerous, there could be name conflicts and counts will be wrong for averaging. Not sure how to fix though.

"""
Merge loss definitions from all mixers.
Returns the union of all loss definitions, deduplicated by name.
This ensures we allocate space for any auxiliary losses that any
of the mixers might need.
"""
all_losses = []
for mixer in self.mixers:
all_losses.extend(mixer.get_loss_definitions(count=count))

# Deduplicate by loss name
seen = set()
unique_losses = []
for loss_def in all_losses:
if loss_def.name not in seen:
seen.add(loss_def.name)
unique_losses.append(loss_def)

return unique_losses
1 change: 1 addition & 0 deletions fast_llm/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Import these submodules to ensure classes are added to the dynamic class registry.
"""

from fast_llm.layers.attention.config import AttentionConfig # isort: skip
from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip
from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip
Loading