From 377ca9cefda3bb557cb1097bdabd626ba2cedf05 Mon Sep 17 00:00:00 2001 From: parkminkyumin Date: Thu, 22 Jan 2026 07:57:32 +0000 Subject: [PATCH] Docs: update add_aten_operator tutorial --- docs/mlir/add_aten_operator.md | 343 ++++++++++++++------------------- 1 file changed, 149 insertions(+), 194 deletions(-) diff --git a/docs/mlir/add_aten_operator.md b/docs/mlir/add_aten_operator.md index 847651c1..057f1932 100644 --- a/docs/mlir/add_aten_operator.md +++ b/docs/mlir/add_aten_operator.md @@ -1,17 +1,14 @@ -# How to add support for a new ATen Operator in PyTorchSim +# How to Add a New ATen Operator via MLIR Templates in PyTorchSim --- ## Overview PyTorchSim executes PyTorch programs by **lowering ATen operators into MLIR**, followed by backend-specific code generation and simulation. -This wiki describes contributors through the process of adding support for **new -ATen operators** by defining custom lowerings in PyTorchSim’s MLIR-based -execution path. +This wiki helps contributors add support for **new ATen operators** by defining custom lowerings that route operators through MLIR templates in the MLIR-based execution path. -We use a dummy operator, `torch._foobar()`, as a minimal example. -Although `_foobar` has trivial semantics, it still exercises the full +We use a dummy ATen operator, `torch._foobar()`, as a minimal example. Although `_foobar` has trivial semantics, it still exercises the full integration workflow: - Defining an MLIR template @@ -20,55 +17,42 @@ integration workflow: - Validating correctness with a test + --- ## Background -Before diving into the step-by-step implementation, it helps to understand -how an ATen operator flows through PyTorchSim’s MLIR-based pipeline. At a high level, supporting a new ATen operator means intercepting the -operator during the lowering stage and redirecting it to a custom MLIR template. -Understanding this flow makes the implementation steps clearer. - -At a high level, adding support for a new ATen op means intercepting this flow -at the lowering stage and redirecting it to a custom MLIR template. +Before diving into the step-by-step implementation, it helps to understand how an ATen operator flows through PyTorchSim’s MLIR-based pipeline. There are two main codegen paths: the **ordinary path** (generic lowering → standard MLIR codegen) and the **template path** (lowering → MLIRTemplate → TemplateBuffer → template-driven codegen). Supporting a new ATen operator typically means mapping it at the lowering stage into the template path so its code is generated by a custom MLIR template. ### Graph capture (`test_.py`) -When `torch.compile` captures a Python function, each `torch.(...)` call -is recorded in the graph as an ATen operator (`aten.`). - -This is the first point where the operator becomes visible to the lowering -pipeline and eligible for custom handling. +`torch.compile` runs `Dynamo` to trace the Python function and produce an `FX graph`. +That `FX graph` is then processed by `AOTAutograd`, which lowers it into an ATen-based graph. +Next, `Inductor` applies transformations such as decomposition and functionalization, and lowers the `ATen graph` into `Inductor IR`, which contains `ordinary IR nodes` and, when a custom lowering is used, `template-backed nodes`. At this point, `aten.` becomes eligible for custom handling in the lowering stage. ### Lowering stage (`mlir_lowering.py`) -During lowering, a custom function (e.g., `custom_`) is invoked for -`aten.`. +`aten.` is dispatched to a custom function through the lowerings registry (`lowerings.update({aten._foobar: custom_foobar})`). -The lowering: -- materializes inputs in Inductor IR if necessary, -- constructs an MLIR template instance from `mlir__template.py`, -- replaces the original ATen op with a template-backed Inductor IR node. +The lowering : +* Materializes the input (`TensorBox.realize()`) +* Creates the MLIR template instance from `mlir__template.py` +* Returns a template-backed `TemplateBuffer` via `generate().output_node()` -This is the key hook where new ATen operator support is introduced. +This is the primary hook where new ATen operator support is introduced. ### Scheduling stage (`mlir_scheduling.py`) -The template-backed node enters the scheduler and is routed through -`codegen_template()`. +The template-backed node enters the scheduler and follows the +`codegen_template()` path. At this stage: -- scheduling decisions are applied, -- the MLIR source code for the kernel is generated, -- the kernel is prepared for registration. +- Scheduling/Tiling decisions are applied +- MLIR source code is rendered from the template +- Kernel is prepared for registration ### Kernel registration (`define_kernel()`) -The generated MLIR source (`src_code`) is registered as a compilable and -cacheable kernel via the wrapper. - -After registration, the kernel becomes part of the code generation artifacts -and can be reused across runs without regenerating MLIR. - +The generated MLIR source (`src_code`) is registered via the wrapper in `define_kernel()`. This emits a `custom_async_compile.mlir(...)` call into the wrapper so the kernel can be compiled, cached, and reused at runtime. --- @@ -76,26 +60,23 @@ and can be reused across runs without regenerating MLIR. ## Table of Contents - [Toy Example: `_foobar` as a Dummy ATen Op](#toy-example-_foobar-as-a-dummy-aten-op) + - [Step 1 — Create `mlir__template.py`](#step-1--create-mlir_op_templatepy) + - [Step 2 — Add a Custom Lowering in `mlir_lowering.py`](#step-2--add-a-custom-lowering-in-mlir_loweringpy) - [Step 2.1 — Define the Custom Lowering Function](#step-21--define-the-custom-lowering-function) - - [Step 2.2 — Register the Lowering via `loweringsupdate`](#step-22--register-the-lowering-via-loweringsupdate) -- [Step 3 — Add a Test (`test_.py`)](#step-3--add-a-test-test_oppy) -- [Summary](#summary) - - - - + - [Step 2.2 — Register the Lowering via `lowerings.update`](#step-22--register-the-lowering-via-loweringsupdate) +- [Step 3 — Add a Test (`test_.py`)](#step-3--add-a-test-test_oppy) +- [Summary](#summary) --- ## Toy Example: `_foobar` as a Dummy ATen Op -`_foobar` is a deliberately trivial ATen operator exposed in PyTorch as `torch._foobar` -(and also accessible via `torch.ops.aten._foobar`). +`_foobar` is a deliberately trivial ATen operator exposed in PyTorch as [`torch._foobar()`](https://docs.pytorch.org/cppdocs/api/function_namespaceat_1a79acf777b53d06d41bdfa735b767506a.html). ### Signature @@ -104,14 +85,14 @@ From the PyTorch C++ API, `_foobar` has the following signature: * `at::_foobar(const Tensor& self, bool arg1=true, bool arg2=true)` In other words, it takes: -- one input tensor (`self`) -- three optional boolean flags (`arg1`, `arg2`) +- One input tensor (`self`) +- Three optional boolean flags (`arg1`, `arg2`) ### Behavior In our Python checks, `_foobar` behaves like a simple **copy/identity-style op**: -- the output has the same shape and dtype as the input -- the output values match the input values +- The output has the same shape and dtype as the input +- The output values match the input values Once you understand the full flow with `_foobar`, you can replace it with other ATen op and expand @@ -123,40 +104,38 @@ the lowering/template logic as needed. ## Step 1 — Create `mlir__template.py` -This step defines an MLIR template for the new ATen operator. -For `_foobar`, the template implements a minimal **identity-style kernel** that copies input elements to output elements. - -We walk through the file top-down, highlighting the role of each section. +This step defines the MLIR template for the new ATen operator. In `_foobar`, the template is a minimal copy-style kernel that loads tiles of `X` into SRAM and writes them back to `Y` using explicit DMA ops and a tiled `affine.for` loop nest. The template sets up tile descriptors, indexing, and render options that drive code generation. We walk through the file top-down, highlighting the role of each section. ### File skeleton and imports -Defines the core dependencies used to build an MLIR template, including: -- the base `MLIRTemplate` / `MLIRTemplateKernel` classes -- shared utilities from `mlir_common` -- symbolic helpers (e.g., `sympy`) for shape expressions +Pulls in the pieces needed to define and render an MLIR template, including: + +* The base `MLIRTemplate` / `MLIRTemplateKernel` classes for template codegen +* IRNode and Buffer types from Inductor IR for node typing and output binding +* MLIR helpers from `mlir_common` (e.g., tile descriptors) +* `sympy` and `empty_strided` for shape/stride computations +* A standard utility(`typing`) used by the template ```python from typing import List, Optional import sympy - +from torch import empty_strided from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel from torch._inductor.ir import IRNode, Buffer from PyTorchSimFrontend.mlir import mlir_common ``` ### MLIR string template +Defines the raw MLIR kernel as a string template with placeholders that are filled at render time. +In this `_foobar` implementation, the template: -Defines the raw MLIR code as a string template. -This template contains symbolic placeholders such as: +* Allocates SRAM buffers for `X` and `Y` via `def_sram_buffer` +* Builds a 2D tiled loop nest over `M` and `N` using `affine.for` with steps `TILE_M` and `TILE_N` +* Uses explicit DMA ops (`MVIN` / `MVOUT`) to move tiles between DRAM and SRAM +* Performs the copy with `linalg.copy` inside the tile -* number of elements in X (`{{ M }}`) - -* tile size (`{{ TILE }}`) - -* input/output memref shapes - -In this foobar example, the kernel performs 1D tiling over M and, within each tile, copies elements one by one from X to Y. It does not use SRAM buffers and does not emit DMA ops (MVIN/MVOUT) — all accesses are direct DRAM `memref.load`/`memref.store` operations. The placeholders are filled later via `kernel.render_options`. For more complex ATen ops (e.g., multi‑dimensional tiling, SRAM/DMA usage, prologue/epilogue fusion), see the next WIKI page. +Key placeholders include `{{ M }}`, `{{ N }}`, `{{ TILE_M }}`, `{{ TILE_N }}`, tile descriptors (`X_tile_desc`, `Y_tile_desc`), and kernel/signature helpers, all populated via `kernel.render_options` in `render()`. ```python @@ -164,16 +143,17 @@ TEMPLATE = r""" {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{- kernel.def_local_vars(indent_size=2) }} %M_const = arith.constant {{ M }} : index - affine.for %index0 = 0 to {{ M }} step {{ TILE }} { - affine.for %t = 0 to {{ TILE }} step 1 { - %g = arith.addi %index0, %t : index - %cond = arith.cmpi slt, %g, %M_const : index - scf.if %cond { - %val = memref.load %X[%g] : {{ X_flat_mlir_shape }} - memref.store %val, %Y[%g] : {{ Y_flat_mlir_shape }} - } - } + %N_const = arith.constant {{ N }} : index + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + linalg.copy ins(%X_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}) outs(%Y_buffer : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}) + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} + } { outer_loop=true } } { outer_loop=true } return } @@ -190,13 +170,12 @@ Initializes the template by calling the base class constructor to register. ```python class MLIRFoobarTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): - # Initialize the MLIR template with the kernel name and input/output nodes. super().__init__("kernel", input_nodes, layout, input_reorder) ``` ### Render entry point (`render`) -Defines the main render entry point for the template. It selects the output node (template buffer/epilogue), prepares tile descriptors and indices, fills `kernel.render_options`, and returns the rendered MLIR code. +Defines the main render entry point for the template. It selects the output node, `prologue_nodes`/`epilogue_nodes` are only used for fusion scenarios and are not otherwise consumed in this kernel. The method then prepares tile descriptors and indices, fills `kernel.render_options`, and returns the rendered MLIR code. ```python def render(self, @@ -210,79 +189,66 @@ Defines the main render entry point for the template. It selects the output node ### Output selection and basic setup -Selects the output buffer, binds symbolic names for input (`X`) and output (`Y`), -computes the number of elements in X (`M`), and derives the tile size (`TILE`) from the kernel -configuration. +Sets the output to template_buffer_node, binds X/Y, and reads the concrete shape by creating an `empty_strided` tensor from `X.layout` to extract `M` and `N`. Tile sizes are then fixed to `TILE_M = 8` and `TILE_N = 8` for this template. ```python if template_buffer_node is not None: self.output_node = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = epilogue_nodes[-1] - X = self.input_nodes[0] + X = self.input_nodes[0] Y = self.output_node - M = X.get_numel() - TILE = kernel.vector_lane + X_tensor = empty_strided(X.layout.size, X.layout.stride) + M = X_tensor.size()[0] + N = X_tensor.size()[1] + + TILE_M = 8 + TILE_N = 8 ``` ### Tile descriptors and indices -Defines tile descriptors for the input and output tensors. -A tile descriptor (`MLIRMultiDimTile`) is a small metadata object that captures a tile’s shape, stride, and vector‑lane mapping. It is later used to form SRAM buffer types, DMA parameters, and indexing decisions. In this _foobar example, both X and Y use 1D tiles of size TILE, and both are indexed by the same loop variable (`index0`) to represent elementwise copy. - +Defines 2D tile descriptors for `X` and `Y` using `MLIRMultiDimTile`, which captures tile shape, vector-lane mapping, and stride metadata for SRAM/DMA generation. +Here, both tiles use `[TILE_M, TILE_N]` with explicit strides `[1, TILE_M]`, the input tile also carries the DRAM offset, and indexing is built from the 2D loop variables (`index0`, `index1`) multiplied by the DRAM strides of X/Y. ```python - vlane_stride = 1 - vlane_split_axis = 0 - X_tile_size = [TILE] - X_tile_stride = [1] + vlane_stride = 8 + vlane_split_axis = 1 + X_tile_size = [TILE_M,TILE_N] + X_tile_stride = [1, TILE_M] X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) X_tile_desc.set_name("X_buffer") - X_idx = [sympy.Symbol("index0")] + X_tile_desc.offset = X.get_layout().offset + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index1") * X_stride[1]] - - Y_tile_size = [TILE] - Y_tile_stride = [1] + Y_tile_size = [TILE_M,TILE_N] + Y_tile_stride = [1, TILE_M] Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("Y_buffer") - Y_idx = [sympy.Symbol("index0")] -``` - -### Memref shape strings - -Defines the memref type strings used in the MLIR template for both input and output. -A memref is MLIR’s memory reference type, it describes a buffer in memory by its shape and element type and optionally layout/stride information. -For example, memref<128xf32> is a 1D buffer of 128 floats. In this foobar example, both X and Y are treated as 1D buffers, so we use memref<{M}xf32>: - -```python - X_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') - Y_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') + Y_stride = Y.get_layout().stride + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] ``` ### Render options -Collects all symbolic values and configuration parameters into -`kernel.render_options`. These options are later used to render the MLIR string -template. +Populates `kernel.render_options` with all values needed by the MLIR string template: kernel name/handle, tensor shape (`M`, `N`), tile sizes (`TILE_M`, `TILE_N`), input/output buffers (`X`, `Y`), index expressions (`X_idx`, `Y_idx`), tile descriptors (`X_tile_desc`, `Y_tile_desc`), data type, and input reordering. These options are then used to render `TEMPLATE`. ```python kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - M=M, - TILE=TILE, + M=M, N=N, + TILE_M=TILE_M, + TILE_N=TILE_N, X=X, Y=Y, X_idx=X_idx, Y_idx=Y_idx, X_tile_desc=X_tile_desc, Y_tile_desc=Y_tile_desc, - X_flat_mlir_shape=X_flat_mlir_shape, - Y_flat_mlir_shape=Y_flat_mlir_shape, DATA_STYPE="f32", input_reorder=self.input_reorder, ) @@ -290,8 +256,7 @@ template. ### Epilogue -Records metadata related to output buffers and element counts, which is useful for -exception handling and debugging. +Stores epilogue metadata for the output tile—output_node, DRAM/SRAM names, and the tile descriptor—so the codegen can configure epilogue tile size and generate the final DMA/MVOUT path for the output. ```python kernel.epilogue_info = dict( @@ -303,15 +268,17 @@ exception handling and debugging. ``` -### Render MLIR and add loop metadata & Return +### Render MLIR, register loop metadata and return -Renders the final MLIR code by substituting placeholders in the template. -And returns the final MLIR string that will be consumed by the kernel. +Renders the MLIR string by applying `kernel.render_options` to `TEMPLATE`, records loop bounds/strides via `kernel.add_loop_info` for wrapper metadata, and returns the final MLIR source string. ```python + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"]]) + + return code - return code ``` ### Full `mlir_foobar_template.py` @@ -321,26 +288,28 @@ Copy-paste the full reference implementation below to create `mlir_foobar_templa ```python from typing import List, Optional import sympy - +from torch import empty_strided from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel from torch._inductor.ir import IRNode, Buffer from PyTorchSimFrontend.mlir import mlir_common - +from PyTorchSimFrontend import extension_config +from pathlib import Path TEMPLATE = r""" {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{- kernel.def_local_vars(indent_size=2) }} %M_const = arith.constant {{ M }} : index - affine.for %index0 = 0 to {{ M }} step {{ TILE }} { - affine.for %t = 0 to {{ TILE }} step 1 { - %g = arith.addi %index0, %t : index - %cond = arith.cmpi slt, %g, %M_const : index - scf.if %cond { - %val = memref.load %X[%g] : {{ X_flat_mlir_shape }} - memref.store %val, %Y[%g] : {{ Y_flat_mlir_shape }} - } - } + %N_const = arith.constant {{ N }} : index + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + linalg.copy ins(%X_buffer : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}) outs(%Y_buffer : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}) + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} + } { outer_loop=true } } { outer_loop=true } return } @@ -349,7 +318,6 @@ func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_ class MLIRFoobarTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): - # Initialize the MLIR template with the kernel name and input/output nodes. super().__init__("kernel", input_nodes, layout, input_reorder) def render(self, @@ -359,55 +327,54 @@ class MLIRFoobarTemplate(MLIRTemplate): prologue_nodes: Optional[List[IRNode]] = None, tile_info = None, **kwargs): - """Render the MLIR code for the `torch._foobar()` operation. - This method generates the MLIR code by filling in the placeholders in the - `TEMPLATE` string with the appropriate values for the input/output tensors, - tile sizes, and other parameters. - """ if template_buffer_node is not None: self.output_node = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = epilogue_nodes[-1] + # if epilogue_nodes is not None and len(epilogue_nodes) > 0: + # self.output_node = epilogue_nodes[-1] X = self.input_nodes[0] - Y = self.output_node + Y = self.output_node - M = X.get_numel() - TILE = kernel.vector_lane + X_tensor = empty_strided(X.layout.size, X.layout.stride) + M = X_tensor.size()[0] + N = X_tensor.size()[1] - vlane_stride = 1 - vlane_split_axis = 0 - X_tile_size = [TILE] - X_tile_stride = [1] + TILE_M = 8 + TILE_N = 8 + + vlane_stride = 8 + vlane_split_axis = 1 + X_tile_size = [TILE_M,TILE_N] + X_tile_stride = [1, TILE_M] X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) X_tile_desc.set_name("X_buffer") - X_idx = [sympy.Symbol("index0")] + X_tile_desc.offset = X.get_layout().offset + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index1") * X_stride[1]] - Y_tile_size = [TILE] - Y_tile_stride = [1] + Y_tile_size = [TILE_M,TILE_N] + Y_tile_stride = [1, TILE_M] Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("Y_buffer") - Y_idx = [sympy.Symbol("index0")] + Y_stride = Y.get_layout().stride + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] - X_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') - Y_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - M=M, - TILE=TILE, + M=M, N=N, + TILE_M=TILE_M, + TILE_N=TILE_N, X=X, Y=Y, X_idx=X_idx, Y_idx=Y_idx, X_tile_desc=X_tile_desc, Y_tile_desc=Y_tile_desc, - X_flat_mlir_shape=X_flat_mlir_shape, - Y_flat_mlir_shape=Y_flat_mlir_shape, DATA_STYPE="f32", input_reorder=self.input_reorder, ) @@ -418,10 +385,13 @@ class MLIRFoobarTemplate(MLIRTemplate): dram_var="Y", dram_tile_desc=Y_tile_desc, ) - + + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"]]) return code + ``` @@ -433,8 +403,8 @@ This step connects the ATen operator (`aten._foobar`) to the MLIR template. All changes in this step are made in **`mlir_lowering.py`**. It consists of two parts: -1. Defining the custom lowering function, and -2. Registering that function so Inductor actually uses it. +1. Defining the custom lowering function +2. Registering that function so Inductor actually uses it --- @@ -443,7 +413,7 @@ It consists of two parts: The custom lowering function specifies **how `aten._foobar` should be lowered** during the Inductor lowering stage. -For this tutorial, the lowering is intentionally minimal and preserves the +For this example, the lowering is intentionally minimal and preserves the MLIR/template path, making it suitable for testing and demonstration purposes. ```python @@ -462,7 +432,7 @@ The actual input tensor. During lowering, this is typically wrapped in a with its layout/metadata and deferred computation context. * `a.realize()`: -Materializes the input so the MLIR template sees a concrete buffer / IR node. +Materializes the input `TensorBox` so the MLIR template sees a concrete buffer / IR node. This is a safe default pattern to ensure shape/stride/layout metadata is available. * `layout = a.layout`: @@ -473,32 +443,27 @@ memref types and derive tiling/indexing behavior. * `MLIRFoobarTemplate([a], layout)`: Instantiates the MLIR template with the input node and its layout. -`generate().output_node()`: +* `generate().output_node()`: Builds the template-backed Inductor IR node and returns it as the lowering result. -This pattern is the baseline shape you will reuse for other operators, adding -more logic as the operator becomes more complex. --- -### Step 2.2 — Register the Lowering via `lowerings.update(...)` +### Step 2.2 — Register the Lowering via `lowerings.update` + +Defining the lowering function is only the first step. You also need to register it so Inductor can dispatch to it during lowering. -Defining the lowering function alone is not enough. -It must be registered so that Inductor actually invokes it. ```python lowerings.update({getattr(aten._foobar, overload): custom_foobar for overload in aten._foobar.overloads()}) ``` -The `lowerings` table is consulted by Inductor during the lowering phase. -When the ATen graph contains `aten._foobar`, Inductor looks up the operator in -this table and invokes `custom_foobar` instead of the default lowering. +The `lowerings` registry is consulted by Inductor during the lowering phase. +When the ATen graph contains `aten._foobar`, Inductor looks it up here and dispatches to custom_foobar instead of the default path. Using `aten._foobar.overloads()` ensures that all overload variants of the operator are covered, even if multiple signatures exist. -This is the wiring step that activates the MLIR template path defined in -Step 1 and implemented in Step 2.1. --- ## Step 3 — Add a Test (`test_.py`) @@ -535,11 +500,11 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): ### Define a small wrapper function -Defines a minimal wrapper function (vector_foobar) that calls -torch._foobar(a). +Defines a minimal wrapper function `vector_foobar` that calls +`torch._foobar(a)`. -This function is what torch.compile captures into the graph, making it the -entry point for Dynamo → ATen → Inductor → custom lowering. +This function is what `torch.compile` captures into the graph, making it the +entry point for FX Graph → ATen → Inductor → custom lowering. ``` def test_foobar(device, size=(128, 128)): @@ -550,10 +515,7 @@ def test_foobar(device, size=(128, 128)): ### Create input and compile Creates a random input tensor and compiles the wrapper function using -torch.compile(dynamic=False). - -This ensures the execution path goes through: -Dynamo → ATen → Inductor → custom lowering → MLIR template. +`torch.compile(dynamic=False)`. ```python x = torch.randn(size).to(device=device) @@ -566,7 +528,7 @@ Dynamo → ATen → Inductor → custom lowering → MLIR template. Executes the compiled function and compares the result against a reference output. -For _foobar, the reference output is simply the input tensor moved to CPU, +For `_foobar`, the reference output is simply the input tensor moved to CPU, since the operator behaves like an identity op. ```python @@ -579,11 +541,11 @@ since the operator behaves like an identity op. Defines the command-line entry point for the test. This section: -* Sets up the custom PyTorchSim device and runner, +* Sets up the custom PyTorchSim device and runner -* Parses shape arguments, +* Parses shape arguments -* Runs the test across multiple input sizes. +* Runs the test across multiple input sizes ```python if __name__ == "__main__": @@ -668,14 +630,7 @@ To add support for a new ATen operator in PyTorchSim’s MLIR path, you: 2. Implement and register a custom lowering in `mlir_lowering.py`, 3. Validate the integration with a dedicated test (`test_.py`). -During testing, `torch.compile` captures the operator into the ATen graph. -The custom lowering replaces the ATen op with a template-backed Inductor IR node, -which then flows through scheduling, MLIR code generation, and kernel -registration before execution. - The `_foobar` example illustrates the complete integration flow for adding a new ATen operator in PyTorchSim. You can use this example as a reference when -extending PyTorchSim to support additional ATen operators. For operators with -more complex semantics, refer to the follow-up documentation for guidance on -advanced lowering, layout handling, and performance considerations. +extending PyTorchSim to support additional ATen operators.