Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -142,31 +142,32 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],

['batch', ['data', 'fsdp_gpu']],
['activation_batch', ['data', 'fsdp_gpu']],
['activation_length', 'fsdp_tpu'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['fsdp_tpu', 'fsdp_gpu']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'fsdp_tpu', 'fsdp_gpu']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'fsdp_tpu'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'tensor', 'fsdp_tpu', 'fsdp_gpu']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
dcn_fsdp_tpu_parallelism: -1
dcn_fsdp_gpu_parallelism: 1 # recommended DCN axis to be auto-sharded
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1
ici_fsdp_tpu_parallelism: -1
ici_fsdp_gpu_parallelism: 1 # recommended ICI axis to be auto-sharded

allow_split_physical_axes: False

Expand Down
49 changes: 37 additions & 12 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True):
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]
if "dcn_fsdp_tpu_parallelism" in config.get_keys():
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_tensor_parallelism,
config.dcn_fsdp_tpu_parallelism,
config.dcn_fsdp_gpu_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_tensor_parallelism,
config.ici_fsdp_tpu_parallelism,
config.ici_fsdp_gpu_parallelism,
]
else:
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
Expand Down Expand Up @@ -641,4 +654,16 @@ def maybe_initialize_jax_distributed_system(raw_keys):
initialize_jax_for_gpu()
max_logging.log("Jax distributed system initialized on GPU!")
else:
jax.distributed.initialize()
jax.distributed.initialize()

def get_axis_names(axis_key: str, config=None) -> str:
"""Returns the mesh axis names given the logical axis key from config.logical_axis_rules."""
axis_name = ''
if config:
axis_rules = config.logical_axis_rules
else:
axis_rules = nn.get_logical_axis_rules()
for rules in axis_rules:
if rules[0] == axis_key:
axis_name = rules[1]
return axis_name
88 changes: 51 additions & 37 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from einops import rearrange
from .. import common_types, max_logging
from .. import max_utils

from . import quantizations

Expand Down Expand Up @@ -68,8 +69,11 @@ def _reshape_data_from_cudnn_flash(tensor):

def _reshape_data_for_cudnn_flash(tensor, heads):
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
batch, seq, heads_and_dim_head = tensor.shape
tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads)
if len(tensor.shape) == 3:
batch, seq, dim_head = tensor.shape
tensor = tensor.reshape(batch, seq, heads, dim_head // heads)
else:
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
return tensor


Expand All @@ -79,7 +83,8 @@ def _reshape_batch_dim_to_heads(tensor, heads):
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _reshape_heads_to_batch_dim(tensor, heads):
Expand All @@ -92,8 +97,8 @@ def _reshape_heads_to_batch_dim(tensor, heads):
else:
batch_size, head_size, seq_len, head_dim = tensor.shape
reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)

return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _reshape_heads_to_head_dim(tensor):
Expand All @@ -102,7 +107,8 @@ def _reshape_heads_to_head_dim(tensor):
b, h, s, d = tensor.shape
tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3])
reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d))
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _unflatten_heads(tensor, heads):
Expand Down Expand Up @@ -200,7 +206,8 @@ def _tpu_flash_attention(
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
)
num_fsdp_shards = mesh.shape["fsdp"]
fsdp_key = max_utils.get_axis_names("activation_length")
num_fsdp_shards = mesh.shape[fsdp_key]
query = _reshape_data_for_flash(query, heads)
key = _reshape_data_for_flash(key, heads)
value = _reshape_data_for_flash(value, heads)
Expand Down Expand Up @@ -274,13 +281,13 @@ def wrap_flash_attention(query, key, value):

perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]

k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
k1 = jax.lax.ppermute(key, axis_name=fsdp_key, perm=perm)
v1 = jax.lax.ppermute(value, axis_name=fsdp_key, perm=perm)

def ring_scan_body(carry, _):
m, l, o, k_current, v_current = carry
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
k_next = jax.lax.ppermute(k_current, axis_name=fsdp_key, perm=perm)
v_next = jax.lax.ppermute(v_current, axis_name=fsdp_key, perm=perm)

out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)

Expand All @@ -305,7 +312,7 @@ def ring_scan_body(carry, _):

