Skip to content

Llama3.1-8B converges faster than the RCPs #838

@psyhtest

Description

@psyhtest

As discussed in the Training WG meeting on 2/Oct, Llama3.1-8B converges faster than the RCPs, at least with GBS=32.

Here's one example:

INFO - ------------------------------
INFO -  Running RCP Checker, pass: pruned_rcps
INFO - ------------------------------
INFO -  RCP Record: {'Benchmark': 'llama31_8b', 'BS': 32, 'Hyperparams': {'opt_base_learning_rate': 0.001, 'opt_learning_rate_warmup_samples': 16348, 'gradient_accumulation_steps': 2},
'Epochs to converge': [196608, 196608, 196608, 208896, 208896, 208896, 208896, 208896, 208896, 208896, 208896, 221184, 221184, 221184, 221184, 221184, 233472, 233472, 233472, 233472],
'RCP Mean': np.float64(215040.0), 'RCP Stdev': np.float64(11976.860890901255), 'Max Speedup': np.float64(1.042198772353707), 'Min Epochs': np.float64(206333.00067543983)}
INFO -  Submission mean epochs: 180576.0000
ERROR - RCP Test Failed: RCP found
INFO - ------------------------------

The mean epochs is ~206.3k, while the submission mean epochs is 180.6k.

Proposed workarounds:

  • Slowing down convergence by increasing the warm up samples (up to 16k as in the RCPs), as well as adjusting the learning rate.
  • Submitting new RCPs by early next week (NVIDIA).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions