Back to skills
SkillHub ClubAnalyze Data & AIFull StackData / AI

pytorch-fsdp2

Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.

Packaged view

This page reorganizes the original catalog entry around fit, installability, and workflow context first. The original raw source lives below.

Stars
5,242
Hot score
99
Updated
March 20, 2026
Overall rating
C5.0
Composite score
5.0
Best-practice grade
B75.6

Install command

npx @skill-hub/cli install orchestra-research-ai-research-skills-pytorch-fsdp2
PyTorchFSDP2Fully Sharded Data ParallelDistributed TrainingDTensorDevice MeshSharded CheckpointingMixed PrecisionOffloadTorch Distributed

Repository

Orchestra-Research/AI-Research-SKILLs

Skill path: 08-distributed-training/pytorch-fsdp2

Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.

Open repository

Best for

Primary workflow: Analyze Data & AI.

Technical facets: Full Stack, Data / AI.

Target audience: everyone.

License: MIT.

Original source

Catalog source: SkillHub Club.

Repository owner: Orchestra-Research.

This is still a mirrored public skill entry. Review the repository before installing into production workflows.

What it helps with

  • Install pytorch-fsdp2 into Claude Code, Codex CLI, Gemini CLI, or OpenCode workflows
  • Review https://github.com/Orchestra-Research/AI-Research-SKILLs before adding pytorch-fsdp2 to shared team environments
  • Use pytorch-fsdp2 for development workflows

Works across

Claude CodeCodex CLIGemini CLIOpenCode

Favorites: 0.

Sub-skills: 0.

Aggregator: No.

Original source / Raw SKILL.md

---
name: pytorch-fsdp2
description: Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
version: 1.0.0
author: Orchestra Research
license: MIT
tags: [PyTorch, FSDP2, Fully Sharded Data Parallel, Distributed Training, DTensor, Device Mesh, Sharded Checkpointing, Mixed Precision, Offload, Torch Distributed]
dependencies: [torch]
---

# Skill: Use PyTorch FSDP2 (`fully_shard`) correctly in a training script

This skill teaches a coding agent how to **add PyTorch FSDP2** to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.

> FSDP2 in PyTorch is exposed primarily via `torch.distributed.fsdp.fully_shard` and the `FSDPModule` methods it adds in-place to modules. See: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.

---

## When to use this skill

Use FSDP2 when:
- Your model **doesn’t fit** on one GPU (parameters + gradients + optimizer state).
- You want an eager-mode sharding approach that is **DTensor-based per-parameter sharding** (more inspectable, simpler sharded state dicts) than FSDP1.  
- You may later compose DP with **Tensor Parallel** using **DeviceMesh**.

Avoid (or be careful) if:
- You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
- You’re forced onto older PyTorch versions without the FSDP2 stack.

## Alternatives (when FSDP2 is not the best fit)

- **DistributedDataParallel (DDP)**: Use the standard data-parallel wrapper when you want classic distributed data parallel training.
- **FullyShardedDataParallel (FSDP1)**: Use the original FSDP wrapper for parameter sharding across data-parallel workers.

Reference: `references/pytorch_ddp_notes.md`, `references/pytorch_fsdp1_api.md`.

---

## Contract the agent must follow

1. **Launch with `torchrun`** and set the CUDA device per process (usually via `LOCAL_RANK`).  
2. **Apply `fully_shard()` bottom-up**, i.e., shard submodules (e.g., Transformer blocks) before the root module.  
3. **Call `model(input)`**, not `model.forward(input)`, so the FSDP2 hooks run (unless you explicitly `unshard()` or register the forward method).  
4. **Create the optimizer after sharding** and make sure it is built on the **DTensor parameters** (post-`fully_shard`).  
5. **Checkpoint using Distributed Checkpoint (DCP)** or the distributed-state-dict helpers, not naïve `torch.save(model.state_dict())` unless you deliberately gather to full tensors.

(Each of these rules is directly described in the official API docs/tutorial; see references.)

---

## Step-by-step procedure

### 0) Version & environment sanity
- Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
- Use `torchrun --nproc_per_node <gpus_per_node> ...` and ensure `RANK`, `WORLD_SIZE`, `LOCAL_RANK` are visible.

Reference: `references/pytorch_fsdp2_tutorial.md` (launch commands and setup), `references/pytorch_fully_shard_api.md` (user contract).

