pytorch-fsdp2
Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
Packaged view
This page reorganizes the original catalog entry around fit, installability, and workflow context first. The original raw source lives below.
Install command
npx @skill-hub/cli install orchestra-research-ai-research-skills-pytorch-fsdp2
Repository
Skill path: 08-distributed-training/pytorch-fsdp2
Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
Open repositoryBest for
Primary workflow: Analyze Data & AI.
Technical facets: Full Stack, Data / AI.
Target audience: everyone.
License: MIT.
Original source
Catalog source: SkillHub Club.
Repository owner: Orchestra-Research.
This is still a mirrored public skill entry. Review the repository before installing into production workflows.
What it helps with
- Install pytorch-fsdp2 into Claude Code, Codex CLI, Gemini CLI, or OpenCode workflows
- Review https://github.com/Orchestra-Research/AI-Research-SKILLs before adding pytorch-fsdp2 to shared team environments
- Use pytorch-fsdp2 for development workflows
Works across
Favorites: 0.
Sub-skills: 0.
Aggregator: No.
Original source / Raw SKILL.md
---
name: pytorch-fsdp2
description: Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [PyTorch, FSDP2, Fully Sharded Data Parallel, Distributed Training, DTensor, Device Mesh, Sharded Checkpointing, Mixed Precision, Offload, Torch Distributed]
dependencies: [torch]
---
# Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script
This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
> FSDP2 in PyTorch is exposed primarily via `torch.distributed.fsdp.fully_shard` and the `FSDPModule` methods it adds in-place to modules. See: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.
---
## When to use this skill
Use FSDP2 when:
- Your model **doesn’t fit** on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is **DTensor-based per-parameter sharding** (more inspectable, simpler sharded state dicts) than FSDP1.
- You may later compose DP with **Tensor Parallel** using **DeviceMesh**.
Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.
## Alternatives (when FSDP2 is not the best fit)
- **DistributedDataParallel (DDP)**: Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- **FullyShardedDataParallel (FSDP1)**: Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: `references/pytorch_ddp_notes.md`, `references/pytorch_fsdp1_api.md`.
---
## Contract the agent must follow
1. **Launch with `torchrun`** and set the CUDA device per process (usually via `LOCAL_RANK`).
2. **Apply `fully_shard()` bottom-up**, i.e., shard submodules (e.g., Transformer blocks) before the root module.
3. **Call `model(input)`**, not `model.forward(input)`, so the FSDP2 hooks run (unless you explicitly `unshard()` or register the forward method).
4. **Create the optimizer after sharding** and make sure it is built on the **DTensor parameters** (post-`fully_shard`).
5. **Checkpoint using Distributed Checkpoint (DCP)** or the distributed-state-dict helpers, not naïve `torch.save(model.state_dict())` unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)
---
## Step-by-step procedure
### 0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use `torchrun --nproc_per_node <gpus_per_node> ...` and ensure `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are visible.
Reference: `references/pytorch_fsdp2_tutorial.md` (launch commands and setup), `references/pytorch_fully_shard_api.md` (user contract).
---
### 1) Initialize distributed and set device
Minimal, correct pattern:
- `dist.init_process_group(backend="nccl")`
- `torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`
- Optionally create a `DeviceMesh` to describe the data-parallel group(s)
Reference: `references/pytorch_device_mesh_tutorial.md` (why DeviceMesh exists & how it manages process groups).
---
### 2) Build model on meta device (recommended for very large models)
For big models, initialize on `meta`, apply sharding, then materialize weights on GPU:
- `with torch.device("meta"): model = ...`
- apply `fully_shard(...)` on submodules, then `fully_shard(model)`
- `model.to_empty(device="cuda")`
- `model.reset_parameters()` (or your init routine)
Reference: `references/pytorch_fsdp2_tutorial.md` (migration guide shows this flow explicitly).
---
### 3) Apply `fully_shard()` bottom-up (wrapping policy = “apply where needed”)
**Do not** only call `fully_shard` on the topmost module.
Recommended sharding pattern for transformer-like models:
- iterate modules, `if isinstance(m, TransformerBlock): fully_shard(m, ...)`
- then `fully_shard(model, ...)`
Why:
- `fully_shard` forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference: `references/pytorch_fully_shard_api.md` (bottom-up requirement and why).
---
### 4) Configure `reshard_after_forward` for memory/perf trade-offs
Default behavior:
- `None` means `True` for non-root modules and `False` for root modules (good default).
Heuristics:
- If you’re memory-bound: keep defaults or force `True` on many blocks.
- If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often `False`).
- Advanced: use an `int` to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference: `references/pytorch_fully_shard_api.md` (full semantics).
---
### 5) Mixed precision & offload (optional but common)
FSDP2 uses:
- `mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)`
- `offload_policy=CPUOffloadPolicy()` if you want CPU offload
Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep `reduce_dtype` aligned with your gradient reduction expectations.
- If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: `references/pytorch_fully_shard_api.md` (MixedPrecisionPolicy / OffloadPolicy classes).
---
### 6) Optimizer, gradient clipping, accumulation
- Create the optimizer **after** sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
- use the FSDP2 mechanism (`set_requires_gradient_sync`) instead of FSDP1’s `no_sync()`.
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: `references/pytorch_fsdp2_tutorial.md`.
---
### 7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
**A) Distributed Checkpoint (DCP) — best default**
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces **multiple files** (often at least one per rank) and operates “in place”.
**B) Distributed state dict helpers**
- `get_model_state_dict` / `set_model_state_dict` with `StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)`
- For optimizer: `get_optimizer_state_dict` / `set_optimizer_state_dict`
Avoid:
- Saving DTensor state dicts with plain `torch.save` unless you intentionally convert with `DTensor.full_tensor()` and manage memory carefully.
References:
- `references/pytorch_dcp_overview.md` (DCP behavior and caveats)
- `references/pytorch_dcp_recipe.md` and `references/pytorch_dcp_async_recipe.md` (end-to-end usage)
- `references/pytorch_fsdp2_tutorial.md` (DTensor vs DCP state-dict flows)
- `references/pytorch_examples_fsdp2.md` (working checkpoint scripts)
---
## Workflow checklists (copy-paste friendly)
### Workflow A: Retrofit FSDP2 into an existing training script
- [ ] Launch with `torchrun` and initialize the process group.
- [ ] Set the CUDA device from `LOCAL_RANK`; create a `DeviceMesh` if you need multi-dim parallelism.
- [ ] Build the model (use `meta` if needed), apply `fully_shard` bottom-up, then `fully_shard(model)`.
- [ ] Create the optimizer after sharding so it captures DTensor parameters.
- [ ] Use `model(inputs)` so hooks run; use `set_requires_gradient_sync` for accumulation.
- [ ] Add DCP save/load via `torch.distributed.checkpoint` helpers.
Reference: `references/pytorch_fsdp2_tutorial.md`, `references/pytorch_fully_shard_api.md`, `references/pytorch_device_mesh_tutorial.md`, `references/pytorch_dcp_recipe.md`.
### Workflow B: Add DCP save/load (minimal pattern)
- [ ] Wrap state in `Stateful` or assemble state via `get_state_dict`.
- [ ] Call `dcp.save(...)` from all ranks to a shared path.
- [ ] Call `dcp.load(...)` and restore with `set_state_dict`.
- [ ] Validate any resharding assumptions when loading into a different mesh.
Reference: `references/pytorch_dcp_recipe.md`.
## Debug checklist (what the agent should check first)
1. **All ranks on distinct GPUs?**
If not, verify `torch.cuda.set_device(LOCAL_RANK)` and your `torchrun` flags.
2. **Did you accidentally call `forward()` directly?**
Use `model(input)` or explicitly `unshard()` / register forward.
3. **Is `fully_shard()` applied bottom-up?**
If only root is sharded, expect worse memory/perf and possible confusion.
4. **Optimizer created at the right time?**
Must be built on DTensor parameters *after* sharding.
5. **Checkpointing path consistent?**
- If using DCP, don’t mix with ad-hoc `torch.save` unless you understand conversions.
- Be mindful of PyTorch-version compatibility warnings for DCP.
---
## Common issues and fixes
- **Forward hooks not running** → Call `model(inputs)` (or `unshard()` explicitly) instead of `model.forward(...)`.
- **Optimizer sees non-DTensor params** → Create optimizer after all `fully_shard` calls.
- **Only root module sharded** → Apply `fully_shard` bottom-up on submodules before the root.
- **Memory spikes after forward** → Set `reshard_after_forward=True` for more modules.
- **Gradient accumulation desync** → Use `set_requires_gradient_sync` instead of FSDP1’s `no_sync()`.
Reference: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.
---
## Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
- `init_distributed()`: init process group, set device
- `build_model_meta()`: model on meta, apply `fully_shard`, materialize weights
- `build_optimizer()`: optimizer created after sharding
- `train_step()`: forward/backward/step with `model(inputs)` and DTensor-aware patterns
- `checkpoint_save/load()`: DCP or distributed state dict helpers
Concrete examples live in `references/pytorch_examples_fsdp2.md` and the official tutorial reference.
---
## References
- `references/pytorch_fsdp2_tutorial.md`
- `references/pytorch_fully_shard_api.md`
- `references/pytorch_ddp_notes.md`
- `references/pytorch_fsdp1_api.md`
- `references/pytorch_device_mesh_tutorial.md`
- `references/pytorch_tp_tutorial.md`
- `references/pytorch_dcp_overview.md`
- `references/pytorch_dcp_recipe.md`
- `references/pytorch_dcp_async_recipe.md`
- `references/pytorch_examples_fsdp2.md`
- `references/torchtitan_fsdp_notes.md` (optional, production notes)
- `references/ray_train_fsdp2_example.md` (optional, integration example)
---
## Referenced Files
> The following files are referenced in this skill and included for context.
### references/pytorch_fully_shard_api.md
```markdown
# Reference: `torch.distributed.fsdp.fully_shard` API (FSDP2)
**Source (official):** PyTorch docs — `torch.distributed.fsdp.fully_shard`
https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html
Created: Dec 04, 2024 • Last updated: Oct 13, 2025
## Key facts (paraphrased from the API docs)
### User contract highlights
- `fully_shard(model)` converts `model.parameters()` to **DTensor** at init, then hooks **all-gather** before forward/backward and **free/reshard** after.
- The optimizer **must be initialized with DTensor parameters** and step must happen on DTensors.
- Call `model(input)` (not `model.forward(input)`) so hooks run; otherwise explicitly `unshard()` or register the forward method for hooking.
- Apply `fully_shard` **bottom-up**: shard submodules first, then the root module, to form efficient communication groups and enable overlap.
- `fully_shard` “unions” the module type in-place with `FSDPModule`, enabling methods like `unshard()` / `reshard()`.
> Short excerpt (<= 25 words): “Users generally should not call fully_shard() only on the topmost root module.”
### Signature & core args
`fully_shard(module, *, mesh=None, reshard_after_forward=None, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(...), offload_policy=OffloadPolicy(), ignored_params=None)`
- **mesh** (`DeviceMesh`):
- 1D mesh ⇒ “classic” FSDP sharding, placement `(Shard(0),)`
- 2D mesh ⇒ Hybrid sharding (HSDP): sharded across one dim, replicated across the other, placement `(Replicate(), Shard(0))`
- **reshard_after_forward**:
- `True`: free unsharded params after forward (re-all-gather during backward)
- `False`: keep unsharded params after forward (avoid backward all-gather)
- `None`: defaults to `True` for non-root, `False` for root
- `int`: reshard to a smaller world-size after forward (must divide shard-dim size)
- **shard_placement_fn**: override per-parameter sharding dim (requires even sharding if not dim-0)
- **ignored_params**: parameters not sharded / not moved / not reduced
## Mixed precision & offload policy classes (same doc page)
### `MixedPrecisionPolicy`
Controls:
- `param_dtype`: dtype used for unsharded parameters during forward/backward
- `reduce_dtype`: dtype used for gradient reduction
- `output_dtype`: dtype used for forward output
- `cast_forward_inputs`: whether to cast forward inputs to `param_dtype`
### `OffloadPolicy` and `CPUOffloadPolicy`
OffloadPolicy controls:
- `param_device` / `reduce_device` / `output_device` (and for CPU offload policy, also `optimizer_state_device`)
## Practical implications for agents
- **Bottom-up sharding** is not optional: it affects grouping and memory/perf.
- **Don’t bypass hooks**: using `model.forward` directly breaks all-gather scheduling.
- **Optimizer construction order matters**: construct optimizer after `fully_shard`.
```
### references/pytorch_fsdp2_tutorial.md
```markdown
# Reference: Getting Started with Fully Sharded Data Parallel (FSDP2) tutorial
**Source (official):** PyTorch Tutorials — “Getting Started with Fully Sharded Data Parallel (FSDP2)”
https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html
Created: Mar 17, 2022 • Last updated: Sep 02, 2025 • Last verified: Nov 05, 2024
## What the tutorial emphasizes
### How FSDP2 differs from DDP and FSDP1
- FSDP shards **parameters, gradients, and optimizer state**; parameters are all-gathered for compute and reduce-scattered for grads.
- Compared to FSDP1, FSDP2:
- uses **DTensor per-parameter sharding** (more direct manipulation; sharded state dicts)
- improves memory management for more deterministic memory behavior
- supports extensibility points for custom all-gather (e.g., float8/NF4 use cases)
### Model initialization flow (meta-device pattern)
The tutorial’s migration section shows a typical pattern:
- initialize model on `meta`
- apply `fully_shard` to the intended layers (policy expressed by explicit calls)
- apply `fully_shard` to the root module
- materialize weights via `to_empty(device="cuda")`, then run `reset_parameters()`
### State dict workflows
The tutorial describes two main ways:
**A) DTensor APIs (manual)**
- Loading: use `distribute_tensor(full_tensor, meta_param.device_mesh, meta_param.placements)` then `model.load_state_dict(..., assign=True)`
- Saving: call `DTensor.full_tensor()` to all-gather; optionally CPU-offload on rank0 to avoid peak GPU memory
**B) DCP distributed state-dict helpers (recommended when no custom handling needed)**
- Loading: `set_model_state_dict(..., StateDictOptions(full_state_dict=True, broadcast_from_rank0=True))`
- Saving: `get_model_state_dict(..., StateDictOptions(full_state_dict=True, cpu_offload=True))`
- Points to `pytorch/examples` for optimizer state dict save/load with `set_optimizer_state_dict` / `get_optimizer_state_dict`
### Migration guide mapping
The tutorial explicitly maps FSDP1 concepts to FSDP2:
- `sharding_strategy` ↔ `reshard_after_forward` (+ 2D mesh for HYBRID)
- `cpu_offload` ↔ `offload_policy` (`CPUOffloadPolicy`)
- `no_sync()` ↔ `set_requires_gradient_sync`
- `sync_module_states` moves to DCP broadcast-from-rank0 flows
## Practical takeaways for agents
- Express wrapping policy by **explicitly applying `fully_shard`** to chosen submodules.
- Use DCP APIs for flexible checkpointing and resharding unless you must interop with third-party formats.
```
### references/pytorch_ddp_notes.md
```markdown
# Reference: Distributed Data Parallel (DDP) notes
**Source (official):** PyTorch docs — “Distributed Data Parallel”
https://docs.pytorch.org/docs/stable/notes/ddp.html
Last accessed: Jan 30, 2026
## Key points (paraphrased from the notes)
- DDP is the standard PyTorch wrapper for distributed data parallel training.
- Typical usage includes initializing the process group, wrapping the model with `DistributedDataParallel`, and training normally.
```
### references/pytorch_fsdp1_api.md
```markdown
# Reference: Fully Sharded Data Parallel (FSDP1) API
**Source (official):** PyTorch docs — “Fully Sharded Data Parallel”
https://docs.pytorch.org/docs/stable/fsdp.html
Last accessed: Jan 30, 2026
## Key points (paraphrased from the API docs)
- `torch.distributed.fsdp.FullyShardedDataParallel` is the original FSDP wrapper for sharding module parameters across data-parallel workers.
```
### references/pytorch_device_mesh_tutorial.md
```markdown
# Reference: Getting Started with DeviceMesh (PyTorch tutorial)
**Source (official):** PyTorch Recipes — “Getting Started with DeviceMesh”
https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html
Created: Jan 24, 2024 • Last updated: Jul 18, 2025 • Last verified: Nov 05, 2024
## What DeviceMesh is (as defined by the tutorial)
DeviceMesh is a higher-level abstraction that **manages ProcessGroups**, making it easier to set up the right communication groups for multi-dimensional parallelism.
The tutorial motivation:
- Without DeviceMesh, users must manually compute rank groupings (replicate/shard groups) and create multiple process groups.
- With DeviceMesh, you describe topology with a shape (e.g., 2D mesh), and slice submeshes by dimension name.
## Why this matters for FSDP2
FSDP2 `fully_shard(..., mesh=...)` takes a `DeviceMesh`:
- 1D mesh: standard full sharding across DP workers.
- 2D mesh: hybrid sharding (HSDP), combining replication + sharding across mesh dimensions.
So the agent should:
- Prefer to create a DeviceMesh early (after init_process_group and setting CUDA device).
- Pass the correct (sub)mesh into `fully_shard` if composing with TP or other dimensions.
```
### references/pytorch_dcp_overview.md
```markdown
# Reference: Distributed Checkpoint (DCP) overview (torch.distributed.checkpoint)
**Source (official):** PyTorch docs — `torch.distributed.checkpoint`
https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
Created: Nov 16, 2022 • Last updated: Oct 08, 2025
## What DCP does
- Supports saving/loading from **multiple ranks in parallel**
- Handles **load-time resharding**, enabling saving with one cluster topology and loading into another
- Produces **multiple files per checkpoint** (often at least one per rank)
- Operates “in place”: the model allocates storage first; DCP loads into that storage
## Important caveats
- The docs warn: **no guarantees of backwards compatibility** across PyTorch versions for saved `state_dict`s.
- Process-group usage: if you pass a process group, only those ranks should call save/load, and all tensors must belong to that group.
## Where to learn usage
The doc links to official “Getting Started with DCP” and “Asynchronous Saving with DCP” recipes.
```
### references/pytorch_dcp_recipe.md
```markdown
# Reference: Getting Started with Distributed Checkpoint (DCP) recipe
**Source (official):** PyTorch Tutorials recipe — “Getting Started with Distributed Checkpoint (DCP)”
https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Created: Oct 02, 2023 • Last updated: Jul 10, 2025 • Last verified: Nov 05, 2024
## Key ideas shown in the recipe
- DCP saves/loads in parallel, and supports resharding across topologies at load time.
- It provides helpers under `torch.distributed.checkpoint.state_dict` to manage distributed `state_dict` generation/loading.
## Example structure (high level)
- Wrap application state in a `Stateful` object, so DCP automatically calls `state_dict()` / `load_state_dict()`
- Use `dcp.save(...)` / `dcp.load(...)`
- Use `get_state_dict` / `set_state_dict` helpers to correctly obtain and apply model/optimizer state dicts in distributed settings
## Practical agent guidance
If adding checkpointing to an FSDP2 training script, this recipe’s patterns are the safest default.
```
### references/pytorch_dcp_async_recipe.md
```markdown
# Reference: Asynchronous Saving with Distributed Checkpoint (DCP) recipe
**Source (official):** PyTorch Tutorials recipe — “Asynchronous Saving with Distributed Checkpoint (DCP)”
https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html
Created: Jul 22, 2024 • Last updated: Sep 29, 2025 • Last verified: Nov 05, 2024
## What async checkpointing changes
- Moves checkpointing off the critical training path via `torch.distributed.checkpoint.async_save`
- Introduces extra memory overhead because async save first copies model state into internal CPU buffers
## Practical agent guidance
- Use async save when checkpoint stalls are significant and you have headroom for CPU memory.
- Consider pinned memory strategies described in the recipe if performance matters.
```
### references/pytorch_examples_fsdp2.md
```markdown
# Reference: Official `pytorch/examples` FSDP2 scripts
**Sources (official, code):**
- `pytorch/examples` repository: https://github.com/pytorch/examples
- FSDP2 checkpoint example: https://github.com/pytorch/examples/blob/main/distributed/FSDP2/checkpoint.py
## Why this matters
The FSDP2 tutorial explicitly points users to `pytorch/examples` for end-to-end scripts, especially for:
- optimizer state dict save/load with the DCP state-dict helpers
- runnable command lines and minimal scaffolding
## How agents should use this
- Prefer copying patterns from these scripts over inventing new checkpoint logic.
- Keep the script structure (init distributed, build model, shard, optimizer, train loop, save/load) similar to ease debugging.
```
### references/pytorch_tp_tutorial.md
```markdown
# Reference: Tensor Parallel (TP) tutorial (and how it composes with FSDP)
**Source (official):** PyTorch Tutorials — “Large Scale Transformer model training with Tensor Parallel (TP)”
https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html
Created: Apr 19, 2024 • Last updated: Jul 18, 2025 • Last verified: Nov 05, 2024
## Key composition pattern: TP intra-host + FSDP inter-host
The tutorial recommends:
- Run TP on a fast intra-host fabric (e.g., NVLink).
- Run FSDP across hosts (inter-host).
It shows a **2D DeviceMesh** pattern and slicing:
- `mesh_2d = init_device_mesh("cuda", (dp, tp))`
- `tp_mesh = mesh_2d["tp"]` and `dp_mesh = mesh_2d["dp"]`
- Apply TP with `parallelize_module(..., tp_mesh, ...)`
- Apply FSDP2 with `fully_shard(..., mesh=dp_mesh, ...)`
## Practical agent guidance
If the user is already doing TP:
- Ensure FSDP2 `mesh` only includes the DP dimension (often inter-host).
- Leave the TP dimension to `torch.distributed.tensor.parallel`.
```
### references/torchtitan_fsdp_notes.md
```markdown
# Reference: TorchTitan notes on FSDP/FSDP2 (production-oriented)
**Source (official-ish, PyTorch org):** TorchTitan — FSDP docs
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
## Why include this
TorchTitan is a PyTorch reference stack for large-scale LLM training. Its FSDP documentation often contains pragmatic guidance around:
- configuration choices (e.g., sharding strategy vs memory/perf)
- checkpointing workflows in larger systems
- composition with other parallelisms
## Agent guidance
Treat TorchTitan as a “how people do it in production” complement to the API docs/tutorials. Always defer to the official API docs on semantics.
```
### references/ray_train_fsdp2_example.md
```markdown
# Reference: Ray Train FSDP2 integration guide (third-party, useful patterns)
**Source (third-party):** Ray docs — “Get started with PyTorch FSDP2 (Ray Train)”
https://docs.ray.io/en/latest/train/examples/pytorch/pytorch-fsdp/README.html
## Why include this
- Shows how to integrate FSDP2 into a higher-level training orchestrator.
- Mentions common mitigation knobs (mixed precision, CPU offload, sharding granularity).
- Demonstrates checkpointing with DCP in a managed training environment.
## Agent guidance
Use as integration inspiration, not as the semantic source of truth.
```