From 02c28a50103a38ced74593fc0e0f48121c63acde Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 22:30:48 -0500 Subject: [PATCH 01/11] Triton loss --- fast_llm/functional/triton/cross_entropy.py | 2 +- tests/layers/test_lm_losses.py | 24 ++++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index ef2039ad..a8becfb6 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -128,6 +128,7 @@ def triton_cross_entropy_forward_backward( target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, + looped: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -143,7 +144,6 @@ def triton_cross_entropy_forward_backward( n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) - assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 639a3ba7..1a31db90 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,4 +1,3 @@ -import contextlib import pathlib import random @@ -173,18 +172,17 @@ def _test_entropy_loss( # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED - with pytest.raises(AssertionError) if num_columns > 65536 else contextlib.nullcontext(): - out_triton, grad_triton = entropy_loss_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.triton, - ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + out_triton, grad_triton = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.triton, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): From 22c8e0ba0935a63db086042a18f824903c8f4f08 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 22:33:49 -0500 Subject: [PATCH 02/11] Triton loss --- fast_llm/functional/triton/cross_entropy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index a8becfb6..354223a9 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -128,7 +128,6 @@ def triton_cross_entropy_forward_backward( target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, - looped: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, From 7491e0f8bf0619a5ec647d4af858589a191f4761 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 23:39:32 -0500 Subject: [PATCH 03/11] Parallel attempt --- fast_llm/functional/entropy_loss.py | 3 +- fast_llm/functional/triton/cross_entropy.py | 62 ++++++++++++++++++- .../language_model/loss/entropy_loss.py | 35 ++++------- 3 files changed, 73 insertions(+), 27 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index d56c745a..37486ddc 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -14,6 +14,7 @@ def torch_entropy_loss_forward_backward( logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, + group: ProcessGroup | None = None, temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: # (), (*batch, vocab) """ @@ -21,7 +22,7 @@ def torch_entropy_loss_forward_backward( The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. """ - + assert group is None # Torch methods require flattened batch dimension. target = target.flatten() if target_format == TargetFormat.labels else target.flatten(0, -2) if target_format == TargetFormat.labels: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 354223a9..12d96a88 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -5,12 +5,42 @@ from fast_llm.utils import Assert +@triton_jit() +def triton_softmax_base_kernel( + logits_ptr, + max_logits_ptr, + sum_exp_logits_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + logits_scale_factor: tl_constexpr, + block_size: tl_constexpr, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, block_size) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + mask = col_offsets < n_cols + + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + + tl.store(max_logits_ptr + block_idx, max_logits) + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + + @triton_jit() def triton_cross_entropy_forward_backward_kernel( logits_ptr, labels_ptr, grad_logits_ptr, losses_ptr, + max_logits_ptr, + sum_exp_logits_ptr, grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, @@ -28,9 +58,15 @@ def triton_cross_entropy_forward_backward_kernel( if logits_scale_factor != 1.0: logits *= logits_scale_factor - max_logits = tl.max(logits, 0) + if max_logits_ptr is None: + max_logits = tl.max(logits, 0) + else: + max_logits = tl.load(max_logits_ptr + block_idx) exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + if sum_exp_logits_ptr is None: + sum_exp_logits = tl.sum(exp_logits, 0) + else: + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) label_idx = tl.load(labels_ptr + block_idx) @@ -127,6 +163,7 @@ def triton_cross_entropy_forward_backward( logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, + group: torch.distributed.ProcessGroup | None = None, temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -148,11 +185,31 @@ def triton_cross_entropy_forward_backward( # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: + if group is None: + max_logits = sum_exp_logits = None + else: + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) + triton_softmax_base_kernel[(n_rows,)]( + logits, + local_max_logits, + sum_exp_logits, + n_cols, + logits.stride(-2), + logits_scale_factor, + block_size=block_size, + ) + max_logits = local_max_logits.clone() + torch.distributedall_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = sum_exp_logits * (local_max_logits - max_logits).exp() + torch.distributedall_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, grad_logits, losses, + max_logits, + sum_exp_logits, None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(-2), @@ -162,6 +219,7 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, ) else: + assert group is None if loss_mask is not None: assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 1dfd3920..550f8f33 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -122,27 +122,14 @@ def entropy_loss_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, EntropyLossImplementation.fused) - return fused_entropy_loss_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - group, - temperature, - ) - else: - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - temperature=temperature, - ) + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature=temperature, + ) From b8e7179976f49c64947c177d93293c3786e013b0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 3 Feb 2026 00:26:45 -0500 Subject: [PATCH 04/11] fix --- fast_llm/functional/triton/cross_entropy.py | 145 ++++++++++++++------ 1 file changed, 105 insertions(+), 40 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 12d96a88..22498cf4 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -6,10 +6,13 @@ @triton_jit() -def triton_softmax_base_kernel( +def triton_cross_entropy_forward_parallel_kernel( logits_ptr, + labels_ptr, max_logits_ptr, sum_exp_logits_ptr, + predicted_logits_ptr, + col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, @@ -29,6 +32,17 @@ def triton_softmax_base_kernel( exp_logits = tl.exp(logits - max_logits) sum_exp_logits = tl.sum(exp_logits, 0) + if labels_ptr is not None and predicted_logits_ptr is not None: + label_idx = tl.load(labels_ptr + block_idx) - col_min + if label_idx < 0 or label_idx >= n_cols: + # Loss mask + predicted_logits = 0.0 + else: + predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logits *= logits_scale_factor + tl.store(predicted_logits_ptr + block_idx, predicted_logits) + tl.store(max_logits_ptr + block_idx, max_logits) tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) @@ -42,6 +56,7 @@ def triton_cross_entropy_forward_backward_kernel( max_logits_ptr, sum_exp_logits_ptr, grad_losses, + col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, @@ -68,26 +83,32 @@ def triton_cross_entropy_forward_backward_kernel( else: sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) - label_idx = tl.load(labels_ptr + block_idx) + label_idx = tl.load(labels_ptr + block_idx) - col_min - if label_idx < 0: - # Loss mask - loss = 0.0 - else: - label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) - if logits_scale_factor != 1.0: - label_logits *= logits_scale_factor - loss = tl.log(sum_exp_logits) + max_logits - label_logits - tl.store(losses_ptr + block_idx, loss) + if losses_ptr is not None: + if label_idx < 0 or label_idx >= n_cols: + # Loss mask + loss = 0.0 + else: + predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logits *= logits_scale_factor + loss = tl.log(sum_exp_logits) + max_logits - predicted_logits + tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - if label_idx < 0: + if label_idx < -col_min: grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor grad_base = exp_logits / sum_exp_logits - grad_logits = grad_losses * tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) - if logits_scale_factor != 1.0: - grad_logits *= logits_scale_factor - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + if label_idx < 0 or label_idx >= n_cols: + grad_logits = grad_base + else: + grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask + ) @triton_jit() @@ -155,6 +176,25 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) +@torch.compile +def _rescale_sum_exp_logits( + sum_exp_logits: torch.Tensor, + local_max_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + return sum_exp_logits * (local_max_logits - max_logits).exp() + + +@torch.compile +def _calculate_loss( + predicted_logits: torch.Tensor, + target: torch.Tensor, + sum_exp_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() + + def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -181,45 +221,69 @@ def triton_cross_entropy_forward_backward( n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: if group is None: - max_logits = sum_exp_logits = None + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None, + None, + None if grad_output is None else grad_output / n_rows, + 0, + n_cols, + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) + loss = losses.mean() else: - local_max_logits = torch.empty_like(losses) - sum_exp_logits = torch.empty_like(losses) - triton_softmax_base_kernel[(n_rows,)]( + predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(predicted_logits) + sum_exp_logits = torch.empty_like(predicted_logits) + triton_cross_entropy_forward_parallel_kernel[(n_rows,)]( logits, + target, local_max_logits, sum_exp_logits, + predicted_logits, + n_cols * group.rank(), n_cols, logits.stride(-2), logits_scale_factor, block_size=block_size, ) max_logits = local_max_logits.clone() - torch.distributedall_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) - sum_exp_logits = sum_exp_logits * (local_max_logits - max_logits).exp() - torch.distributedall_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( - logits, - target, - grad_logits, - losses, - max_logits, - sum_exp_logits, - None if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - ) + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) + loss = _calculate_loss(predicted_logits, target, sum_exp_logits, max_logits) + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + None, + max_logits, + sum_exp_logits, + None if grad_output is None else grad_output / n_rows, + n_cols * group.rank(), + n_cols, + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) else: assert group is None + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) if loss_mask is not None: assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( @@ -238,4 +302,5 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, from_logits=target_format == TargetFormat.logits, ) - return losses.mean(), grad_logits + loss = losses.mean() + return loss, grad_logits From 3c3e0c8704d14fdfe77d41370e70a6b0b1242bce Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 3 Feb 2026 15:45:19 -0500 Subject: [PATCH 05/11] fixes --- fast_llm/functional/triton/cross_entropy.py | 97 ++++++++++++------- .../language_model/loss/entropy_loss.py | 2 + 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 22498cf4..82312b99 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -5,6 +5,32 @@ from fast_llm.utils import Assert +@triton_jit() +def triton_fused_softmax_base( + logits_ptr, + n_cols: tl_constexpr, + logits_scale_factor: tl_constexpr, + block_size: tl_constexpr, +): + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl.arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if col_offset == 0: + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + max_logits = new_max_logits + return exp_logits, sum_exp_logits, max_logits, mask + + @triton_jit() def triton_cross_entropy_forward_parallel_kernel( logits_ptr, @@ -20,17 +46,11 @@ def triton_cross_entropy_forward_parallel_kernel( ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) logits_ptr = logits_ptr + block_idx * logits_stride_0 - mask = col_offsets < n_cols - - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( + logits_ptr, n_cols, logits_scale_factor, block_size + ) if labels_ptr is not None and predicted_logits_ptr is not None: label_idx = tl.load(labels_ptr + block_idx) - col_min @@ -65,22 +85,14 @@ def triton_cross_entropy_forward_backward_kernel( ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) logits_ptr = logits_ptr + block_idx * logits_stride_0 - mask = col_offsets < n_cols - - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - if max_logits_ptr is None: - max_logits = tl.max(logits, 0) + if max_logits_ptr is None or sum_exp_logits_ptr is None: + exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( + logits_ptr, n_cols, logits_scale_factor, block_size + ) else: max_logits = tl.load(max_logits_ptr + block_idx) - exp_logits = tl.exp(logits - max_logits) - if sum_exp_logits_ptr is None: - sum_exp_logits = tl.sum(exp_logits, 0) - else: sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) label_idx = tl.load(labels_ptr + block_idx) - col_min @@ -89,6 +101,7 @@ def triton_cross_entropy_forward_backward_kernel( if label_idx < 0 or label_idx >= n_cols: # Loss mask loss = 0.0 + predicted_logits = 0.0 else: predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if logits_scale_factor != 1.0: @@ -97,18 +110,28 @@ def triton_cross_entropy_forward_backward_kernel( tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - if label_idx < -col_min: - grad_losses = 0.0 - elif logits_scale_factor != 1.0: - grad_losses *= logits_scale_factor - grad_base = exp_logits / sum_exp_logits - if label_idx < 0 or label_idx >= n_cols: - grad_logits = grad_base - else: - grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) - tl.store( - grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask - ) + # Run in reverse order to maximize input and cache reuse. + for col_offset in tl.static_range((n_cols - 1) // block_size * block_size, -1, -block_size): + if max_logits_ptr is None or sum_exp_logits_ptr is None or col_offset != n_cols - block_size: + col_offsets = tl.arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + + if label_idx < -col_min: + grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + grad_base = exp_logits / sum_exp_logits + if label_idx < 0 or label_idx >= n_cols: + grad_logits = grad_base + else: + grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask + ) @triton_jit() @@ -205,6 +228,8 @@ def triton_cross_entropy_forward_backward( entropy_loss_type: EntropyLossType, group: torch.distributed.ProcessGroup | None = None, temperature: float = 1.0, + block_size: int | None = None, + num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -219,8 +244,10 @@ def triton_cross_entropy_forward_backward( assert target.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 550f8f33..351aa210 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -106,6 +106,7 @@ def entropy_loss_forward_backward( temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -132,4 +133,5 @@ def entropy_loss_forward_backward( entropy_loss_type, group, temperature=temperature, + **kwargs, ) From 2d293eaf7dc5782c2c3c204d3fe7f8c1747bd138 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 5 Feb 2026 02:51:31 -0500 Subject: [PATCH 06/11] Cross-entropy from distribution --- fast_llm/functional/entropy_loss.py | 8 +- fast_llm/functional/triton/__init__.py | 17 + fast_llm/functional/triton/cross_entropy.py | 509 ++++++++++++++------ tests/layers/test_lm_losses.py | 89 ++-- 4 files changed, 444 insertions(+), 179 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 37486ddc..25e1ae31 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -121,7 +121,7 @@ def fused_softmax_base( @torch.compile -def _fused_reverse_kl_base( +def _fused_reverse_kl_base_from_distribution( logits: torch.Tensor, # (*batch, vocab) target: torch.Tensor, # (*batch, vocab) grad_output: float | None, @@ -161,7 +161,7 @@ def _fused_reverse_kl_base( @torch.compile -def _fused_cross_entropy_base( +def _fused_cross_entropy_base_from_distribution( logits: torch.Tensor, # (*batch, vocab) target: torch.Tensor, # (*batch, vocab) grad_output: float | None, @@ -302,7 +302,7 @@ def fused_entropy_loss_forward_backward( group, ) elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): - per_sample_loss, grad = _fused_cross_entropy_base( + per_sample_loss, grad = _fused_cross_entropy_base_from_distribution( logits, target, grad_output, @@ -313,7 +313,7 @@ def fused_entropy_loss_forward_backward( return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - per_sample_loss, grad = _fused_reverse_kl_base( + per_sample_loss, grad = _fused_reverse_kl_base_from_distribution( logits, target, grad_output, diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 778559db..82f67621 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -1,16 +1,33 @@ +import torch + from fast_llm.utils import InvalidObject, try_decorate try: import triton + import triton.knobs import triton.language as tl tl_constexpr = tl.constexpr TritonConfig = triton.Config + triton_available = torch.cuda.is_available() or triton.knobs.runtime.interpret except ImportError as e: triton = InvalidObject(e) tl = triton tl_constexpr = None TritonConfig = lambda *args, **kwargs: None + triton_available = False triton_jit = try_decorate(lambda: triton.jit) triton_autotune = try_decorate(lambda: triton.autotune) + + +if not triton_available: + tl_arange = None +elif triton.knobs.runtime.interpret: + # Workaround for a triton bug. + @triton_jit + def tl_arange(start, end): + return tl.arange(int(start), int(end)) + +else: + tl_arange = tl.arange diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 82312b99..6fb8e930 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.utils import Assert @@ -9,11 +9,11 @@ def triton_fused_softmax_base( logits_ptr, n_cols: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + logits_scale_factor: tl_constexpr = 1.0, ): for col_offset in tl.static_range(0, n_cols, block_size): - col_offsets = tl.arange(col_offset, col_offset + block_size) + col_offsets = tl_arange(col_offset, col_offset + block_size) mask = col_offsets < n_cols logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: @@ -28,28 +28,28 @@ def triton_fused_softmax_base( exp_logits = tl.exp(logits - new_max_logits) sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) max_logits = new_max_logits - return exp_logits, sum_exp_logits, max_logits, mask + return exp_logits, sum_exp_logits, max_logits @triton_jit() -def triton_cross_entropy_forward_parallel_kernel( +def triton_cross_entropy_forward_from_labels_parallel_kernel( logits_ptr, labels_ptr, - max_logits_ptr, - sum_exp_logits_ptr, - predicted_logits_ptr, - col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 - exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( - logits_ptr, n_cols, logits_scale_factor, block_size + exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) if labels_ptr is not None and predicted_logits_ptr is not None: @@ -63,33 +63,35 @@ def triton_cross_entropy_forward_parallel_kernel( predicted_logits *= logits_scale_factor tl.store(predicted_logits_ptr + block_idx, predicted_logits) - tl.store(max_logits_ptr + block_idx, max_logits) - tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) @triton_jit() -def triton_cross_entropy_forward_backward_kernel( +def triton_cross_entropy_forward_backward_from_labels_kernel( logits_ptr, labels_ptr, - grad_logits_ptr, - losses_ptr, - max_logits_ptr, - sum_exp_logits_ptr, - grad_losses, - col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, - grad_logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 if max_logits_ptr is None or sum_exp_logits_ptr is None: - exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( - logits_ptr, n_cols, logits_scale_factor, block_size + exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) else: max_logits = tl.load(max_logits_ptr + block_idx) @@ -101,7 +103,6 @@ def triton_cross_entropy_forward_backward_kernel( if label_idx < 0 or label_idx >= n_cols: # Loss mask loss = 0.0 - predicted_logits = 0.0 else: predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if logits_scale_factor != 1.0: @@ -110,20 +111,21 @@ def triton_cross_entropy_forward_backward_kernel( tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: + if label_idx < -col_min: + grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor # Run in reverse order to maximize input and cache reuse. - for col_offset in tl.static_range((n_cols - 1) // block_size * block_size, -1, -block_size): - if max_logits_ptr is None or sum_exp_logits_ptr is None or col_offset != n_cols - block_size: - col_offsets = tl.arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols + col_offset_start = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: logits *= logits_scale_factor exp_logits = tl.exp(logits - max_logits) - if label_idx < -col_min: - grad_losses = 0.0 - elif logits_scale_factor != 1.0: - grad_losses *= logits_scale_factor grad_base = exp_logits / sum_exp_logits if label_idx < 0 or label_idx >= n_cols: grad_logits = grad_base @@ -135,68 +137,209 @@ def triton_cross_entropy_forward_backward_kernel( @triton_jit() -def triton_cross_entropy_from_distribution_forward_backward_kernel( +def triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + from_logits: tl_constexpr = True, + target_logits_scale_factor: tl_constexpr = 1.0, + logits_scale_factor: tl_constexpr = 1.0, + unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. +): + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if from_logits: + target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if target_logits_scale_factor != 1.0: + target_logits *= target_logits_scale_factor + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + + if col_offset == 0: + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + if from_logits: + target_max_logits = tl.max(target_logits, 0) + target_exp_logits = tl.exp(target_logits - target_max_logits) + target_sum_exp_logits = tl.sum(target_exp_logits, 0) + predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits, 0)) + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + predicted_logits = tl.sum(tl.where(mask, target * logits, 0)) + target_max_logits = None + target_sum_exp_logits = None + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + max_logits = new_max_logits + if from_logits: + target_new_max_logits = tl.maximum(tl.max(target_logits, 0), target_max_logits) + target_exp_logits = tl.exp(target_logits - target_new_max_logits) + target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( + target_max_logits - target_new_max_logits + ) + predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( + tl.where(mask, target_exp_logits * logits, 0) + ) + target_max_logits = target_new_max_logits + else: + predicted_logits += tl.sum(tl.where(mask, target * logits, 0)) + + if from_logits: + target = target_exp_logits + if not unscaled_probabilities: + predicted_logits /= target_sum_exp_logits + target /= target_sum_exp_logits + + return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target + + +@triton_jit() +def triton_cross_entropy_from_distribution_forward_parallel_kernel( logits_ptr, target_ptr, - loss_mask_ptr, - grad_logits_ptr, - losses_ptr, - grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, - grad_logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, - from_logits: tl_constexpr, block_size: tl_constexpr, + loss_mask_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + from_logits: tl_constexpr = True, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) - mask = col_offsets < n_cols + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 - if loss_mask_ptr is not None: - loss_mask = tl.load(loss_mask_ptr + block_idx) - if loss_mask == 0: - tl.store(losses_ptr + block_idx, 0) - if grad_losses is not None: - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=mask) - return + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + tl.store(predicted_logits_ptr + block_idx, 0) + return - logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( - tl.float32 + predicted_logits, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( + triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + unscaled_probabilities=True, + ) ) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + if predicted_logits_ptr is not None: + tl.store(predicted_logits_ptr + block_idx, predicted_logits) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) - max_logits = tl.max(logits, 0) - logits_norm = logits - max_logits - exp_logits = tl.exp(logits_norm) - sum_exp_logits = tl.sum(exp_logits, 0) + if target_max_logits_ptr is not None: + tl.store(target_max_logits_ptr + block_idx, target_max_logits) + if target_sum_exp_logits_ptr is not None: + tl.store(target_sum_exp_logits_ptr + block_idx, target_sum_exp_logits) - target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( - tl.float32 - ) - if from_logits: - if logits_scale_factor != 1.0: - target *= logits_scale_factor - max_target_logits = tl.max(target, 0) - exp_target_logits = tl.exp(target - max_target_logits) - sum_exp_target_logits = tl.sum(exp_target_logits, 0) - target = exp_target_logits / sum_exp_target_logits - # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(tl.where(mask, target * logits_norm, 0), 0) - tl.store(losses_ptr + block_idx, loss) +@triton_jit() +def triton_cross_entropy_from_distribution_forward_backward_kernel( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + from_logits: tl_constexpr = True, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( + triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + ) + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + if grad_losses is not None and from_logits: + target_max_logits = tl.load(target_max_logits_ptr + block_idx) + target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) + + if losses_ptr is not None: + # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) + loss = tl.log(sum_exp_logits) + max_logits - predicted_logits + tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) if logits_scale_factor != 1.0: - grad_logits *= logits_scale_factor - if loss_mask_ptr is not None: - grad_logits = grad_logits - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + grad_losses *= logits_scale_factor + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + col_offset_start = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + if from_logits: + target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if target_logits_scale_factor != 1.0: + target_logits *= target_logits_scale_factor + target = tl.exp(target_logits - target_max_logits) / target_sum_exp_logits + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) @torch.compile @@ -208,8 +351,20 @@ def _rescale_sum_exp_logits( return sum_exp_logits * (local_max_logits - max_logits).exp() +def _parallel_sum_exp_logits( + sum_exp_logits: torch.Tensor, + local_max_logits: torch.Tensor, + group: torch.distributed.ProcessGroup | None, +) -> torch.Tensor: + max_logits = local_max_logits.clone() + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + return max_logits, sum_exp_logits + + @torch.compile -def _calculate_loss( +def _cross_entropy_loss_from_labels( predicted_logits: torch.Tensor, target: torch.Tensor, sum_exp_logits: torch.Tensor, @@ -218,6 +373,30 @@ def _calculate_loss( return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() +@torch.compile +def _rescale_predicted_logits( + predicted_logits: torch.Tensor, + target_sum_exp_logits: torch.Tensor, + local_target_max_logits: torch.Tensor, + target_max_logits: torch.Tensor, +): + # We skipped the division by `target_sum_exp_logits` in the triton kernel so we do it here. + return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) / target_sum_exp_logits + + +@torch.compile +def _cross_entropy_loss_from_distribution( + predicted_logits: torch.Tensor, + loss_mask: torch.Tensor | None, + sum_exp_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits + if loss_mask is not None: + per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) + return per_sample_losses.mean() + + def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -248,86 +427,134 @@ def triton_cross_entropy_forward_backward( block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "block_size": block_size, + "num_warps": num_warps, + } + # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) + backward_kwargs = ( + {} + if grad_output is None + else { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / n_rows, + "grad_logits_stride_0": grad_logits.stride(-2), + } + ) if target_format == TargetFormat.labels: if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, target, - grad_logits, - losses, - None, - None, - None if grad_output is None else grad_output / n_rows, - 0, - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, + losses_ptr=losses, + **kwargs, + **backward_kwargs, ) loss = losses.mean() else: predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) local_max_logits = torch.empty_like(predicted_logits) sum_exp_logits = torch.empty_like(predicted_logits) - triton_cross_entropy_forward_parallel_kernel[(n_rows,)]( + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, - local_max_logits, - sum_exp_logits, - predicted_logits, - n_cols * group.rank(), - n_cols, - logits.stride(-2), - logits_scale_factor, - block_size=block_size, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits, + col_min=n_cols * group.rank(), + **kwargs, ) - max_logits = local_max_logits.clone() - torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) - sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _calculate_loss(predicted_logits, target, sum_exp_logits, max_logits) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + loss = _cross_entropy_loss_from_labels(predicted_logits, target, sum_exp_logits, max_logits) + if grad_output is not None: + triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + col_min=n_cols * group.rank(), + **kwargs, + **backward_kwargs, + ) + else: + if group is None: + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if loss_mask is not None: + assert loss_mask.is_contiguous() + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, - grad_logits, - None, - max_logits, - sum_exp_logits, - None if grad_output is None else grad_output / n_rows, - n_cols * group.rank(), - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, + ) + loss = losses.mean() + else: + predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(predicted_logits) + sum_exp_logits = torch.empty_like(predicted_logits) + if target_format == TargetFormat.logits: + local_target_max_logits = torch.empty_like(predicted_logits) + target_sum_exp_logits = torch.empty_like(predicted_logits) + else: + local_target_max_logits = target_sum_exp_logits = None + + triton_cross_entropy_from_distribution_forward_parallel_kernel[(n_rows,)]( + logits, + target, + loss_mask_ptr=loss_mask, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + target_max_logits_ptr=local_target_max_logits, + target_sum_exp_logits_ptr=target_sum_exp_logits, + predicted_logits_ptr=predicted_logits, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, + ) + max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + if target_format == TargetFormat.logits: + target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( + target_sum_exp_logits, local_target_max_logits, group + ) + predicted_logits = _rescale_predicted_logits( + predicted_logits, target_sum_exp_logits, local_target_max_logits, target_max_logits + ) + else: + target_max_logits = None + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) + + loss = _cross_entropy_loss_from_distribution(predicted_logits, loss_mask, sum_exp_logits, max_logits) + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + logits, + target, + loss_mask_ptr=loss_mask, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + target_max_logits_ptr=target_max_logits, + target_sum_exp_logits_ptr=target_sum_exp_logits, + predicted_logits_ptr=predicted_logits, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, ) - else: - assert group is None - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if loss_mask is not None: - assert loss_mask.is_contiguous() - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( - logits, - target / temperature, - loss_mask, - grad_logits, - losses, - None if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(-2), - target.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - from_logits=target_format == TargetFormat.logits, - ) - loss = losses.mean() return loss, grad_logits diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 1a31db90..37cd99fb 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -10,6 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.triton import triton_available from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward @@ -18,9 +19,6 @@ from tests.utils.dataset import get_random_spans from tests.utils.subtest import DistributedTestContext -VOCAB_SIZE = 100 -NUM_TOKENS = 200 - def _get_lm_loss_inputs( num_columns: int, loss_masking: bool, target_format: TargetFormat, batch_shape: tuple[int], dtype @@ -108,15 +106,15 @@ def reference_dpo_loss( _BATCH_SHAPES = ((64,), (16, 8)) _LOSS_PARAMETERS = ( - (500, 1.0, 1.0, False, DataType.float32), # Simple - (512, 1.0, 1.0, False, DataType.float32), # Power of 2 - (500, None, 1.0, False, DataType.float32), # No grad - (500, 1.0, 4.0, False, DataType.float32), # Loss scaling - (500, 4.0, 1.0, False, DataType.float32), # Grad scaling - (500, 1.0, 1.0, True, DataType.float32), # Loss masking - (500, 1.0, 1.0, False, DataType.float16), # Fp16 - (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16, loss masking - (65538, 1.0, 1.0, False, DataType.float32), # Above max block size + (500, 1.0, 1.0, False, DataType.float32, None), # Simple + (256, 1.0, 1.0, False, DataType.float32, None), # Power of 2 + (500, None, 1.0, False, DataType.float32, None), # No grad + (500, 1.0, 4.0, False, DataType.float32, None), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32, None), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32, None), # Loss masking + (500, 1.0, 1.0, False, DataType.float16, None), # Fp16 + (500, 1.0, 1.0, False, DataType.float32, 256), # Looped + (1000, 2.0, 3.0, True, DataType.float16, 256), # Hard ) @@ -129,12 +127,15 @@ def _test_entropy_loss( target_format, entropy_loss_type, dtype, + block_size, group=None, ): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: pytest.skip(reason="Not implemented") # TODO: Test tensor-parallel implementation. logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) + local_logits = split_op(logits, group, -1).contiguous() + local_target = target if target_format == TargetFormat.labels else split_op(target, group, -1).contiguous() # Torch serves as the reference implementation. out_ref, grad_ref = entropy_loss_forward_backward( logits=logits, @@ -147,8 +148,8 @@ def _test_entropy_loss( implementation=EntropyLossImplementation.torch, ) out_fused, grad_fused = entropy_loss_forward_backward( - logits=split_op(logits, group, -1), - target=target if target_format == TargetFormat.labels else split_op(target, group, -1), + logits=local_logits, + target=local_target, loss_mask=loss_mask, grad_output=grad_output, group=group, @@ -157,7 +158,6 @@ def _test_entropy_loss( entropy_loss_type=entropy_loss_type, implementation=EntropyLossImplementation.fused, ) - _compare_losses_and_grads( out_fused, out_ref, @@ -168,21 +168,23 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available() or group is not None: + if entropy_loss_type != EntropyLossType.cross_entropy or not triton_available: # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED out_triton, grad_triton = entropy_loss_forward_backward( - logits=logits, - target=target, + logits=local_logits, + target=local_target, loss_mask=loss_mask, grad_output=grad_output, logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, implementation=EntropyLossImplementation.triton, + group=group, + block_size=block_size, ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): @@ -201,18 +203,34 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los group=group, logits_scale_factor=logits_scale_factor, ) - _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + _compare_losses_and_grads( + out_fused, + out_ref, + grad_output is not None, + grad_fused, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) @pytest.mark.parametrize("entropy_loss_type", EntropyLossType) def test_entropy_loss( - batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type, dtype + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + block_size, ): _test_entropy_loss( batch_shape, @@ -223,23 +241,24 @@ def test_entropy_loss( target_format, entropy_loss_type, dtype, + block_size, ) @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype): +def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size): _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): - logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) - reference_model_logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) - labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + logits = torch.normal(0, 1, (200, 100)) + reference_model_logits = torch.normal(0, 1, (200, 100)) + labels = torch.randint(0, 100, (200,)) spans = get_random_spans(np.full(10, 50), 0, 10) fast_llm_loss = dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2]) @@ -249,8 +268,8 @@ def test_dpo_loss(): def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path, seed: int): for batch_shape in _BATCH_SHAPES: - for num_columns, grad_output, logits_scale_factor, loss_masking, dtype in _LOSS_PARAMETERS: - suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}" + for num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size in _LOSS_PARAMETERS: + suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}" # Entropy loss for entropy_loss_type in EntropyLossType: for target_format in TargetFormat: @@ -270,6 +289,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa target_format, entropy_loss_type, dtype, + block_size, test_context.group, ) # Z loss @@ -302,8 +322,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): _run_lm_loss_distributed, (result_path / "test_losses", random.randint(0, 2**32 - 1)), world_size=2, - backend=DistributedBackend.gloo, - use_cuda=False, # Disable device count check. + backend=DistributedBackend.nccl if (use_nccl := torch.cuda.device_count() >= 2) else DistributedBackend.gloo, + use_cuda=use_nccl, # Disable device count check. ) @@ -311,7 +331,7 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): @pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) @pytest.mark.parametrize( "loss_type", @@ -335,10 +355,11 @@ def test_lm_loss_distributed( logits_scale_factor, loss_masking, dtype, + block_size, ): report_subtest( result_path - / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}", + / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}", 2, use_cuda=False, ) From 1d0439e0c6419a7040932ba520a89ba2554437fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 5 Feb 2026 05:06:43 -0500 Subject: [PATCH 07/11] Forward KL --- fast_llm/functional/entropy_loss.py | 4 +- fast_llm/functional/triton/cross_entropy.py | 41 ++++++++++++++++++--- tests/layers/test_lm_losses.py | 4 +- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 25e1ae31..4d39b3a7 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -183,7 +183,7 @@ def _fused_cross_entropy_base_from_distribution( # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) if return_kl_loss: if target_format == TargetFormat.logits: - target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) + target_log_probability = target_logits_norm else: target_log_probability = torch.log(target) logits_norm = logits_norm - target_log_probability @@ -194,6 +194,8 @@ def _fused_cross_entropy_base_from_distribution( all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) per_sample_loss = sum_exp_logits.log() - predicted_logits + if return_kl_loss and target_format == TargetFormat.logits: + per_sample_loss = per_sample_loss - sum_exp_target_logits.log() if grad_output is None: grad = None diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 6fb8e930..516df046 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -146,6 +146,7 @@ def triton_predicted_logits_from_distribution( target_logits_scale_factor: tl_constexpr = 1.0, logits_scale_factor: tl_constexpr = 1.0, unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. + return_kl_loss: tl.constexpr = False, ): for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(col_offset, col_offset + block_size) @@ -169,10 +170,14 @@ def triton_predicted_logits_from_distribution( target_max_logits = tl.max(target_logits, 0) target_exp_logits = tl.exp(target_logits - target_max_logits) target_sum_exp_logits = tl.sum(target_exp_logits, 0) - predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits, 0)) + # entropy = sum(logits*exp_logits)/sum_exp_logits - log_sum_exp_logits + # `log_sum_exp_logits` term and division by `sum_exp_logits` kept for later, + logits_shifted = logits - target_logits if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits_shifted, 0)) else: target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - predicted_logits = tl.sum(tl.where(mask, target * logits, 0)) + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) target_max_logits = None target_sum_exp_logits = None else: @@ -186,18 +191,22 @@ def triton_predicted_logits_from_distribution( target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( target_max_logits - target_new_max_logits ) + logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( - tl.where(mask, target_exp_logits * logits, 0) + tl.where(mask, target_exp_logits * logits_shifted, 0) ) target_max_logits = target_new_max_logits else: - predicted_logits += tl.sum(tl.where(mask, target * logits, 0)) + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits += tl.sum(tl.where(mask, target * logits_shifted, 0)) if from_logits: target = target_exp_logits if not unscaled_probabilities: predicted_logits /= target_sum_exp_logits target /= target_sum_exp_logits + if return_kl_loss: + predicted_logits = predicted_logits + tl.log(target_sum_exp_logits) + target_max_logits return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target @@ -219,6 +228,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( from_logits: tl_constexpr = True, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, + return_kl_loss: tl.constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -240,6 +250,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, unscaled_probabilities=True, + return_kl_loss=return_kl_loss, ) ) if predicted_logits_ptr is not None: @@ -275,6 +286,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_logits_stride_0: tl_constexpr = None, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, + return_kl_loss: tl.constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -303,6 +315,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( from_logits=from_logits, logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, + return_kl_loss=return_kl_loss, ) ) else: @@ -390,7 +403,12 @@ def _cross_entropy_loss_from_distribution( loss_mask: torch.Tensor | None, sum_exp_logits: torch.Tensor, max_logits: torch.Tensor, + target_sum_exp_logits: torch.Tensor | None, + target_max_logits: torch.Tensor | None, + return_kl_loss: bool = False, ) -> torch.Tensor: + if return_kl_loss: + predicted_logits = predicted_logits + target_sum_exp_logits.log() + target_max_logits per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits if loss_mask is not None: per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) @@ -417,7 +435,7 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED - Assert.eq(entropy_loss_type, EntropyLossType.cross_entropy) + Assert.incl(entropy_loss_type, (EntropyLossType.cross_entropy, EntropyLossType.forward_kl)) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -500,6 +518,7 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -526,6 +545,7 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -541,7 +561,16 @@ def triton_cross_entropy_forward_backward( target_max_logits = None torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_distribution(predicted_logits, loss_mask, sum_exp_logits, max_logits) + loss = _cross_entropy_loss_from_distribution( + predicted_logits, + loss_mask, + sum_exp_logits, + max_logits, + target_sum_exp_logits=target_sum_exp_logits, + target_max_logits=target_max_logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl + and target_format == TargetFormat.logits, + ) triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 37cd99fb..e5419720 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -168,7 +168,7 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type != EntropyLossType.cross_entropy or not triton_available: + if entropy_loss_type == EntropyLossType.reverse_kl or not triton_available: # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED @@ -219,7 +219,7 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +@pytest.mark.parametrize("target_format", TargetFormat) @pytest.mark.parametrize("entropy_loss_type", EntropyLossType) def test_entropy_loss( batch_shape, From 1b40518e3550710d03628fc82adde8d6c4811ab3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 04:17:43 -0500 Subject: [PATCH 08/11] Reverse KL, triton tweaks --- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/functional/config.py | 13 + fast_llm/functional/triton/__init__.py | 13 +- fast_llm/functional/triton/adam.py | 6 +- fast_llm/functional/triton/cross_entropy.py | 495 ++++++++++++++---- fast_llm/functional/triton/mlp.py | 13 +- fast_llm/functional/triton/normalization.py | 16 +- fast_llm/functional/triton/pointwise.py | 19 +- fast_llm/functional/triton/rotary.py | 6 +- fast_llm/functional/triton/sparse_copy.py | 17 +- fast_llm/functional/triton/sparse_linear.py | 30 +- fast_llm/layers/attention/config.py | 5 - fast_llm/layers/attention/rotary/rotary.py | 12 +- .../common/normalization/normalization.py | 4 +- fast_llm/layers/decoder/mlp/mlp.py | 4 +- fast_llm/layers/language_model/loss/config.py | 21 +- .../language_model/loss/entropy_loss.py | 44 +- tests/conftest.py | 5 + tests/functional/test_functional.py | 32 +- tests/functional/test_sparse_matmul.py | 53 +- tests/functional/test_triton_kernels.py | 81 ++- tests/layers/test_lm_losses.py | 34 +- tests/layers/test_rotary.py | 15 +- tests/layers/test_ssm.py | 4 +- tests/models/test_checkpoint.py | 13 +- tests/test_loss_mask.py | 5 +- tests/utils/model_configs.py | 2 + tests/utils/utils.py | 2 + 28 files changed, 628 insertions(+), 338 deletions(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 2c6c8105..baa38633 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 050c700c..f863a99a 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -14,6 +14,19 @@ class TritonConfig: POINTWISE_BLOCK_SIZE = 1024 MAX_BLOCK_SIZE_BYTES = 65536 + @classmethod + def enabled(cls, device: "torch.device|None" = None, default: bool | None = None) -> bool: + if default is False: + return False + from fast_llm.functional.triton import triton_available, triton_interpret + + available = triton_available and (device is None or device.type == "cuda" or triton_interpret) + if default is None: + default = available and cls.TRITON_ENABLED + else: + assert available + return default + class MLPRecomputeLevel(enum.StrEnum): none = "none" diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 82f67621..61ead1c6 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -9,25 +9,32 @@ tl_constexpr = tl.constexpr TritonConfig = triton.Config - triton_available = torch.cuda.is_available() or triton.knobs.runtime.interpret + # Use `TRITON_INTERPRET=1` to enable triton on CPU. + triton_interpret = triton.knobs.runtime.interpret + triton_available = torch.cuda.is_available() or triton_interpret except ImportError as e: triton = InvalidObject(e) tl = triton tl_constexpr = None TritonConfig = lambda *args, **kwargs: None + triton_interpret = False triton_available = False triton_jit = try_decorate(lambda: triton.jit) triton_autotune = try_decorate(lambda: triton.autotune) - if not triton_available: tl_arange = None -elif triton.knobs.runtime.interpret: + tl_full = None +elif triton_interpret: # Workaround for a triton bug. @triton_jit def tl_arange(start, end): return tl.arange(int(start), int(end)) + @triton_jit + def tl_full(shape, value, dtype): + return tl.full(tuple(int(x) for x in shape), value, dtype) + else: tl_arange = tl.arange diff --git a/fast_llm/functional/triton/adam.py b/fast_llm/functional/triton/adam.py index 07ba2df4..2c835ca0 100644 --- a/fast_llm/functional/triton/adam.py +++ b/fast_llm/functional/triton/adam.py @@ -8,7 +8,7 @@ from torch.optim.adamw import adamw # noqa from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit @triton_jit() @@ -37,7 +37,7 @@ def triton_adam_kernel( # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel params = tl.load(params_ptr + offsets, mask=mask) @@ -75,7 +75,7 @@ def triton_adam( epsilon: float, use_triton=True, ) -> None: - if not use_triton or (use_triton is None and TritonConfig.TRITON_ENABLED): + if not TritonConfig.enabled(params.device, use_triton): if noop_flag.item() == 0: return adamw( [params], diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 516df046..33504877 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,34 +1,61 @@ import torch -from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit -from fast_llm.utils import Assert @triton_jit() -def triton_fused_softmax_base( +def triton_fused_softmax_iter_base( logits_ptr, + col_offset: tl.constexpr, n_cols: tl_constexpr, block_size: tl_constexpr, + max_logits=None, + sum_exp_logits=None, + col_offsets=None, + mask=None, logits_scale_factor: tl_constexpr = 1.0, ): - for col_offset in tl.static_range(0, n_cols, block_size): + if col_offsets is None: col_offsets = tl_arange(col_offset, col_offset + block_size) + if mask is None: mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if col_offset == 0: + new_max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + return logits, exp_logits, sum_exp_logits, new_max_logits, col_offsets, mask - if col_offset == 0: - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) - else: - new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) - exp_logits = tl.exp(logits - new_max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) - max_logits = new_max_logits - return exp_logits, sum_exp_logits, max_logits + +@triton_jit() +def triton_fused_softmax_base( + logits_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + logits_scale_factor: tl_constexpr = 1.0, +): + exp_logits = None + sum_exp_logits = None + max_logits = None + for col_offset in tl.static_range(0, n_cols, block_size): + logits, exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) + return exp_logits, sum_exp_logits, max_logits, col_offsets, mask @triton_jit() @@ -48,7 +75,7 @@ def triton_cross_entropy_forward_from_labels_parallel_kernel( block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 - exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + exp_logits, sum_exp_logits, max_logits, _, _ = triton_fused_softmax_base( logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) @@ -90,7 +117,7 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( logits_ptr = logits_ptr + block_idx * logits_stride_0 if max_logits_ptr is None or sum_exp_logits_ptr is None: - exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) else: @@ -116,11 +143,11 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( elif logits_scale_factor != 1.0: grad_losses *= logits_scale_factor # Run in reverse order to maximize input and cache reuse. - col_offset_start = (n_cols - 1) // block_size * block_size + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size for col_offset in tl.static_range(col_offset_start, -1, -block_size): - col_offsets = tl_arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: logits *= logits_scale_factor @@ -145,68 +172,68 @@ def triton_predicted_logits_from_distribution( from_logits: tl_constexpr = True, target_logits_scale_factor: tl_constexpr = 1.0, logits_scale_factor: tl_constexpr = 1.0, - unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. + return_partial_loss: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. return_kl_loss: tl.constexpr = False, ): + max_logits = None + sum_exp_logits = None + if from_logits: + target_max_logits = None + target_sum_exp_logits = None + else: + target_max_logits = 0 + target_sum_exp_logits = 0 + for col_offset in tl.static_range(0, n_cols, block_size): - col_offsets = tl_arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + logits, exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) if from_logits: - target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if target_logits_scale_factor != 1.0: - target_logits *= target_logits_scale_factor - else: - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - - if col_offset == 0: - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) - if from_logits: - target_max_logits = tl.max(target_logits, 0) - target_exp_logits = tl.exp(target_logits - target_max_logits) - target_sum_exp_logits = tl.sum(target_exp_logits, 0) - # entropy = sum(logits*exp_logits)/sum_exp_logits - log_sum_exp_logits - # `log_sum_exp_logits` term and division by `sum_exp_logits` kept for later, + target_logits, target_exp_logits, target_sum_exp_logits, target_new_max_logits, _, _ = ( + triton_fused_softmax_iter_base( + target_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=target_max_logits, + sum_exp_logits=target_sum_exp_logits, + logits_scale_factor=target_logits_scale_factor, + col_offsets=col_offsets, + mask=mask, + ) + ) + if col_offset == 0: logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits_shifted, 0)) else: - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - logits_shifted = logits - tl.log(target) if return_kl_loss else logits - predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) - target_max_logits = None - target_sum_exp_logits = None - else: - new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) - exp_logits = tl.exp(logits - new_max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) - max_logits = new_max_logits - if from_logits: - target_new_max_logits = tl.maximum(tl.max(target_logits, 0), target_max_logits) - target_exp_logits = tl.exp(target_logits - target_new_max_logits) - target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( - target_max_logits - target_new_max_logits - ) logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( tl.where(mask, target_exp_logits * logits_shifted, 0) ) - target_max_logits = target_new_max_logits + target_max_logits = target_new_max_logits + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if col_offset == 0: + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) else: logits_shifted = logits - tl.log(target) if return_kl_loss else logits predicted_logits += tl.sum(tl.where(mask, target * logits_shifted, 0)) if from_logits: target = target_exp_logits - if not unscaled_probabilities: + if not return_partial_loss: predicted_logits /= target_sum_exp_logits target /= target_sum_exp_logits if return_kl_loss: - predicted_logits = predicted_logits + tl.log(target_sum_exp_logits) + target_max_logits + predicted_logits += tl.log(target_sum_exp_logits) + target_max_logits return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target @@ -224,7 +251,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( sum_exp_logits_ptr=None, target_max_logits_ptr=None, target_sum_exp_logits_ptr=None, - predicted_logits_ptr=None, + partial_losses_ptr=None, from_logits: tl_constexpr = True, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, @@ -237,7 +264,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: # This entry is masked, ignore. - tl.store(predicted_logits_ptr + block_idx, 0) + tl.store(partial_losses_ptr + block_idx, 0) return predicted_logits, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( @@ -249,12 +276,12 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( from_logits=from_logits, logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, - unscaled_probabilities=True, + return_partial_loss=True, return_kl_loss=return_kl_loss, ) ) - if predicted_logits_ptr is not None: - tl.store(predicted_logits_ptr + block_idx, predicted_logits) + if partial_losses_ptr is not None: + tl.store(partial_losses_ptr + block_idx, predicted_logits) if max_logits_ptr is not None: tl.store(max_logits_ptr + block_idx, max_logits) if sum_exp_logits_ptr is not None: @@ -275,6 +302,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target_stride_0: tl_constexpr, block_size: tl_constexpr, loss_mask_ptr=None, + partial_losses_ptr=None, losses_ptr=None, max_logits_ptr=None, sum_exp_logits_ptr=None, @@ -321,11 +349,22 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( else: max_logits = tl.load(max_logits_ptr + block_idx) sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) - if grad_losses is not None and from_logits: + if from_logits: target_max_logits = tl.load(target_max_logits_ptr + block_idx) target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) if losses_ptr is not None: + # if return_kl_loss: + # predicted_logits = predicted_logits + target_sum_exp_logits.log() + target_max_logits + # per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits + # if loss_mask is not None: + # per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) + if partial_losses_ptr is not None: + predicted_logits = tl.load(partial_losses_ptr + block_idx) + if from_logits: + predicted_logits /= target_sum_exp_logits + if return_kl_loss: + predicted_logits += tl.log(target_sum_exp_logits) + target_max_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) loss = tl.log(sum_exp_logits) + max_logits - predicted_logits tl.store(losses_ptr + block_idx, loss) @@ -334,7 +373,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( if logits_scale_factor != 1.0: grad_losses *= logits_scale_factor # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - col_offset_start = (n_cols - 1) // block_size * block_size + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size for col_offset in tl.static_range(col_offset_start, -1, -block_size): col_offsets = tl_arange(col_offset, col_offset + block_size) mask = col_offsets < n_cols @@ -355,6 +394,239 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) +@triton_jit() +def triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + from_logits: tl_constexpr = True, + target_logits_scale_factor: tl_constexpr = 1.0, + logits_scale_factor: tl_constexpr = 1.0, + return_partial_loss: tl_constexpr = False, +): + max_logits = None + sum_exp_logits = None + if from_logits: + target_max_logits = None + target_sum_exp_logits = None + else: + target_max_logits = 0 + target_sum_exp_logits = 0 + + for col_offset in tl.static_range(0, n_cols, block_size): + logits, exp_logits, sum_exp_logits, new_max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) + + # print("sum_exp_logits", sum_exp_logits) + # print("max_logits", new_max_logits) + if from_logits: + # log_target excludes the log_sum_exp term to be added later + log_target, _, target_sum_exp_logits, target_new_max_logits, _, _ = triton_fused_softmax_iter_base( + target_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=target_max_logits, + sum_exp_logits=target_sum_exp_logits, + logits_scale_factor=target_logits_scale_factor, + col_offsets=col_offsets, + mask=mask, + ) + target = log_target + # print("target_sum_exp_logits", target_sum_exp_logits) + # print("new_max_logits", target_new_max_logits) + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + log_target = tl.log(target) + if col_offset == 0: + # predicted_log_probability=logits - new_max_logits - tl.log(sum_exp_logits) + # target_log_probability=log_target-target_new_max_logits-tl.log(target_sum_exp_logits) + # print("predicted_log_probability", predicted_log_probability) + # print("target_log_probability", target_log_probability) + # print("IUWH", exp_logits * (predicted_log_probability-target_log_probability)/sum_exp_logits) + loss = tl.sum(tl.where(mask, exp_logits * (logits - log_target), 0)) + # print("max_logits", new_max_logits) + # print("partial_losses", exp_logits * (logits-log_target)) + + else: + loss = loss * tl.exp(max_logits - new_max_logits) + tl.sum( + tl.where(mask, exp_logits * (logits - log_target), 0) + ) + max_logits = new_max_logits + if from_logits: + target_max_logits = target_new_max_logits + + # print("partial_loss", loss) + if not return_partial_loss: + loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits + if from_logits: + loss = loss + tl.log(target_sum_exp_logits) + target_max_logits + + return loss, logits, target, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits + + +@triton_jit() +def triton_reverse_kl_forward_kernel_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + partial_losses_ptr=None, + from_logits: tl_constexpr = True, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + tl.store(partial_losses_ptr + block_idx, 0) + return + + partial_loss, _, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits = ( + triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + return_partial_loss=True, + ) + ) + if partial_losses_ptr is not None: + tl.store(partial_losses_ptr + block_idx, partial_loss) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + + if target_max_logits_ptr is not None: + tl.store(target_max_logits_ptr + block_idx, target_max_logits) + if target_sum_exp_logits_ptr is not None: + tl.store(target_sum_exp_logits_ptr + block_idx, target_sum_exp_logits) + + +@triton_jit() +def triton_reverse_kl_forward_backward_kernel_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + partial_losses_ptr=None, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + from_logits: tl_constexpr = True, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + loss, logits, target, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits = ( + triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + ) + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + if from_logits: + target_max_logits = tl.load(target_max_logits_ptr + block_idx) + target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) + + # print("sum_exp_logits", sum_exp_logits) + # print("max_logits", max_logits) + + # if from_logits: + # print("target_sum_exp_logits", target_sum_exp_logits) + # print("target_max_logits", target_max_logits) + + if losses_ptr is not None: + if partial_losses_ptr is not None: + loss = tl.load(partial_losses_ptr + block_idx) + # print("partial_loss", loss) + loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits + if from_logits: + loss = loss + tl.log(target_sum_exp_logits) + target_max_logits + tl.store(losses_ptr + block_idx, loss) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if from_logits and target_logits_scale_factor != 1.0: + target *= target_logits_scale_factor + if from_logits: + target_log_probability = target - target_max_logits - tl.log(target_sum_exp_logits) + else: + target_log_probability = tl.log(target) + + predicted_probability = tl.exp(logits - max_logits) / sum_exp_logits + predicted_log_probability = logits - max_logits - tl.log(sum_exp_logits) + grad_logits = ( + grad_losses * (predicted_log_probability - target_log_probability - loss) * predicted_probability + ) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + + @torch.compile def _rescale_sum_exp_logits( sum_exp_logits: torch.Tensor, @@ -389,12 +661,11 @@ def _cross_entropy_loss_from_labels( @torch.compile def _rescale_predicted_logits( predicted_logits: torch.Tensor, - target_sum_exp_logits: torch.Tensor, local_target_max_logits: torch.Tensor, target_max_logits: torch.Tensor, ): # We skipped the division by `target_sum_exp_logits` in the triton kernel so we do it here. - return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) / target_sum_exp_logits + return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) @torch.compile @@ -415,7 +686,7 @@ def _cross_entropy_loss_from_distribution( return per_sample_losses.mean() -def triton_cross_entropy_forward_backward( +def triton_entropy_loss_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, @@ -434,8 +705,6 @@ def triton_cross_entropy_forward_backward( Compared to a standard pytorch implementation, this reduces memory usage (of logits) by 3x and memory I/O by 5x. TODO: Better handling of `grad_output = None` """ - assert TritonConfig.TRITON_ENABLED - Assert.incl(entropy_loss_type, (EntropyLossType.cross_entropy, EntropyLossType.forward_kl)) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -465,6 +734,7 @@ def triton_cross_entropy_forward_backward( } ) if target_format == TargetFormat.labels: + assert entropy_loss_type != EntropyLossType.reverse_kl if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( @@ -476,21 +746,21 @@ def triton_cross_entropy_forward_backward( ) loss = losses.mean() else: - predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(predicted_logits) - sum_exp_logits = torch.empty_like(predicted_logits) + partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(partial_losses) + sum_exp_logits = torch.empty_like(partial_losses) triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, max_logits_ptr=local_max_logits, sum_exp_logits_ptr=sum_exp_logits, - predicted_logits_ptr=predicted_logits, + predicted_logits_ptr=partial_losses, col_min=n_cols * group.rank(), **kwargs, ) max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_labels(predicted_logits, target, sum_exp_logits, max_logits) + torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) if grad_output is not None: triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, @@ -502,11 +772,17 @@ def triton_cross_entropy_forward_backward( **backward_kwargs, ) else: + if loss_mask is not None: + assert loss_mask.is_contiguous() + if entropy_loss_type == EntropyLossType.reverse_kl: + kernel = triton_reverse_kl_forward_backward_kernel_from_distribution + else: + kernel = triton_cross_entropy_from_distribution_forward_backward_kernel + kwargs["return_kl_loss"] = entropy_loss_type == EntropyLossType.forward_kl + if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if loss_mask is not None: - assert loss_mask.is_contiguous() - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -518,22 +794,27 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) loss = losses.mean() else: - predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(predicted_logits) - sum_exp_logits = torch.empty_like(predicted_logits) + partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(partial_losses) + sum_exp_logits = torch.empty_like(partial_losses) if target_format == TargetFormat.logits: - local_target_max_logits = torch.empty_like(predicted_logits) - target_sum_exp_logits = torch.empty_like(predicted_logits) + local_target_max_logits = torch.empty_like(partial_losses) + target_sum_exp_logits = torch.empty_like(partial_losses) else: local_target_max_logits = target_sum_exp_logits = None - triton_cross_entropy_from_distribution_forward_parallel_kernel[(n_rows,)]( + forward_kernel = ( + triton_reverse_kl_forward_kernel_from_distribution + if entropy_loss_type == EntropyLossType.reverse_kl + else triton_cross_entropy_from_distribution_forward_parallel_kernel + ) + + forward_kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -541,11 +822,10 @@ def triton_cross_entropy_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=local_target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - predicted_logits_ptr=predicted_logits, + partial_losses_ptr=partial_losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -554,24 +834,17 @@ def triton_cross_entropy_forward_backward( target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( target_sum_exp_logits, local_target_max_logits, group ) - predicted_logits = _rescale_predicted_logits( - predicted_logits, target_sum_exp_logits, local_target_max_logits, target_max_logits - ) + if entropy_loss_type != EntropyLossType.reverse_kl: + partial_losses = _rescale_predicted_logits( + partial_losses, local_target_max_logits, target_max_logits + ) else: target_max_logits = None - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - - loss = _cross_entropy_loss_from_distribution( - predicted_logits, - loss_mask, - sum_exp_logits, - max_logits, - target_sum_exp_logits=target_sum_exp_logits, - target_max_logits=target_max_logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl - and target_format == TargetFormat.logits, - ) - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + if entropy_loss_type == EntropyLossType.reverse_kl: + partial_losses = _rescale_predicted_logits(partial_losses, local_max_logits, max_logits) + torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + + kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -579,11 +852,13 @@ def triton_cross_entropy_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - predicted_logits_ptr=predicted_logits, + partial_losses_ptr=partial_losses, + losses_ptr=partial_losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, **backward_kwargs, ) + loss = partial_losses.mean() return loss, grad_logits diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 286e7159..7949faaf 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -14,7 +14,7 @@ output_parallel_linear_forward, update_linear_gradients, ) -from fast_llm.functional.triton import tl, tl_constexpr, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton_jit from fast_llm.functional.triton.sparse_copy import ( SparseMap, copy_dense_to_sparse_backward, @@ -37,7 +37,7 @@ def triton_mlp_activation_forward_kernel( ): # TODO: Int64 ptr only if needed? row_idx = tl.program_id(0).to(tl.int64) - columns = tl.program_id(1) * block_size + tl.arange(0, block_size) + columns = tl.program_id(1) * block_size + tl_arange(0, block_size) output_offsets = n_cols * row_idx + columns input_offsets = 2 * n_cols * row_idx + columns if gated else output_offsets @@ -85,7 +85,7 @@ def triton_mlp_activation_backward_kernel( ): # TODO: Int64 ptr only if needed? row_idx = tl.program_id(0).to(tl.int64) - columns = tl.program_id(1) * block_size + tl.arange(0, block_size) + columns = tl.program_id(1) * block_size + tl_arange(0, block_size) output_offsets = n_cols * row_idx + columns input_offsets = 2 * n_cols * row_idx + columns if gated else output_offsets @@ -219,6 +219,7 @@ def mlp_forward( recompute_level: MLPRecomputeLevel = MLPRecomputeLevel.none, transposed_layer_2_weight: bool = False, sparse_map: SparseMap | None = None, + use_triton: bool | None = None, ) -> tuple[torch.Tensor, list[typing.Any] | None]: # Sparse copy input_shape = input_.shape @@ -235,7 +236,7 @@ def mlp_forward( input_ = None # Activation - if TritonConfig.TRITON_ENABLED and intermediate_1.device.type == "cuda": + if TritonConfig.enabled(intermediate_1.device, use_triton): intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, gated, activation_type) else: do_grad = training and not recompute_level.recompute_activation @@ -287,6 +288,7 @@ def mlp_forward( transposed_layer_2_weight, sparse_map, input_shape, + use_triton, ] if training else None @@ -313,6 +315,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ transposed_layer_2_weight, sparse_map, input_shape, + use_triton, ) = context context.clear() @@ -344,7 +347,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ )[0] # Activation recomputation and/or backward - if TritonConfig.TRITON_ENABLED and grad_output.device.type == "cuda": + if TritonConfig.enabled(grad_output.device, use_triton): grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( grad_intermediate_2, (intermediate_1, gated, activation_type), intermediate_2 is None ) diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index a018ad44..9538a927 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -4,7 +4,7 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit from fast_llm.tensor import param_get_and_unset_is_zero @@ -23,7 +23,7 @@ def triton_normalization_forward_kernel( ): # Program dimensions row = tl.program_id(0).to(tl.int64) - cols = tl.arange(0, block_size) + cols = tl_arange(0, block_size) mask = cols < n_cols offsets = row * n_cols + cols @@ -75,10 +75,10 @@ def triton_normalization_backward_kernel_1( block_size_row: tl_constexpr, ): # row_start = tl.program_id(0)*block_size_row - rows = tl.program_id(0) * block_size_row + tl.arange(0, block_size_row)[:, None] + rows = tl.program_id(0) * block_size_row + tl_arange(0, block_size_row)[:, None] row_mask = rows < n_rows - cols = tl.arange(0, block_size)[None, :] + cols = tl_arange(0, block_size)[None, :] col_mask = cols < n_cols mask = col_mask & row_mask @@ -140,15 +140,15 @@ def triton_normalization_backward_kernel_2( block_size_n: tl_constexpr, ): pid = tl.program_id(0) - cols = pid * block_size_n + tl.arange(0, block_size_n) - grad_weight_partial_sum = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + cols = pid * block_size_n + tl_arange(0, block_size_n) + grad_weight_partial_sum = tl_full((block_size_m, block_size_n), 0, dtype=tl.float32) if has_bias: - grad_bias_partial_sum = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + grad_bias_partial_sum = tl_full((block_size_m, block_size_n), 0, dtype=tl.float32) col_mask = cols < n_cols # Partial sums. for i in range(0, m, block_size_m): - rows = i + tl.arange(0, block_size_m) + rows = i + tl_arange(0, block_size_m) mask = (rows[:, None] < m) & (cols[None, :] < n_cols) offsets = rows[:, None] * n_cols + cols[None, :] grad_weight_partial_sum += tl.load(grad_weight_partial_ptr + offsets, mask=mask, other=0.0) diff --git a/fast_llm/functional/triton/pointwise.py b/fast_llm/functional/triton/pointwise.py index 22676ae1..44bb805f 100644 --- a/fast_llm/functional/triton/pointwise.py +++ b/fast_llm/functional/triton/pointwise.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit @triton_jit() @@ -19,7 +19,7 @@ def triton_copy_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel input_ = tl.load(input_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, input_, mask=mask) @@ -28,11 +28,12 @@ def triton_copy_kernel( def triton_copy( input_: torch.Tensor, out: torch.Tensor, + use_triton: bool | None = None, ) -> torch.Tensor: """ A triton implementation of tensor copying (`torch.Tensor.copy_()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return out.copy_(input_) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -53,19 +54,20 @@ def triton_fill_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel - tl.store(input_ptr + offsets, tl.full((block_size,), value, dtype), mask=mask) + tl.store(input_ptr + offsets, tl_full((block_size,), value, dtype), mask=mask) def triton_fill( input_: torch.Tensor, value: float | int, + use_triton: bool | None = None, ) -> torch.Tensor: """ A faster triton implementation of tensor copying (`torch.Tensor.fill_()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return input_.fill_(value) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -91,7 +93,7 @@ def triton_add_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel input_ = tl.load(input_ptr + offsets, mask=mask) other = tl.load(other_ptr + offsets, mask=mask) @@ -102,11 +104,12 @@ def triton_add( input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor | None = None, + use_triton: bool | None = None, ) -> torch.Tensor: """ A faster triton implementation of tensor addition (`torch.Tensor.add()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return torch.add(input_, other, out=out) # TODO: Improve assumptions. assert input_.is_contiguous() diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index c510925c..2c93776a 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -2,7 +2,7 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.utils import div @@ -25,8 +25,8 @@ def triton_rotary_kernel( pid_1 = tl.program_id(axis=1) # Head index position_id = pid_0 % seq_len - offsets = tl.arange(0, rotary_block_size) - head_offsets = pid_1 * head_block_size + tl.arange(0, head_block_size)[:, None] + offsets = tl_arange(0, rotary_block_size) + head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] input_re_ptr = input_ptr + input_offsets input_im_ptr = input_re_ptr + rotary_dim diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 7c803689..e68692d9 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit @dataclasses.dataclass() @@ -36,7 +36,7 @@ def copy_dense_to_sparse_kernel( block_size: tl_constexpr, ): dense_row = tl.program_id(0) - offsets = tl.arange(0, block_size) + block_size * tl.program_id(1) + offsets = tl_arange(0, block_size) + block_size * tl.program_id(1) mask = None if num_columns % block_size == 0 else offsets < num_columns out = tl.load(input_ptr + dense_row * num_columns + offsets, mask=mask) # Write to each expert. @@ -78,7 +78,7 @@ def copy_sparse_to_dense_kernel( block_size: tl_constexpr, ): dense_row = tl.program_id(0) - offsets = tl.arange(0, block_size) + block_size * tl.program_id(1) + offsets = tl_arange(0, block_size) + block_size * tl.program_id(1) mask = None if num_columns % block_size == 0 else offsets < num_columns out = tl.zeros((block_size,), tl.float32) # Sum over experts. @@ -125,7 +125,7 @@ def copy_sparse_to_dense_grad_score_kernel( grad_output_ptr += dense_row * num_columns input_ptr += sparse_row * num_columns - offsets = tl.arange(0, block_size) + offsets = tl_arange(0, block_size) if num_columns % block_size == 0: grad_scores = tl.load(input_ptr + offsets).to(tl.float32) * tl.load(grad_output_ptr + offsets).to(tl.float32) @@ -216,8 +216,8 @@ def sparse_map_kernel( we use a one-hot representation to get the quantities we want. TODO: Next triton release will support tl.histogram, maybe argsort. """ - block_range = tl.arange(0, block_size) - expert_range = tl.arange(0, block_size_expert) + block_range = tl_arange(0, block_size) + expert_range = tl_arange(0, block_size_expert) expert_mask = None if block_size_expert == num_experts else expert_range < num_experts if num_sparse_rows >= block_size: @@ -256,7 +256,7 @@ def sparse_map_kernel( if sparse_rows_ptr is not None: # Assign a new unique index to each row so that it lies in the range (expert_begin, expert_end) # for its assigned expert. - block_range = tl.arange(0, block_size) + block_range = tl_arange(0, block_size) for i in range(tl.cdiv(num_sparse_rows, block_size)): if num_sparse_rows % block_size == 0: mask = None @@ -307,7 +307,8 @@ def get_sparse_map( num_rows_unpadded = num_rows_dense * num_experts_per_token max_rows = (num_rows_unpadded + num_experts * pad_to_multiple) // pad_to_multiple * pad_to_multiple dtype = torch.int16 if max_rows < 32768 else torch.int32 - if (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton: + + if TritonConfig.enabled(top_experts.device, use_triton): expert_ends, expert_pad_begins = top_experts.new_empty((2 * num_experts,), dtype=dtype).chunk(2) sparse_rows = expert_ends.new_empty(num_rows_dense, num_experts_per_token) sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)]( diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index ae46655e..15af789d 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -2,7 +2,7 @@ import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.utils import Assert, div @@ -99,9 +99,9 @@ def dense_matmul_kernel( col_offset = pid_col * block_size_col # Pointers - row_range = tl.arange(0, block_size_row)[:, None] + row_offset - col_range = tl.arange(0, block_size_col)[None, :] + col_offset - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + row_offset + col_range = tl_arange(0, block_size_col)[None, :] + col_offset + inner_range = tl_arange(0, block_size_inner) lhs_ptr += row_range * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + col_range * rhs_stride_col out_ptr += row_range * out_stride_row + col_range * out_stride_col @@ -228,7 +228,7 @@ def output_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row col_sparse_offset = pid_col * block_size_col - sparse_range = tl.arange(0, padded_sparse_dim) + sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: @@ -236,9 +236,9 @@ def output_sparse_matmul_kernel( col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + (col_dense_offset + col_range) * rhs_stride_col out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col @@ -351,7 +351,7 @@ def input_inner_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row - sparse_range = tl.arange(0, padded_sparse_dim) + sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: @@ -360,9 +360,9 @@ def input_inner_sparse_matmul_kernel( col_offset = pid_col * block_size_col # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += (inner_dense_offset + inner_range[:, None]) * rhs_stride_inner + ( col_offset + col_range @@ -485,9 +485,9 @@ def input_row_sparse_matmul_kernel( inner_offset = (inner_begin // block_size_inner) * block_size_inner # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + inner_offset + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) + inner_offset lhs_ptr += (row_sparse_offset + row_range) * lhs_stride_row rhs_ptr += (col_offset + col_range) * rhs_stride_col out_ptr += (row_dense_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 626a8fde..40baf200 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,8 @@ import enum import logging import typing -import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig @@ -132,9 +130,6 @@ class AttentionConfig(MixerConfig): def _validate(self) -> None: super()._validate() - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - Assert.multiple(self.heads, self.head_groups) if not self.causal: diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 304f96b8..307256a7 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -100,11 +100,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = ( - triton_rotary_autograd_ - if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" - else rotary_embeddings_real - ) + rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key @@ -238,11 +234,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = ( - triton_rotary_autograd_ - if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" - else rotary_embeddings_real - ) + rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 55e62af2..6fe1ea51 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -190,7 +190,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | and not self._config.zero_centered ): implementation = NormalizationImplementation.fast - elif (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: + elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -259,7 +259,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | assert not hidden_dim.is_parallel implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: + if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 882963ce..88c86c8a 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -44,7 +44,9 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() - self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + self._activation_fn = ( + triton_mlp_activation_autograd if TritonConfig.enabled(torch.device("cuda")) else torch_mlp_activation + ) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = self._config.layer_1.get_layer( diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f531a1d4..d636d6af 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -2,7 +2,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType +from fast_llm.functional.config import EntropyLossType from fast_llm.layers.block.config import BlockKwargs from fast_llm.utils import Assert @@ -77,11 +77,10 @@ class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): desc="Type of loss to use.", hint=FieldHint.core, ) - - implementation: EntropyLossImplementation = Field( - default=EntropyLossImplementation.auto, - desc="Loss implementation.", - hint=FieldHint.performance, + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, ) @property @@ -100,11 +99,6 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): desc="Type of loss to use.", hint=FieldHint.core, ) - implementation: EntropyLossImplementation = Field( - default=EntropyLossImplementation.auto, - desc="Loss implementation.", - hint=FieldHint.performance, - ) reference_model: str = Field( default="teacher", desc="Name of the reference model for knowledge distillation.", @@ -116,6 +110,11 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): desc="Temperature for teacher softmax.", valid=check_field(Assert.gt, 0.0), ) + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelDistillationLoss]": diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 351aa210..25e0c19b 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -4,7 +4,7 @@ from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, @@ -13,32 +13,9 @@ from fast_llm.utils import Assert -def _get_implementation( - default: EntropyLossImplementation = EntropyLossImplementation.auto, - loss_type: EntropyLossType = EntropyLossType.cross_entropy, - vocab_parallel: bool = False, -) -> EntropyLossImplementation: - # Vocab parallel requires fused. - if vocab_parallel: - assert default in (EntropyLossImplementation.auto, EntropyLossImplementation.fused) - return EntropyLossImplementation.fused - - # Triton only available for cross_entropy - if TritonConfig.TRITON_ENABLED and torch.cuda.is_available() and loss_type == EntropyLossType.cross_entropy: - return EntropyLossImplementation.triton if default == EntropyLossImplementation.auto else default - else: - assert default != EntropyLossImplementation.triton - - # Otherwise, use fused. - return EntropyLossImplementation.fused if default == EntropyLossImplementation.auto else default - - class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._implementation = _get_implementation( - self._config.implementation, self._config.loss_type, self._vocab_parallel - ) def forward_backward( self, @@ -52,10 +29,10 @@ def forward_backward( None, # Labels are already masked grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, - implementation=self._implementation, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels, entropy_loss_type=self._config.loss_type, + use_triton=self._config.use_triton, ) @@ -65,10 +42,6 @@ def __init__(self, *args, **kwargs): if self._prediction_distance > 0: raise NotImplementedError() - self._implementation = _get_implementation( - self._config.implementation, self._config.loss_type, self._vocab_parallel - ) - def forward_backward( self, logits: "torch.Tensor", @@ -81,17 +54,17 @@ def forward_backward( self._get_loss_mask(kwargs, split_index), grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, - implementation=self._implementation, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, + use_triton=self._config.use_triton, ) _ENTROPY_LOSS_IMPLEMENTATIONS = { EntropyLossImplementation.torch: torch_entropy_loss_forward_backward, EntropyLossImplementation.fused: fused_entropy_loss_forward_backward, - EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, + EntropyLossImplementation.triton: triton_entropy_loss_forward_backward, } @@ -101,11 +74,11 @@ def entropy_loss_forward_backward( loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, group: torch.distributed.ProcessGroup | None = None, - implementation: EntropyLossImplementation = EntropyLossImplementation.fused, logits_scale_factor: float = 1.0, temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + use_triton: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -114,6 +87,7 @@ def entropy_loss_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) @@ -123,7 +97,11 @@ def entropy_loss_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + return ( + triton_entropy_loss_forward_backward + if TritonConfig.enabled(logits.device, use_triton) + else fused_entropy_loss_forward_backward + )( logits, target, loss_mask, diff --git a/tests/conftest.py b/tests/conftest.py index 4f7d7bad..23fc58b1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -279,3 +279,8 @@ def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" return xdist.scheduler.LoadGroupScheduling(config, log) + + +@pytest.fixture(scope="session") +def testing_device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 6471a516..7980f05b 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -2,6 +2,7 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert @@ -19,12 +20,12 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] ) -def test_mlp_recomputation(gated, activation): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - tokens = 1024 - hidden_size = 2048 - intermediate_size = 4096 - std = 1 / 64 +def test_mlp_recomputation(gated, activation, testing_device): + device = torch.device(testing_device) + tokens = 64 + hidden_size = 128 + intermediate_size = 256 + std = 1 / 16 input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device=device, requires_grad=True) @@ -53,7 +54,20 @@ def test_mlp_recomputation(gated, activation): param.grad = None param.grad_buffer = torch.empty_like(param) param.param_grad_is_zero = True - output = mlp_autograd(input_, None, *params, gated, activation, None, False, True, recompute_level, True) + output = mlp_autograd( + input_, + None, + *params, + gated, + activation, + None, + False, + True, + recompute_level, + True, + None, + triton_available and torch.cuda.is_available(), + ) output.backward(output_grad) if i == 0: Assert.rms_close(output, output_ref, 1e-5) @@ -74,8 +88,8 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow @pytest.mark.skip("Dropless MoE is broken") -def test_dropless_mlp(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def test_dropless_mlp(testing_device): + device = torch.device(testing_device) num_experts = 4 experts_per_token = 4 tokens = 256 diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py index 899dad96..0ebf9c5a 100644 --- a/tests/functional/test_sparse_matmul.py +++ b/tests/functional/test_sparse_matmul.py @@ -12,7 +12,7 @@ output_sparse_matmul, ) from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda +from tests.utils.utils import requires_triton @dataclasses.dataclass @@ -46,12 +46,11 @@ def sparse_dim_expanded(self) -> int: def num_experts(self) -> int: return len(self.expert_begins) - @functools.cached_property - def sparse_map(self) -> SparseMap: + def get_sparse_map(self, device: torch.device) -> SparseMap: return SparseMap( num_experts=self.num_experts, - expert_ends=torch.tensor(self.expert_ends, device="cuda"), - expert_pad_begins=torch.tensor(self.expert_pad_begins, device="cuda"), + expert_ends=torch.tensor(self.expert_ends, device=device), + expert_pad_begins=torch.tensor(self.expert_pad_begins, device=device), num_rows=self.expert_ends[-1], # Not needed sparse_rows=None, @@ -60,8 +59,8 @@ def sparse_map(self) -> SparseMap: num_experts_per_token=None, ) - def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: - return torch.normal(0, self.std, (dim_0, dim_1), device="cuda") + def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor: + return torch.normal(0, self.std, (dim_0, dim_1), device=device) _SPARSE_TEST_DATAS = ( @@ -80,28 +79,28 @@ def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: ) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_dense_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) - rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim) +def test_dense_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim, testing_device) output = dense_matmul(lhs, rhs) output_ref = torch.matmul(lhs, rhs) Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_output_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) - rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) +def test_output_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded, testing_device) # Randomly initialize the output to ensure padded values have no effect. - out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) - output = output_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map, out) + out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + output = output_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device), out) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): @@ -114,14 +113,14 @@ def test_output_sparse_matmul(sparse_test_data): Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_input_inner_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) - rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) +def test_input_inner_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim, testing_device) - output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device)) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): @@ -134,14 +133,14 @@ def test_input_inner_sparse_matmul(sparse_test_data): Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_input_row_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim) - rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) +def test_input_row_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) - output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device)) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 79817bb0..2886ab14 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -2,7 +2,7 @@ import torch from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType, TritonConfig +from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType from fast_llm.functional.triton.adam import triton_adam from fast_llm.functional.triton.mlp import ( torch_mlp_activation, @@ -25,71 +25,66 @@ rotary_embeddings_real, ) from fast_llm.utils import Assert, rms_diff -from tests.utils.utils import requires_cuda +from tests.utils.utils import requires_cuda, requires_triton -@requires_cuda -def test_triton_fill(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(425, 549, dtype=torch.bfloat16, device="cuda") - triton_fill(x, 32) +@requires_triton +def test_triton_fill(testing_device): + x = torch.randn(425, 549, dtype=torch.float16, device=testing_device) + triton_fill(x, 32, use_triton=True) assert x.min().item() == x.max().item() == 32 -@requires_cuda -def test_triton_copy(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(7563, dtype=torch.bfloat16, device="cuda") +@requires_triton +def test_triton_copy(testing_device): + x = torch.randn(7563, dtype=torch.float32, device=testing_device).to(torch.float16) x1 = x.clone() y = torch.zeros_like(x) Assert.all_different(x, y) - triton_copy(x, y) + triton_copy(x, y, use_triton=True) Assert.all_equal(x, y) Assert.all_equal(x, x1) -@requires_cuda -def test_triton_copy_cast(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(7563, dtype=torch.bfloat16, device="cuda") +@requires_triton +def test_triton_copy_cast(testing_device): + x = torch.randn(7563, dtype=torch.float32, device=testing_device).to(torch.float16) x1 = x.clone() y = torch.zeros_like(x, dtype=torch.float32) Assert.all_different(x.float(), y) - triton_copy(x, y) + triton_copy(x, y, use_triton=True) Assert.rms_close(x, y, 1e-4) Assert.all_equal(x, x1) -@requires_cuda -def test_triton_add(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(8934, dtype=torch.float32, device="cuda") +@requires_triton +def test_triton_add(testing_device): + x = torch.randn(8934, dtype=torch.float32, device=testing_device) x1 = x.clone() y = torch.zeros_like(x) y1 = y.clone() Assert.all_different(x, y) - z = triton_add(x, y) + z = triton_add(x, y, use_triton=True) z1 = x1 + y1 Assert.rms_close(z, z1, 1e-5) Assert.all_equal(x, x1) Assert.all_equal(y, y1) -@requires_cuda +@requires_triton @pytest.mark.parametrize( ("batch_size", "sequence_length", "num_heads", "head_size"), - [(4, 1024, 8, 128), (1, 32, 1, 16), (2, 2048, 2, 192), (3, 519, 7, 134), (2, 100000, 2, 4)], + [(4, 32, 2, 16), (1, 32, 1, 16), (2, 64, 2, 96), (3, 59, 7, 22)], ) -def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device="cuda") +def test_triton_rotary(batch_size, sequence_length, num_heads, head_size, testing_device): + x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device=testing_device) frequencies = ( DefaultRotaryConfig() .get_layer(TensorDim("", head_size)) ._get_frequencies( sequence_length, head_size, - device="cuda", + device=testing_device, ) ) @@ -110,15 +105,14 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): Assert.rms_close(y_real, y_triton, 1e-4) -@requires_cuda +@requires_triton @pytest.mark.parametrize("has_bias", [True, False]) @pytest.mark.parametrize("zero_centered", [True, False]) -def test_triton_normalization(has_bias, zero_centered): - assert TritonConfig.TRITON_ENABLED - input_ = torch.randn(4096, 1024, device="cuda", requires_grad=True) +def test_triton_normalization(has_bias, zero_centered, testing_device): + input_ = torch.randn(32, 128, device=testing_device, requires_grad=True) output_grad = torch.randn_like(input_) - weight = torch.randn(1024, device="cuda", requires_grad=True) + weight = torch.randn(128, device=testing_device, requires_grad=True) weight.grad_buffer = torch.empty_like(weight) weight.param_grad_is_zero = True @@ -160,7 +154,7 @@ def test_triton_normalization(has_bias, zero_centered): Assert.rms_close(bias_grad0, bias.grad, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", @@ -173,10 +167,9 @@ def test_triton_normalization(has_bias, zero_centered): ], ) @pytest.mark.parametrize("recompute", [True, False]) -def test_triton_mlp_activation(gated, activation, recompute): - assert TritonConfig.TRITON_ENABLED - input_ = torch.randn(1024, 4096 * (2 if gated else 1), device="cuda", requires_grad=True) - output_grad = torch.randn(1024, 4096, device="cuda") +def test_triton_mlp_activation(gated, activation, recompute, testing_device): + input_ = torch.randn(32, 128 * (2 if gated else 1), device=testing_device, requires_grad=True) + output_grad = torch.randn(32, 128, device=testing_device) output1, context = triton_mlp_activation_forward(input_, gated, activation) input_grad1, output3 = triton_mlp_activation_backward(output_grad, context, recompute) @@ -190,10 +183,9 @@ def test_triton_mlp_activation(gated, activation, recompute): Assert.rms_close(output1, output3, 1e-5) -@requires_cuda -def test_triton_adam(): - assert TritonConfig.TRITON_ENABLED - params = torch.randn(4576427, dtype=torch.float32, device="cuda") +@requires_triton +def test_triton_adam(testing_device): + params = torch.randn(45764, dtype=torch.float32, device=testing_device) grads = torch.randn_like(params) exp_avgs = torch.randn_like(params) exp_avg_sqs = torch.randn_like(params).abs() @@ -248,13 +240,14 @@ def compare(i, j, fn, arg): compare(0, 4, Assert.eq, 0) +# TODO: Failing with triton interpreter @requires_cuda @pytest.mark.parametrize( ("num_rows_dense", "num_experts", "num_experts_per_token"), [(2048, 8, 2), (2048, 6, 2), (2048, 8, 8), (256, 8, 2), (5627, 8, 2)], ) -def test_triton_sparse_map(num_rows_dense, num_experts, num_experts_per_token): - logits = torch.randn((num_rows_dense, num_experts), device="cuda") +def test_triton_sparse_map(num_rows_dense, num_experts, num_experts_per_token, testing_device): + logits = torch.randn((num_rows_dense, num_experts), device=testing_device) _, top_experts = torch.topk(logits, num_experts_per_token, dim=-1) sparse_map_triton = get_sparse_map(top_experts, num_experts, use_triton=True) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index e5419720..9ca78d64 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -9,10 +9,11 @@ from fast_llm.engine.config_utils import data_type from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedBackend -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat +from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available +from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward from fast_llm.utils import Assert @@ -104,8 +105,10 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() -_BATCH_SHAPES = ((64,), (16, 8)) +# _BATCH_SHAPES = ((64,), (16, 8)) +_BATCH_SHAPES = ((1,),) _LOSS_PARAMETERS = ( + (8, 1.0, 1.0, False, DataType.float32, None), # Simple (500, 1.0, 1.0, False, DataType.float32, None), # Simple (256, 1.0, 1.0, False, DataType.float32, None), # Power of 2 (500, None, 1.0, False, DataType.float32, None), # No grad @@ -131,13 +134,13 @@ def _test_entropy_loss( group=None, ): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") + pytest.skip(reason="Reverse KL loss not implemented for target labels") # TODO: Test tensor-parallel implementation. logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) local_logits = split_op(logits, group, -1).contiguous() local_target = target if target_format == TargetFormat.labels else split_op(target, group, -1).contiguous() # Torch serves as the reference implementation. - out_ref, grad_ref = entropy_loss_forward_backward( + out_ref, grad_ref = torch_entropy_loss_forward_backward( logits=logits, target=target, loss_mask=loss_mask, @@ -145,9 +148,8 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.torch, ) - out_fused, grad_fused = entropy_loss_forward_backward( + out_fused, grad_fused = fused_entropy_loss_forward_backward( logits=local_logits, target=local_target, loss_mask=loss_mask, @@ -156,7 +158,6 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.fused, ) _compare_losses_and_grads( out_fused, @@ -168,11 +169,9 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type == EntropyLossType.reverse_kl or not triton_available: - # Triton implementation only supports cross-entropy. + if not triton_available: return - assert TritonConfig.TRITON_ENABLED - out_triton, grad_triton = entropy_loss_forward_backward( + out_triton, grad_triton = triton_entropy_loss_forward_backward( logits=local_logits, target=local_target, loss_mask=loss_mask, @@ -180,11 +179,18 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.triton, group=group, block_size=block_size, ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) + _compare_losses_and_grads( + out_triton, + out_ref, + grad_output is not None, + grad_triton, + grad_ref, + threshold=1e-5 if target_format != TargetFormat.probabilities and data_type == DataType.float32 else 1e-4, + group=group, + ) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 112c88a6..f34b9a35 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -8,24 +8,27 @@ from fast_llm.utils import Assert -def test_rotary_2d(): +def test_rotary_2d(testing_device): """ Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. """ head_dim = 16 num_heads = 8 - device = "cuda" if torch.cuda.is_available() else "cpu" patch_positions = torch.tensor( [[h, w] for h in range(4) for w in range(4)], dtype=torch.int64, - device=device, + device=testing_device, ) - query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=device).normal_() + query = torch.empty( + 2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=testing_device + ).normal_() key = torch.empty_like(query).normal_() pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) - pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to(device) + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to( + testing_device + ) # Convert patch positions (h, w) to Pixtral's linear position IDs # Pixtral expects: position_id = h * max_patches_per_side + w position_ids = ( @@ -37,7 +40,7 @@ def test_rotary_2d(): ) fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) - kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: device} + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: testing_device} fast_llm_rotary.preprocess(kwargs) output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index 777214aa..c12fe52e 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -85,7 +85,7 @@ def _compare_mixers( @pytest.mark.slow # Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. @pytest.mark.skipif(not transformers.utils.import_utils.is_causal_conv1d_available(), reason="GDN deps missing") -def test_gdn(): +def test_gdn(testing_device): dtype = torch.bfloat16 NUM_V_HEADS = 4 @@ -103,7 +103,7 @@ def test_gdn(): hf_layer = ( Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) - .to(device="cuda" if torch.cuda.is_available() else "cpu", dtype=dtype) + .to(device=testing_device, dtype=dtype) .eval() ) fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 955fa534..1da26473 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -231,14 +231,14 @@ def do_load_and_compare_checkpoints( @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained( - model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints + model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints, testing_device ): # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) reference_shard = safetensors.torch.load_file( - get_convert_path() / "rank_0.safetensors", device="cuda" if torch.cuda.is_available() else "cpu" + get_convert_path() / "rank_0.safetensors", device=str(testing_device) )[_WEIGHT_SHARD_SAVE_NAME] load_and_compare_checkpoints( FastLLMCheckpointFormat, @@ -304,8 +304,7 @@ def test_load_pretrained( @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_huggingface_model(model_testing_config, get_convert_path): - device = "cuda" if torch.cuda.is_available() else "cpu" +def test_huggingface_model(model_testing_config, get_convert_path, testing_device): distributed_update = {("distributed", "use_cuda"): torch.cuda.is_available()} if model_testing_config.checkpoint_format is None: return @@ -331,11 +330,11 @@ def test_huggingface_model(model_testing_config, get_convert_path): 384, size=(4, 100), dtype=torch.int64, - device=device, + device=testing_device, ) kwargs = {} if model_testing_config.model_type == "multimodal": - kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(device) + kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(testing_device) kwargs["image_sizes"] = torch.tensor( [ [20, 20], # Full image, 25 patches @@ -373,7 +372,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] model_as_hf = ( model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True) - .to("cuda" if torch.cuda.is_available() else "cpu") + .to(testing_device) .eval() ) for name, model in zip( diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index ca92f0b7..8c131dfa 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -15,7 +15,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBatchConfig, GPTModelConfig -from tests.utils.utils import get_base_model, requires_cuda +from tests.utils.utils import get_base_model def create_test_batch( @@ -46,7 +46,7 @@ def get_minimal_model(): "embeddings": {"vocab_size": 1000}, "hidden_size": 64, }, - "distributed": {}, + "distributed": {"use_cuda": torch.cuda.is_available()}, }, ) model, distributed = get_base_model(config) @@ -82,7 +82,6 @@ def run_preprocess_batch(model, distributed_config, batch: LanguageModelBatch, p ) -@requires_cuda class TestLossMaskIntegration: """ Integration tests for loss_mask computation in preprocess_batch. diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5a6aff83..a4f28d14 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -211,6 +211,8 @@ def update_and_add_testing_config( "save": True, "show": False, }, + # Triton kernels are extremely slow in interpreter mode. + "enable_triton_kernels": torch.cuda.is_available(), # Uncomment to enable model debug logging: # "model_debug_level": _LOG_LEVEL, }, diff --git a/tests/utils/utils.py b/tests/utils/utils.py index f0ca20db..da293e1d 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -9,11 +9,13 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage +from fast_llm.functional.triton import triton_available from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +requires_triton = pytest.mark.skipif(not triton_available, reason="Triton is not available") @pytest.fixture(scope="session") From d99511b2ee83a7d41c0c68ebead9e08e6ddb5343 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 04:21:09 -0500 Subject: [PATCH 09/11] rename --- .../functional/triton/{cross_entropy.py => entropy_loss.py} | 0 fast_llm/layers/language_model/loss/entropy_loss.py | 2 +- tests/layers/test_lm_losses.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename fast_llm/functional/triton/{cross_entropy.py => entropy_loss.py} (100%) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/entropy_loss.py similarity index 100% rename from fast_llm/functional/triton/cross_entropy.py rename to fast_llm/functional/triton/entropy_loss.py diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 25e0c19b..f81e4e4b 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -4,7 +4,7 @@ from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward -from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 9ca78d64..2e878691 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -12,7 +12,7 @@ from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available -from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward From 094ac85615c81c30305bf142a5fed05ea4a43c41 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 05:12:54 -0500 Subject: [PATCH 10/11] Z loss --- fast_llm/functional/triton/__init__.py | 1 + fast_llm/functional/triton/entropy_loss.py | 9 +- fast_llm/functional/triton/z_loss.py | 138 ++++++++++++++++++ fast_llm/layers/language_model/loss/z_loss.py | 8 +- tests/layers/test_lm_losses.py | 31 +++- 5 files changed, 176 insertions(+), 11 deletions(-) create mode 100644 fast_llm/functional/triton/z_loss.py diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 61ead1c6..f5b394bf 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -38,3 +38,4 @@ def tl_full(shape, value, dtype): else: tl_arange = tl.arange + tl_full = tl.full diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index 33504877..ad826f3e 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -636,7 +636,7 @@ def _rescale_sum_exp_logits( return sum_exp_logits * (local_max_logits - max_logits).exp() -def _parallel_sum_exp_logits( +def parallel_sum_exp_logits( sum_exp_logits: torch.Tensor, local_max_logits: torch.Tensor, group: torch.distributed.ProcessGroup | None, @@ -758,7 +758,7 @@ def triton_entropy_loss_forward_backward( col_min=n_cols * group.rank(), **kwargs, ) - max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) if grad_output is not None: @@ -827,11 +827,10 @@ def triton_entropy_loss_forward_backward( target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, - **backward_kwargs, ) - max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) if target_format == TargetFormat.logits: - target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( + target_max_logits, target_sum_exp_logits = parallel_sum_exp_logits( target_sum_exp_logits, local_target_max_logits, group ) if entropy_loss_type != EntropyLossType.reverse_kl: diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py new file mode 100644 index 00000000..298c3c2a --- /dev/null +++ b/fast_llm/functional/triton/z_loss.py @@ -0,0 +1,138 @@ +import torch + +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton.entropy_loss import ( + parallel_sum_exp_logits, + triton_cross_entropy_forward_from_labels_parallel_kernel, + triton_fused_softmax_base, +) + + +@triton_jit() +def triton_z_loss_forward_backward_kernel( + logits_ptr, + loss_mask_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + block_size: tl_constexpr, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + + log_sum_exp_logits = tl.log(sum_exp_logits) + max_logits + + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, log_sum_exp_logits * log_sum_exp_logits) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + grad_losses *= 2 * log_sum_exp_logits / sum_exp_logits + # Run in reverse order to maximize input and cache reuse. + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, exp_logits * grad_losses, mask=mask + ) + + +def triton_z_loss_forward_backward( + logits: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + block_size: int | None = None, + num_warps: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert logits.is_contiguous() + if loss_mask is not None: + assert loss_mask.is_contiguous() + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "block_size": block_size, + "num_warps": num_warps, + } + grad_logits = None if grad_output is None else torch.empty_like(logits) + backward_kwargs = ( + {} + if grad_output is None + else { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / n_rows, + "grad_logits_stride_0": grad_logits.stride(-2), + } + ) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if group is None: + triton_z_loss_forward_backward_kernel[(n_rows,)]( + logits, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + **kwargs, + **backward_kwargs, + ) + else: + local_max_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + sum_exp_logits = torch.empty_like(local_max_logits) + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( + logits, + None, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + **kwargs, + ) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + triton_z_loss_forward_backward_kernel[(n_rows,)]( + logits, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + **kwargs, + **backward_kwargs, + ) + return losses.mean(), grad_logits diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 82b8d531..0e675d33 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -2,7 +2,9 @@ import torch +from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_softmax_base +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -20,7 +22,9 @@ def forward_backward( kwargs: dict[str, typing.Any], split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return z_loss_forward_backward( + return ( + triton_z_loss_forward_backward if TritonConfig.enabled(logits.device) else fused_z_loss_forward_backward + )( logits, self._get_loss_mask(kwargs, split_index), grad_output=self._get_grad_output(kwargs), @@ -44,7 +48,7 @@ def z_loss( @torch.compile -def z_loss_forward_backward( +def fused_z_loss_forward_backward( logits: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 2e878691..2f04a38e 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -13,9 +13,10 @@ from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.loss import loss_forward_backward -from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward +from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans from tests.utils.subtest import DistributedTestContext @@ -193,7 +194,9 @@ def _test_entropy_loss( ) -def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): +def _test_z_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, group=None +): logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) out_ref, grad_ref = loss_forward_backward( grad_output, @@ -202,7 +205,7 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los loss_mask, logits_scale_factor=logits_scale_factor, ) - out_fused, grad_fused = z_loss_forward_backward( + out_fused, grad_fused = fused_z_loss_forward_backward( split_op(logits, group, -1), loss_mask, grad_output, @@ -218,6 +221,25 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los threshold=1e-5 if data_type == DataType.float32 else 1e-4, group=group, ) + if not triton_available: + return + out_triton, grad_triton = triton_z_loss_forward_backward( + split_op(logits, group, -1), + loss_mask, + grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + block_size=block_size, + ) + _compare_losses_and_grads( + out_triton, + out_ref, + grad_output is not None, + grad_triton, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) @pytest.mark.slow @@ -257,7 +279,7 @@ def test_entropy_loss( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size): - _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) + _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size) @pytest.mark.skip(reason="DPO loss is broken") @@ -309,6 +331,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa logits_scale_factor, loss_masking, dtype, + block_size, test_context.group, ) From 35fd220a7661746884db1cc2a2f0655d9c8788b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 05:14:16 -0500 Subject: [PATCH 11/11] fix --- fast_llm/layers/language_model/loss/config.py | 6 ++++++ fast_llm/layers/language_model/loss/z_loss.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index d636d6af..97000312 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -160,6 +160,12 @@ class LanguageModelZLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) + @property def loss_class(self) -> "type[LanguageModelZLoss]": from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 0e675d33..1df54f7a 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -23,7 +23,9 @@ def forward_backward( split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": return ( - triton_z_loss_forward_backward if TritonConfig.enabled(logits.device) else fused_z_loss_forward_backward + triton_z_loss_forward_backward + if TritonConfig.enabled(logits.device, self._config.use_triton) + else fused_z_loss_forward_backward )( logits, self._get_loss_mask(kwargs, split_index),