---

### 1) Initialize distributed and set device
Minimal, correct pattern:
- `dist.init_process_group(backend="nccl")`
- `torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))`
- Optionally create a `DeviceMesh` to describe the data-parallel group(s)

Reference: `references/pytorch_device_mesh_tutorial.md` (why DeviceMesh exists & how it manages process groups).

---

### 2) Build model on meta device (recommended for very large models)
For big models, initialize on `meta`, apply sharding, then materialize weights on GPU:
- `with torch.device("meta"): model = ...`
- apply `fully_shard(...)` on submodules, then `fully_shard(model)`
- `model.to_empty(device="cuda")`
- `model.reset_parameters()` (or your init routine)

Reference: `references/pytorch_fsdp2_tutorial.md` (migration guide shows this flow explicitly).

---

### 3) Apply `fully_shard()` bottom-up (wrapping policy = “apply where needed”)
**Do not** only call `fully_shard` on the topmost module.

Recommended sharding pattern for transformer-like models:
- iterate modules, `if isinstance(m, TransformerBlock): fully_shard(m, ...)`
- then `fully_shard(model, ...)`

Why:
- `fully_shard` forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.

Reference: `references/pytorch_fully_shard_api.md` (bottom-up requirement and why).

---

### 4) Configure `reshard_after_forward` for memory/perf trade-offs
Default behavior:
- `None` means `True` for non-root modules and `False` for root modules (good default).

Heuristics:
- If you’re memory-bound: keep defaults or force `True` on many blocks.
- If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often `False`).
- Advanced: use an `int` to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.

Reference: `references/pytorch_fully_shard_api.md` (full semantics).

---

### 5) Mixed precision & offload (optional but common)
FSDP2 uses:
- `mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)`
- `offload_policy=CPUOffloadPolicy()` if you want CPU offload

Rules of thumb:
- Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
- Keep `reduce_dtype` aligned with your gradient reduction expectations.
- If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.

Reference: `references/pytorch_fully_shard_api.md` (MixedPrecisionPolicy / OffloadPolicy classes).

---

### 6) Optimizer, gradient clipping, accumulation
- Create the optimizer **after** sharding so it holds DTensor params.
- If you need gradient accumulation / no_sync:
  - use the FSDP2 mechanism (`set_requires_gradient_sync`) instead of FSDP1’s `no_sync()`.

Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.

Reference: `references/pytorch_fsdp2_tutorial.md`.

---

### 7) Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:

**A) Distributed Checkpoint (DCP) — best default**
- DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
- DCP produces **multiple files** (often at least one per rank) and operates “in place”.

**B) Distributed state dict helpers**
- `get_model_state_dict` / `set_model_state_dict` with `StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)`
- For optimizer: `get_optimizer_state_dict` / `set_optimizer_state_dict`

Avoid:
- Saving DTensor state dicts with plain `torch.save` unless you intentionally convert with `DTensor.full_tensor()` and manage memory carefully.

References:
- `references/pytorch_dcp_overview.md` (DCP behavior and caveats)
- `references/pytorch_dcp_recipe.md` and `references/pytorch_dcp_async_recipe.md` (end-to-end usage)
- `references/pytorch_fsdp2_tutorial.md` (DTensor vs DCP state-dict flows)
- `references/pytorch_examples_fsdp2.md` (working checkpoint scripts)

---

## Workflow checklists (copy-paste friendly)

### Workflow A: Retrofit FSDP2 into an existing training script
- [ ] Launch with `torchrun` and initialize the process group.
- [ ] Set the CUDA device from `LOCAL_RANK`; create a `DeviceMesh` if you need multi-dim parallelism.
- [ ] Build the model (use `meta` if needed), apply `fully_shard` bottom-up, then `fully_shard(model)`.
- [ ] Create the optimizer after sharding so it captures DTensor parameters.
- [ ] Use `model(inputs)` so hooks run; use `set_requires_gradient_sync` for accumulation.
- [ ] Add DCP save/load via `torch.distributed.checkpoint` helpers.

Reference: `references/pytorch_fsdp2_tutorial.md`, `references/pytorch_fully_shard_api.md`, `references/pytorch_device_mesh_tutorial.md`, `references/pytorch_dcp_recipe.md`.

