Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import functools
import itertools
import warnings
from contextlib import ExitStack, contextmanager
Expand Down Expand Up @@ -42,7 +43,11 @@
)
from .init_weights import hook_params_setters
from .optimize_sharding import ShardingOptimizer
from .utils import _get_device_from_mesh
from .utils import (
NumericsLogger,
_get_device_from_mesh,
debug_boxed_nop_preserve_node_meta,
)

_APPLY_VIEW_MM_VIEW_PATTERN = False

Expand Down Expand Up @@ -212,6 +217,7 @@ def __init__(
reshard_after_forward: bool = True,
dynamic: bool = False,
loss_fn: Optional[Callable] = None,
numerics_logger: NumericsLogger | None = None,
**kwargs,
):
self.stack = ExitStack()
Expand Down Expand Up @@ -239,7 +245,14 @@ def __init__(
self.model = move_to_fake(model, self.fake_mode, device)
self.input_fn = input_fn
self.mesh = mesh
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
if compile:
self.compiler_fn = compile_fx_inner
elif numerics_logger:
self.compiler_fn = functools.partial(
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
)
else:
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
self.reshard_after_forward = reshard_after_forward
Expand Down
3 changes: 3 additions & 0 deletions autoparallel/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def stage_forward(
action: _Action,
ctx: _PipelineContext,
numerics_logs: Optional[list[str]] = None,
forward_hook: Callable | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[Callable]

) -> None:
schedule = ctx.schedule_ref
assert isinstance(schedule, _PipelineScheduleRuntime)
Expand Down Expand Up @@ -305,6 +306,8 @@ def stage_forward(
stage.output_chunks.append(output)
if ctx.target_mbs is not None:
ctx.schedule_ref._internal_losses.append(output)
if forward_hook:
forward_hook(stage, action, output)

stage.fwd_cache[mb_index] = (
output_tuple, # stage_output
Expand Down
53 changes: 52 additions & 1 deletion autoparallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
continue

self._logs.append(
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
)

def run_node(self, n: torch.fx.Node) -> Any:
Expand Down Expand Up @@ -429,6 +429,20 @@ def log_model_weights(self, parallel_mod):

print(f"Weight hashes written to {path}")

def log_fw_intermediates(self, logs):
rank = torch.distributed.get_rank()
path = self.dir / f"rank_{rank}_fw_intermediates.log"
with open(path, "a") as f:
f.write("\n".join(logs) + "\n")

def log_diff(self, t, rank=0, prefix="?"):
if self.rank == rank:
path = self.dir / "diff.log"
if isinstance(t, torch.distributed.tensor.DTensor):
t = t.to_local()
with open(path, "a") as f:
f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n")

def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
path = self.dir / "pp_weights.log"

Expand Down Expand Up @@ -463,3 +477,40 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):

if self.rank == 0:
print(f"Weight hashes written to {path}")

def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is num_world_stages?

path = self.dir / "diff.log"

torch.distributed.barrier()
for i in range(num_world_stages):
if self.rank in ranks and i in stage_mods:
grad_logs = []
real_params = dict(stage_mods[i].named_parameters())
for name, _ in orig_mod.named_parameters():
if name not in real_params:
continue
grad = real_params[name].grad
if grad is None:
grad_logs.append(f"[grad {name}] None")
else:
grad = grad.to_local()
grad_logs.append(
f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}"
)
with open(path, "a") as f:
f.write("\n".join(grad_logs) + "\n")
torch.distributed.barrier()


def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
def run(args):
with torch.fx.traceback.preserve_node_meta():
interp = DebugInterpreter(fx_g)
out = interp.boxed_run(args)
mylogs = interp.get_logs()
if numerics_logger:
numerics_logger.log_fw_intermediates(mylogs)
return out

run._boxed_call = True
return run
56 changes: 37 additions & 19 deletions examples/example_ds3_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
mscale=0.70,
)

bs = 4 * mesh.shape[0] * mesh.shape[1]
local_batch_size = 2
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
device = torch.device(f"cuda:{local_rank}")

# parallelize the model
Expand All @@ -129,11 +130,16 @@ def input_fn():
return torch.randint(
0,
config.vocab_size,
(bs, seq_len),
(global_batch_size, seq_len),
device=device,
)

with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
numerics_logger = None
if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
with AutoParallel(
model, input_fn, mesh, dynamic=True, numerics_logger=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be numerics_logger = numerics_logger?

) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

