Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,29 @@ def install_requirements() -> List[str]:


def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions."""
return ["numpy", "triton"]
"""Test dependencies for TE/JAX extensions.

Triton Package Selection:
The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable:

Default (NVTE_USE_PYTORCH_TRITON unset or "0"):
Returns 'triton' - OpenAI's standard package from PyPI.
Install with: pip install triton

NVTE_USE_PYTORCH_TRITON=1:
Returns 'pytorch-triton' - for mixed JAX+PyTorch environments.
Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121

Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder.
"""
use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

triton_package = "pytorch-triton" if use_pytorch_triton else "triton"

return [
"numpy",
triton_package,
]


def xla_path() -> str:
Expand Down
14 changes: 12 additions & 2 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@


def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
"""Install dependencies for TE/PyTorch extensions.

IMPORTANT - PyTorch Index Required for pytorch-triton:
These dependencies MUST be installed using PyTorch's package index:

pip install pytorch-triton --index-url https://download.pytorch.org/whl/<version??>

- pytorch-triton is only available from PyTorch's index (not PyPI)
- The 'pytorch-triton' package on PyPI is a placeholder that will fail
- torch.compile() requires pytorch-triton, not OpenAI's 'triton' package
"""
return [
"torch>=2.1",
"einops",
Expand All @@ -22,7 +32,7 @@ def install_requirements() -> List[str]:
"packaging",
"pydantic",
"nvdlfw-inspect",
"triton",
"pytorch-triton",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.

triton is the default for jax, unless either in 2 scenarios happen:

  • there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
  • The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.

]


Expand Down
33 changes: 32 additions & 1 deletion transformer_engine/jax/triton_extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,33 @@
IMPORTANT: This module requires Triton to be installed. If you don't have Triton,
use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives).

Install Triton: pip install triton

Triton Package Options:
-----------------------
There are two compatible Triton packages:

1. Standard 'triton' from OpenAI (recommended for JAX-only environments):
pip install triton

2. 'pytorch-triton' from PyTorch's index (for mixed JAX+PyTorch environments):
pip install torch --index-url https://download.pytorch.org/whl/cu121
# pytorch-triton is automatically installed as a dependency

Both packages work with JAX Triton kernels. The pytorch-triton package
has version format "X.Y.Z+<commit_sha>" (e.g., "3.0.0+45fff310c8").

WARNING: Do NOT run 'pip install pytorch-triton' directly! The package on PyPI
is a placeholder that will fail with "RuntimeError: Should never be installed".
The real pytorch-triton only comes bundled with PyTorch from PyTorch's index.


Environment Variables:
NVTE_USE_PYTORCH_TRITON: If set to "1", acknowledge using pytorch-triton
for JAX Triton kernels (suppresses compatibility warnings). Set this
when both JAX and PyTorch are installed in the same environment.

Example:
export NVTE_USE_PYTORCH_TRITON=1


Usage:
Expand All @@ -23,6 +49,11 @@ def lowering(ctx, x, **kwargs):

# Use permutation functions
from transformer_engine.jax.triton_extensions import make_row_id_map, permute_with_mask_map

# Check Triton package info
from transformer_engine.jax.triton_extensions import get_triton_info
info = get_triton_info()
print(f"Using Triton {info['version']} from {info['source']}")
"""

from .utils import *
Expand Down
164 changes: 163 additions & 1 deletion transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,33 @@

This module provides utility functions for integrating Triton kernels into
JAX primitives. Triton is only imported when this module is used.

Triton Package Compatibility:
There are two Triton packages that can be used:

1. 'triton' (from OpenAI/PyPI): Standard package, works with JAX out of the box.
Install with: pip install triton