### Workflow B: Add DCP save/load (minimal pattern)
- [ ] Wrap state in `Stateful` or assemble state via `get_state_dict`.
- [ ] Call `dcp.save(...)` from all ranks to a shared path.
- [ ] Call `dcp.load(...)` and restore with `set_state_dict`.
- [ ] Validate any resharding assumptions when loading into a different mesh.

Reference: `references/pytorch_dcp_recipe.md`.

## Debug checklist (what the agent should check first)

1. **All ranks on distinct GPUs?**  
   If not, verify `torch.cuda.set_device(LOCAL_RANK)` and your `torchrun` flags.
2. **Did you accidentally call `forward()` directly?**  
   Use `model(input)` or explicitly `unshard()` / register forward.
3. **Is `fully_shard()` applied bottom-up?**  
   If only root is sharded, expect worse memory/perf and possible confusion.
4. **Optimizer created at the right time?**  
   Must be built on DTensor parameters *after* sharding.
5. **Checkpointing path consistent?**  
   - If using DCP, don’t mix with ad-hoc `torch.save` unless you understand conversions.
   - Be mindful of PyTorch-version compatibility warnings for DCP.

---

## Common issues and fixes

- **Forward hooks not running** → Call `model(inputs)` (or `unshard()` explicitly) instead of `model.forward(...)`.
- **Optimizer sees non-DTensor params** → Create optimizer after all `fully_shard` calls.
- **Only root module sharded** → Apply `fully_shard` bottom-up on submodules before the root.
- **Memory spikes after forward** → Set `reshard_after_forward=True` for more modules.
- **Gradient accumulation desync** → Use `set_requires_gradient_sync` instead of FSDP1’s `no_sync()`.

Reference: `references/pytorch_fully_shard_api.md`, `references/pytorch_fsdp2_tutorial.md`.

---

## Minimal reference implementation outline (agent-friendly)

The coding agent should implement a script with these labeled blocks:

- `init_distributed()`: init process group, set device
- `build_model_meta()`: model on meta, apply `fully_shard`, materialize weights
- `build_optimizer()`: optimizer created after sharding
- `train_step()`: forward/backward/step with `model(inputs)` and DTensor-aware patterns
- `checkpoint_save/load()`: DCP or distributed state dict helpers

Concrete examples live in `references/pytorch_examples_fsdp2.md` and the official tutorial reference.

---

## References
- `references/pytorch_fsdp2_tutorial.md`
- `references/pytorch_fully_shard_api.md`
- `references/pytorch_ddp_notes.md`
- `references/pytorch_fsdp1_api.md`
- `references/pytorch_device_mesh_tutorial.md`
- `references/pytorch_tp_tutorial.md`
- `references/pytorch_dcp_overview.md`
- `references/pytorch_dcp_recipe.md`
- `references/pytorch_dcp_async_recipe.md`
- `references/pytorch_examples_fsdp2.md`
- `references/torchtitan_fsdp_notes.md` (optional, production notes)
- `references/ray_train_fsdp2_example.md` (optional, integration example)


---

## Referenced Files

> The following files are referenced in this skill and included for context.

### references/pytorch_fully_shard_api.md