# x_sharding = (Shard(0), Replicate())
Expand All @@ -153,17 +159,22 @@ def input_fn():
# ) # maybe not correct value
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
numerics_logger.log_model_weights(parallel_mod)

x = (
torch.randint(
0,
config.vocab_size,
(bs // mesh.shape[0] // mesh.shape[1], seq_len),
device=device,
),
torch.manual_seed(rng_seed)

n_microbatches = 16
full_batch = torch.randint(
0,
config.vocab_size,
(local_batch_size * n_microbatches, seq_len),
device=device,
)
microbatches = torch.split(full_batch, local_batch_size, dim=0)
assert len(microbatches) == n_microbatches
if rng_seed:
numerics_logger.log_diff(
full_batch.to(torch.float32), prefix="full batch input"
)

# Symbolically evaluate in case you want to test running a graph bigger than your gpu
if fake_evaluate:
Expand All @@ -173,15 +184,22 @@ def input_fn():
allow_non_fake_inputs=True,
shape_env=shape_env,
):
# # now let's run it
out = parallel_mod(*x)
out.backward(torch.randn_like(out))
# now let's run it
for x in microbatches:
out = parallel_mod(x)
out.backward(torch.ones_like(out))
else:
out = parallel_mod(*x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
for i, x in enumerate(microbatches):
assert x.shape[0] == 2
out = parallel_mod(x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
out.backward(torch.ones_like(out))
if rng_seed is not None:
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")

if rng_seed is not None:
numerics_logger.log_forward_output(out)
out.backward(torch.randn_like(out))
for k, v in parallel_mod.named_parameters():
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")

print("All good!")

Expand Down
40 changes: 33 additions & 7 deletions examples/example_ds3_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def run_test(
# This is the spmd mesh to be used for tracing
mesh = world_mesh[("dp_mod_ep", "ep")]

global_batch_size = 32 * dp_degree
# Batch size that will be supplied to the schedule and will be broken down into microbatches
local_batch_size = global_batch_size // dp_degree
local_batch_size = 32
# global_batch_size = local_batch_size * dp_degree
n_microbatches = 16
# Batch size with which the spmd graphs will actually be executed
microbatch_size = local_batch_size // n_microbatches
Expand Down Expand Up @@ -472,10 +472,6 @@ def last_stage_inp_with_loss_fn():

world_size = torch.distributed.get_world_size()
num_world_stages = world_size * len(stage_mods)
if rng_seed is not None:
NumericsLogger(logs_dir).log_pp_model_weights(
model, stage_mods, num_world_stages, ranks=[0, 4]
)

stages = []
# Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
Expand All @@ -500,6 +496,7 @@ def last_stage_inp_with_loss_fn():
group=world_mesh.get_group("pp"),
)
stages.append(stage)

# Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
schedule = build_pipeline_schedule(
stages=stages,
Expand All @@ -511,9 +508,32 @@ def last_stage_inp_with_loss_fn():
backward_requires_autograd=False,
)
assert isinstance(schedule, _PipelineScheduleRuntime)

if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
numerics_logger.log_pp_model_weights(
model, stage_mods, num_world_stages, ranks=[0, 4]
)
torch.manual_seed(rng_seed)

def last_stage_forward_hook(
stage: GraphPipelineStage, action: str, output: torch.Tensor
):
if not stage.is_last or rng_seed is None:
return

rank = torch.distributed.get_rank()
if rank == 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you somehow not hardcode this

numerics_logger.log_diff(
output, rank=4, prefix=f"mb{action.microbatch_index} fwd out"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, very confusing. Also do we care about pp_rank or global rank? Finally v style schedules will have last stage on rank 0?

)

# Step 6. Override the pipeline runner's action implementations
schedule.register_custom_function(
FORWARD, functools.partial(stage_forward, numerics_logs=None)
FORWARD,
functools.partial(
stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook
),
)
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
Expand Down Expand Up @@ -542,6 +562,10 @@ def last_stage_inp_with_loss_fn():
)
if pp_rank == 0:
x = runtime_input_fn_first_stage()
if rng_seed:
numerics_logger.log_diff(
x.to(torch.float32), prefix="full batch input"
)
graph_pp_runner.step(
x, target=target, losses=losses, return_outputs=False
)
Expand All @@ -556,6 +580,8 @@ def last_stage_inp_with_loss_fn():
payload_fn=lambda: f"losses: {losses}",
)

numerics_logger.log_pp_grads(model, stage_mods, num_world_stages, ranks=[0, 4])

print("All good!")

if torch.distributed.is_initialized():
Expand Down