2. 'pytorch-triton' (from PyTorch's index): Bundled with PyTorch, includes
PyTorch-specific patches. Version format: "3.0.0+<commit_sha>"

IMPORTANT: The 'pytorch-triton' package on PyPI (version 0.0.1) is a
placeholder that will NOT work. The real pytorch-triton is only available
from PyTorch's package index and is auto-installed with PyTorch:
pip install torch --index-url https://download.pytorch.org/whl/cu121

pytorch-triton has been tested to work with JAX Triton kernels.

Environment Variables:
NVTE_USE_PYTORCH_TRITON: If set to "1", explicitly acknowledge using
pytorch-triton for JAX Triton kernels (suppresses warnings). This is
useful when both JAX and PyTorch are installed in the same environment.
Default is "0".
"""

import hashlib
import os
import warnings
from typing import Any, Callable, Mapping
import zlib

Expand All @@ -17,6 +41,114 @@
import jax.numpy as jnp


# Placeholder package version on PyPI that should never be used
_PYTORCH_TRITON_PLACEHOLDER_VERSION = "0.0.1"


def _detect_triton_package():
"""Detect which Triton package is installed and validate compatibility.

Returns:
tuple: (triton_version: str or None, is_pytorch_triton: bool, is_placeholder: bool)

The function detects:
- None: Triton not installed
- Standard triton from OpenAI (versions like "3.1.0")
- Real pytorch-triton from PyTorch's index (versions like "3.0.0+45fff310c8")
- Placeholder pytorch-triton from PyPI (version "0.0.1" - broken, raises RuntimeError)
"""
try:
import triton

triton_version = getattr(triton, "__version__", "unknown")
except ImportError:
return None, False, False
except RuntimeError as e:
# The placeholder pytorch-triton package from PyPI raises:
# RuntimeError: "Should never be installed"
if "Should never be installed" in str(e):
return _PYTORCH_TRITON_PLACEHOLDER_VERSION, False, True
raise

# Check for placeholder package (version 0.0.1 from PyPI)
is_placeholder = triton_version == _PYTORCH_TRITON_PLACEHOLDER_VERSION

# Real pytorch-triton versions have a commit SHA suffix like "3.0.0+45fff310c8"
is_pytorch_triton = "+" in triton_version and len(triton_version.split("+")[-1]) >= 8

return triton_version, is_pytorch_triton, is_placeholder


def _check_triton_compatibility():
"""Check Triton package compatibility and emit warnings if necessary.

This function handles the case where both JAX and PyTorch may be installed,
each expecting different Triton packages:
- JAX typically uses the standard 'triton' package from OpenAI
- PyTorch uses 'pytorch-triton' which is versioned with commit SHAs

The NVTE_USE_PYTORCH_TRITON environment variable can be used to explicitly
acknowledge using pytorch-triton with JAX (suppresses warnings).

Raises:
ImportError: If triton is not installed or the placeholder package is detected.
"""
triton_version, is_pytorch_triton, is_placeholder = _detect_triton_package()

# Handle placeholder package from PyPI
if is_placeholder:
raise ImportError(
"Detected the placeholder 'pytorch-triton' package (version 0.0.1) from PyPI.\n"
"This is NOT a functional Triton installation.\n\n"
"The placeholder package exists to prevent namespace conflicts. To fix this:\n\n"
"Option 1 - Use standard Triton (recommended for JAX-only environments):\n"
" pip uninstall pytorch-triton triton\n"
" pip install triton\n\n"
"Option 2 - Use real pytorch-triton (for mixed JAX+PyTorch environments):\n"
" pip uninstall pytorch-triton triton\n"
" pip install torch --index-url https://download.pytorch.org/whl/cu121\n"
" # pytorch-triton is automatically installed as a torch dependency\n\n"
"Note: Do NOT run 'pip install pytorch-triton' directly - this installs\n"
"the broken placeholder. The real pytorch-triton only comes from PyTorch's index."
)

if triton_version is None:
raise ImportError(
"Triton is required for transformer_engine.jax.triton_extensions.\n\n"
"Option 1 - Install standard Triton (recommended for JAX-only):\n"
" pip install triton\n\n"
"Option 2 - Install PyTorch with pytorch-triton (for mixed environments):\n"
" pip install torch --index-url https://download.pytorch.org/whl/cu121\n\n"
"If you don't need Triton, use transformer_engine.jax.cpp_extensions instead."
)

use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower()
use_pytorch_triton_explicit = use_pytorch_triton_env in ("1", "true", "yes")

if is_pytorch_triton:
if use_pytorch_triton_explicit:
# User explicitly opted in - just log info (no warning)
pass # Silent acknowledgment, no warning needed
else:
# pytorch-triton detected but user didn't explicitly opt in
warnings.warn(
f"Detected pytorch-triton package (version {triton_version}) instead of the"
" standard 'triton' package from OpenAI. This typically happens when PyTorch is"
" installed alongside JAX.\n\npytorch-triton is compatible with JAX Triton"
" kernels. To suppress this warning, set:\n export"
" NVTE_USE_PYTORCH_TRITON=1\n\nAlternatively, for a JAX-only environment:\n - Use"
" separate virtual environments for JAX and PyTorch, or\n - Use"
" transformer_engine.jax.cpp_extensions instead (CUDA-based, no Triton needed)",
category=UserWarning,
stacklevel=3,
)

return triton_version, is_pytorch_triton


# Perform compatibility check and get triton info
_TRITON_VERSION, _IS_PYTORCH_TRITON = _check_triton_compatibility()

try:
from jax._src.lib import gpu_triton
from triton.compiler import compiler as tc
Expand All @@ -30,12 +162,42 @@
) from e


__all__ = ["triton_call_lowering"]
__all__ = ["triton_call_lowering", "get_triton_info"]

# Triton kernel cache (module-level, shared across all kernels)
_TRITON_KERNEL_CACHE = {}


def get_triton_info():
"""Get information about the installed Triton package.

Returns:
dict: Dictionary containing:
- version (str): Triton version string (e.g., "3.1.0" or "3.0.0+45fff310c8")
- is_pytorch_triton (bool): True if using real pytorch-triton from PyTorch's index
- is_openai_triton (bool): True if using standard triton from OpenAI/PyPI
- env_acknowledged (bool): True if NVTE_USE_PYTORCH_TRITON=1 is set
- source (str): "pytorch" or "openai" indicating the package source

Example:
from transformer_engine.jax.triton_extensions import get_triton_info
info = get_triton_info()
print(f"Triton version: {info['version']} (from {info['source']})")
if info['is_pytorch_triton']:
print("Using pytorch-triton - compatible with both PyTorch and JAX")
"""
use_pytorch_triton_env = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower()
env_acknowledged = use_pytorch_triton_env in ("1", "true", "yes")

return {
"version": _TRITON_VERSION,
"is_pytorch_triton": _IS_PYTORCH_TRITON,
"is_openai_triton": not _IS_PYTORCH_TRITON,
"env_acknowledged": env_acknowledged and _IS_PYTORCH_TRITON,
"source": "pytorch" if _IS_PYTORCH_TRITON else "openai",
}


def get_triton_dtype(aval):
"""Convert JAX dtype to Triton type string.

Expand Down
Loading