```markdown
# Reference: `torch.distributed.fsdp.fully_shard` API (FSDP2)

**Source (official):** PyTorch docs — `torch.distributed.fsdp.fully_shard`  
https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html  
Created: Dec 04, 2024 • Last updated: Oct 13, 2025

## Key facts (paraphrased from the API docs)

### User contract highlights
- `fully_shard(model)` converts `model.parameters()` to **DTensor** at init, then hooks **all-gather** before forward/backward and **free/reshard** after.  
- The optimizer **must be initialized with DTensor parameters** and step must happen on DTensors.
- Call `model(input)` (not `model.forward(input)`) so hooks run; otherwise explicitly `unshard()` or register the forward method for hooking.
- Apply `fully_shard` **bottom-up**: shard submodules first, then the root module, to form efficient communication groups and enable overlap.
- `fully_shard` “unions” the module type in-place with `FSDPModule`, enabling methods like `unshard()` / `reshard()`.

> Short excerpt (<= 25 words): “Users generally should not call fully_shard() only on the topmost root module.”

### Signature & core args
`fully_shard(module, *, mesh=None, reshard_after_forward=None, shard_placement_fn=None, mp_policy=MixedPrecisionPolicy(...), offload_policy=OffloadPolicy(), ignored_params=None)`

- **mesh** (`DeviceMesh`):  
  - 1D mesh ⇒ “classic” FSDP sharding, placement `(Shard(0),)`  
  - 2D mesh ⇒ Hybrid sharding (HSDP): sharded across one dim, replicated across the other, placement `(Replicate(), Shard(0))`
- **reshard_after_forward**:
  - `True`: free unsharded params after forward (re-all-gather during backward)
  - `False`: keep unsharded params after forward (avoid backward all-gather)
  - `None`: defaults to `True` for non-root, `False` for root
  - `int`: reshard to a smaller world-size after forward (must divide shard-dim size)
- **shard_placement_fn**: override per-parameter sharding dim (requires even sharding if not dim-0)
- **ignored_params**: parameters not sharded / not moved / not reduced

## Mixed precision & offload policy classes (same doc page)

### `MixedPrecisionPolicy`
Controls:
- `param_dtype`: dtype used for unsharded parameters during forward/backward
- `reduce_dtype`: dtype used for gradient reduction
- `output_dtype`: dtype used for forward output
- `cast_forward_inputs`: whether to cast forward inputs to `param_dtype`

### `OffloadPolicy` and `CPUOffloadPolicy`
OffloadPolicy controls:
- `param_device` / `reduce_device` / `output_device` (and for CPU offload policy, also `optimizer_state_device`)

## Practical implications for agents
- **Bottom-up sharding** is not optional: it affects grouping and memory/perf.
- **Don’t bypass hooks**: using `model.forward` directly breaks all-gather scheduling.
- **Optimizer construction order matters**: construct optimizer after `fully_shard`.

```

### references/pytorch_fsdp2_tutorial.md

```markdown
# Reference: Getting Started with Fully Sharded Data Parallel (FSDP2) tutorial

**Source (official):** PyTorch Tutorials — “Getting Started with Fully Sharded Data Parallel (FSDP2)”  
https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html  
Created: Mar 17, 2022 • Last updated: Sep 02, 2025 • Last verified: Nov 05, 2024

## What the tutorial emphasizes

### How FSDP2 differs from DDP and FSDP1
- FSDP shards **parameters, gradients, and optimizer state**; parameters are all-gathered for compute and reduce-scattered for grads.
- Compared to FSDP1, FSDP2:
  - uses **DTensor per-parameter sharding** (more direct manipulation; sharded state dicts)
  - improves memory management for more deterministic memory behavior
  - supports extensibility points for custom all-gather (e.g., float8/NF4 use cases)

### Model initialization flow (meta-device pattern)
The tutorial’s migration section shows a typical pattern:
- initialize model on `meta`
- apply `fully_shard` to the intended layers (policy expressed by explicit calls)
- apply `fully_shard` to the root module
- materialize weights via `to_empty(device="cuda")`, then run `reset_parameters()`

### State dict workflows
The tutorial describes two main ways:

**A) DTensor APIs (manual)**
- Loading: use `distribute_tensor(full_tensor, meta_param.device_mesh, meta_param.placements)` then `model.load_state_dict(..., assign=True)`
- Saving: call `DTensor.full_tensor()` to all-gather; optionally CPU-offload on rank0 to avoid peak GPU memory

**B) DCP distributed state-dict helpers (recommended when no custom handling needed)**
- Loading: `set_model_state_dict(..., StateDictOptions(full_state_dict=True, broadcast_from_rank0=True))`
- Saving: `get_model_state_dict(..., StateDictOptions(full_state_dict=True, cpu_offload=True))`
- Points to `pytorch/examples` for optimizer state dict save/load with `set_optimizer_state_dict` / `get_optimizer_state_dict`

### Migration guide mapping
The tutorial explicitly maps FSDP1 concepts to FSDP2:
- `sharding_strategy` ↔ `reshard_after_forward` (+ 2D mesh for HYBRID)
- `cpu_offload` ↔ `offload_policy` (`CPUOffloadPolicy`)
- `no_sync()` ↔ `set_requires_gradient_sync`
- `sync_module_states` moves to DCP broadcast-from-rank0 flows

## Practical takeaways for agents
- Express wrapping policy by **explicitly applying `fully_shard`** to chosen submodules.
- Use DCP APIs for flexible checkpointing and resharding unless you must interop with third-party formats.

```