return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape[fsdp_key]
# This warning might show up when doing model eval for example, when calculating model flops
# and that is expected.
if not (query.shape[0] / devices_in_data_fsdp).is_integer():
Expand Down Expand Up @@ -403,24 +410,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m
key = _reshape_data_for_cudnn_flash(key, heads)
value = _reshape_data_for_cudnn_flash(value, heads)

cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)

query = nn.with_logical_constraint(query, axis_names)
key = nn.with_logical_constraint(key, axis_names)
value = nn.with_logical_constraint(value, axis_names)

@functools.partial(
shard_map.shard_map,
mesh=mesh,
in_specs=(axis_names, axis_names, axis_names),
out_specs=axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value):
return jax.vmap(dpa_layer)(query, key, value, mask=None)

out = wrap_flash_attention(query, key, value)
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV))
query = jax.lax.with_sharding_constraint(query, axis_names)
key = jax.lax.with_sharding_constraint(key, axis_names)
value = jax.lax.with_sharding_constraint(value, axis_names)

out = dpa_layer(query, key, value, mask=None)
return _reshape_data_from_cudnn_flash(out)


Expand Down Expand Up @@ -611,7 +606,24 @@ def __init__(
):
self.dpa_layer = None
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"{self} has not been tested with {attention_kernel}")
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
jax.config.update("jax_use_shardy_partitioner", False)

dpa_layer = DotProductAttention(
head_dim=dim_head,
num_attention_heads=heads,
num_gqa_groups=heads,
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
# attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=dtype,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=scale,
transpose_batch_sequence=False,
)
variables = {}
self.dpa_layer = functools.partial(dpa_layer.apply, variables)

self.mesh = mesh
self.scale = scale
Expand Down Expand Up @@ -672,8 +684,9 @@ def setup(self):
self.dpa_layer = None
if self.attention_kernel == "cudnn_flash_te":
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
jax.config.update("jax_use_shardy_partitioner", False)

self.dpa_layer = DotProductAttention(
dpa_layer = DotProductAttention(
head_dim=self.dim_head,
num_attention_heads=self.heads,
num_gqa_groups=self.heads,
Expand All @@ -687,6 +700,9 @@ def setup(self):
scale_factor=self.scale,
transpose_batch_sequence=False,
)
variables = {}
self.dpa_layer = functools.partial(dpa_layer.apply, variables)


def apply_attention(self, query: Array, key: Array, value: Array):
return _apply_attention(
Expand Down Expand Up @@ -740,9 +756,6 @@ def __init__(
residual_checkpoint_name: str | None = None,
enable_jax_named_scopes: bool = False,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")

if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
self.dim_head = dim_head
Expand Down Expand Up @@ -889,8 +902,9 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs = None,
) -> jax.Array:
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
Expand Down
6 changes: 5 additions & 1 deletion src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
from flax import nnx
from ...configuration_utils import ConfigMixin
from ... import max_utils
from ..modeling_flax_utils import FlaxModelMixin, get_activation
from ... import common_types
from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput)
Expand Down Expand Up @@ -72,7 +73,10 @@ def __init__(
self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0]

# Set sharding dynamically based on out_channels.
num_fsdp_axis_devices = mesh.device_ids.shape[1]
fspd_key = max_utils.get_axis_names("activation_length")
if not fspd_key:
fspd_key = "fsdp"
num_fsdp_axis_devices = mesh.shape[fspd_key]
kernel_sharding = (None, None, None, None, None)
if out_channels % num_fsdp_axis_devices == 0:
kernel_sharding = (None, None, None, None, "conv_out")
Expand Down
10 changes: 6 additions & 4 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,11 @@ def __call__(
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
Expand Down Expand Up @@ -501,7 +503,7 @@ def init_block(rngs):
if scan_layers:
self.blocks = init_block(rngs)
else:
blocks = nnx.List([])
blocks = []
for _ in range(num_layers):
block = WanTransformerBlock(
rngs=rngs,
Expand All @@ -521,7 +523,7 @@ def init_block(rngs):
enable_jax_named_scopes=enable_jax_named_scopes,
)
blocks.append(block)
self.blocks = blocks
self.blocks = nnx.data(blocks)

self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
self.proj_out = nnx.Linear(
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/train_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def main(argv: Sequence[str]) -> None:
config = pyconfig.config
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
flax.config.update("flax_always_shard_variable", False)
try:
flax.config.update("flax_always_shard_variable", False)
except:
pass
train(config)


Expand Down
Loading