### references/pytorch_ddp_notes.md

```markdown
# Reference: Distributed Data Parallel (DDP) notes

**Source (official):** PyTorch docs — “Distributed Data Parallel”  
https://docs.pytorch.org/docs/stable/notes/ddp.html  
Last accessed: Jan 30, 2026

## Key points (paraphrased from the notes)
- DDP is the standard PyTorch wrapper for distributed data parallel training.
- Typical usage includes initializing the process group, wrapping the model with `DistributedDataParallel`, and training normally.

```

### references/pytorch_fsdp1_api.md

```markdown
# Reference: Fully Sharded Data Parallel (FSDP1) API

**Source (official):** PyTorch docs — “Fully Sharded Data Parallel”  
https://docs.pytorch.org/docs/stable/fsdp.html  
Last accessed: Jan 30, 2026

## Key points (paraphrased from the API docs)
- `torch.distributed.fsdp.FullyShardedDataParallel` is the original FSDP wrapper for sharding module parameters across data-parallel workers.

```

### references/pytorch_device_mesh_tutorial.md

```markdown
# Reference: Getting Started with DeviceMesh (PyTorch tutorial)

**Source (official):** PyTorch Recipes — “Getting Started with DeviceMesh”  
https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html  
Created: Jan 24, 2024 • Last updated: Jul 18, 2025 • Last verified: Nov 05, 2024

## What DeviceMesh is (as defined by the tutorial)
DeviceMesh is a higher-level abstraction that **manages ProcessGroups**, making it easier to set up the right communication groups for multi-dimensional parallelism.

The tutorial motivation:
- Without DeviceMesh, users must manually compute rank groupings (replicate/shard groups) and create multiple process groups.
- With DeviceMesh, you describe topology with a shape (e.g., 2D mesh), and slice submeshes by dimension name.

## Why this matters for FSDP2
FSDP2 `fully_shard(..., mesh=...)` takes a `DeviceMesh`:
- 1D mesh: standard full sharding across DP workers.
- 2D mesh: hybrid sharding (HSDP), combining replication + sharding across mesh dimensions.

So the agent should:
- Prefer to create a DeviceMesh early (after init_process_group and setting CUDA device).
- Pass the correct (sub)mesh into `fully_shard` if composing with TP or other dimensions.

```

### references/pytorch_dcp_overview.md

```markdown
# Reference: Distributed Checkpoint (DCP) overview (torch.distributed.checkpoint)

**Source (official):** PyTorch docs — `torch.distributed.checkpoint`  
https://docs.pytorch.org/docs/stable/distributed.checkpoint.html  
Created: Nov 16, 2022 • Last updated: Oct 08, 2025

## What DCP does
- Supports saving/loading from **multiple ranks in parallel**
- Handles **load-time resharding**, enabling saving with one cluster topology and loading into another
- Produces **multiple files per checkpoint** (often at least one per rank)
- Operates “in place”: the model allocates storage first; DCP loads into that storage

## Important caveats
- The docs warn: **no guarantees of backwards compatibility** across PyTorch versions for saved `state_dict`s.
- Process-group usage: if you pass a process group, only those ranks should call save/load, and all tensors must belong to that group.

## Where to learn usage
The doc links to official “Getting Started with DCP” and “Asynchronous Saving with DCP” recipes.

```

### references/pytorch_dcp_recipe.md

```markdown
# Reference: Getting Started with Distributed Checkpoint (DCP) recipe

**Source (official):** PyTorch Tutorials recipe — “Getting Started with Distributed Checkpoint (DCP)”  
https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html  
Created: Oct 02, 2023 • Last updated: Jul 10, 2025 • Last verified: Nov 05, 2024

## Key ideas shown in the recipe
- DCP saves/loads in parallel, and supports resharding across topologies at load time.
- It provides helpers under `torch.distributed.checkpoint.state_dict` to manage distributed `state_dict` generation/loading.

## Example structure (high level)
- Wrap application state in a `Stateful` object, so DCP automatically calls `state_dict()` / `load_state_dict()`
- Use `dcp.save(...)` / `dcp.load(...)`
- Use `get_state_dict` / `set_state_dict` helpers to correctly obtain and apply model/optimizer state dicts in distributed settings

## Practical agent guidance
If adding checkpointing to an FSDP2 training script, this recipe’s patterns are the safest default.

```

### references/pytorch_dcp_async_recipe.md

```markdown
# Reference: Asynchronous Saving with Distributed Checkpoint (DCP) recipe

**Source (official):** PyTorch Tutorials recipe — “Asynchronous Saving with Distributed Checkpoint (DCP)”  
https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html  
Created: Jul 22, 2024 • Last updated: Sep 29, 2025 • Last verified: Nov 05, 2024

## What async checkpointing changes
- Moves checkpointing off the critical training path via `torch.distributed.checkpoint.async_save`
- Introduces extra memory overhead because async save first copies model state into internal CPU buffers

## Practical agent guidance
- Use async save when checkpoint stalls are significant and you have headroom for CPU memory.
- Consider pinned memory strategies described in the recipe if performance matters.

```

### references/pytorch_examples_fsdp2.md

```markdown
# Reference: Official `pytorch/examples` FSDP2 scripts

**Sources (official, code):**
- `pytorch/examples` repository: https://github.com/pytorch/examples
- FSDP2 checkpoint example: https://github.com/pytorch/examples/blob/main/distributed/FSDP2/checkpoint.py

## Why this matters
The FSDP2 tutorial explicitly points users to `pytorch/examples` for end-to-end scripts, especially for:
- optimizer state dict save/load with the DCP state-dict helpers
- runnable command lines and minimal scaffolding

## How agents should use this
- Prefer copying patterns from these scripts over inventing new checkpoint logic.
- Keep the script structure (init distributed, build model, shard, optimizer, train loop, save/load) similar to ease debugging.

```

### references/pytorch_tp_tutorial.md

```markdown
# Reference: Tensor Parallel (TP) tutorial (and how it composes with FSDP)

**Source (official):** PyTorch Tutorials — “Large Scale Transformer model training with Tensor Parallel (TP)”  
https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html  
Created: Apr 19, 2024 • Last updated: Jul 18, 2025 • Last verified: Nov 05, 2024

## Key composition pattern: TP intra-host + FSDP inter-host
The tutorial recommends:
- Run TP on a fast intra-host fabric (e.g., NVLink).
- Run FSDP across hosts (inter-host).

It shows a **2D DeviceMesh** pattern and slicing:
- `mesh_2d = init_device_mesh("cuda", (dp, tp))`
- `tp_mesh = mesh_2d["tp"]` and `dp_mesh = mesh_2d["dp"]`
- Apply TP with `parallelize_module(..., tp_mesh, ...)`
- Apply FSDP2 with `fully_shard(..., mesh=dp_mesh, ...)`

## Practical agent guidance
If the user is already doing TP:
- Ensure FSDP2 `mesh` only includes the DP dimension (often inter-host).
- Leave the TP dimension to `torch.distributed.tensor.parallel`.

```

### references/torchtitan_fsdp_notes.md

```markdown
# Reference: TorchTitan notes on FSDP/FSDP2 (production-oriented)

**Source (official-ish, PyTorch org):** TorchTitan — FSDP docs  
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md

## Why include this
TorchTitan is a PyTorch reference stack for large-scale LLM training. Its FSDP documentation often contains pragmatic guidance around:
- configuration choices (e.g., sharding strategy vs memory/perf)
- checkpointing workflows in larger systems
- composition with other parallelisms

## Agent guidance
Treat TorchTitan as a “how people do it in production” complement to the API docs/tutorials. Always defer to the official API docs on semantics.

```

### references/ray_train_fsdp2_example.md

```markdown
# Reference: Ray Train FSDP2 integration guide (third-party, useful patterns)

**Source (third-party):** Ray docs — “Get started with PyTorch FSDP2 (Ray Train)”  
https://docs.ray.io/en/latest/train/examples/pytorch/pytorch-fsdp/README.html

## Why include this
- Shows how to integrate FSDP2 into a higher-level training orchestrator.
- Mentions common mitigation knobs (mixed precision, CPU offload, sharding granularity).
- Demonstrates checkpointing with DCP in a managed training environment.

## Agent guidance
Use as integration inspiration, not as the semantic source of truth.

```

pytorch-fsdp2 | SkillHub