From 3ce8a79790c95617c0b47be81de258b3e4e66f87 Mon Sep 17 00:00:00 2001 From: parkminkyumin Date: Mon, 19 Jan 2026 10:39:10 +0000 Subject: [PATCH 1/2] WIP: foobar MLIR lowering + related changes --- .gitignore | 6 +- PyTorchSimFrontend/extension_device.cpp | 1 + .../mlir/mlir_codegen_backend.py | 16 +- PyTorchSimFrontend/mlir/mlir_common.py | 119 +++- .../mlir/mlir_foobar_template.py | 117 ++++ PyTorchSimFrontend/mlir/mlir_gemm_template.py | 3 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 127 ++++- .../mlir/mlir_maxpool_template.py | 124 ++-- PyTorchSimFrontend/mlir/mlir_scheduling.py | 533 +++++++++--------- PyTorchSimFrontend/mlir/mlir_template.py | 364 ++++++++---- Scheduler/scheduler.py | 12 +- tests/test_foobar.py | 48 ++ tests/test_matmul.py | 5 +- tutorial/session1/ExecutionMode.ipynb | 113 +++- tutorial/session1/Inference.ipynb | 45 +- tutorial/session1/LogAnalysis.ipynb | 44 +- tutorial/session1/Mapping.ipynb | 106 +++- .../session1/tutorial_external_mapping.json | 2 +- tutorial/session2/Hands_on.ipynb | 74 ++- 19 files changed, 1310 insertions(+), 549 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_foobar_template.py create mode 100644 tests/test_foobar.py diff --git a/.gitignore b/.gitignore index b42d5f6b..ade5a563 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,8 @@ TOGSim/build/ output togsim_results/* outputs/* -experiments/artifact/logs/* \ No newline at end of file +experiments/artifact/logs/*# tutorial generated artifacts +tutorial/session1/fused/ +tutorial/session1/togsim_results/ +tutorial/session2/togsim_results/ +tutorial/session2/fx_graph.svg diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp index 1a02bfe3..e7a20153 100644 --- a/PyTorchSimFrontend/extension_device.cpp +++ b/PyTorchSimFrontend/extension_device.cpp @@ -338,6 +338,7 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("_foobar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } // This basic implementation doesn't bother dealing with different device indices diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 6650f429..f6505daf 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -430,10 +430,17 @@ def exp2(operand, *args, var_info=None, **kwargs): # var_info = {operand: [tile_size, dtype]} # Ex) var_info[operand] = [8, "f32"] - ln2 = math.log(2) - coeff = ops.constant(ln2, "f32") - operand = ops.mul(operand, coeff) - return ops.exp(operand), var_info[operand] + # ln2 = math.log(2) + # coeff = ops.constant(ln2, "f32") + # operand = ops.mul(operand, coeff) + # return ops.exp(operand), var_info[operand] + + tile_size, dtype = var_info[operand] + if tile_size > 1: + shape = f"vector<{tile_size}x{dtype}>" + else: + shape = dtype + return f'math.exp2 %{operand} : {shape}', [tile_size, dtype] @staticmethod def erf(operand, *args, var_info=None, **kwargs): @@ -897,6 +904,7 @@ def broadcast(operand1, operand2, *args, var_info=None, **kwargs): "MVOUT1": 3, } +# Loop level ir -> MLIR 변환하는 translator class MLIRKernel(mlir_common.BaseMLIRKernel): overrides = ExtensionOverrides newvar_prefix = "%" diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 4d33eea4..0327755b 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -137,43 +137,67 @@ def __init__(self, message="Recompilation requested."): super().__init__(self.message) class MLIRKernelArgs(common.KernelArgs): + """MLIR 전용 커널 인자 헬퍼. + + 역할: 그래프의 버퍼/상수/사이즈 정보를 수집하여 MLIR 함수 시그니처에 맞춘 + arg 정의와 호출 인자 목록을 생성합니다. 또한 인자 타입(in/out/inout/var)을 + 비트플래그로 관리합니다. + """ MLIR_ARGS_IN = 0x01 MLIR_ARGS_OUT = 0x02 MLIR_ARGS_INOUT = 0x04 MLIR_ARGS_VAR = 0x08 def __init__(self, tile_row=None, tile_col=None): + # 부모 클래스 초기화 및 MLIR 전용 타일 정보 보관 super().__init__() self.tile_row = tile_row self.tile_col = tile_col @staticmethod def is_mlir_arg_in(value): + """값이 '입력' 혹은 'inout' 타입인지 판별합니다. + + 왜 필요한가: 인자 분류는 코드 생성(읽기 전용/쓰기 포함)에 영향을 줍니다. + """ return (MLIRKernelArgs.MLIR_ARGS_IN & value) | (MLIRKernelArgs.MLIR_ARGS_INOUT & value) @staticmethod def is_mlir_arg_out(value): + """값이 '출력' 혹은 'inout' 타입인지 판별합니다.""" return (MLIRKernelArgs.MLIR_ARGS_OUT & value) | (MLIRKernelArgs.MLIR_ARGS_INOUT & value) @staticmethod def is_mlir_arg_inout(value): + """값이 'inout' 타입(입출력)인지 판별합니다.""" return MLIRKernelArgs.MLIR_ARGS_INOUT & value @staticmethod def get_mlir_shape(info): + """dtype/numel 정보를 받아 MLIR memref shape 문자열을 생성합니다.""" tensor_type = DTYPE_TO_MLIR[info[0]] return f"memref<{info[1]}x{tensor_type}>" def mlir_argdefs(self, extra_node=dict()): + """그래프의 버퍼/상수/추가 노드 정보를 수집하여 + MLIR 인자 정의(arg_defs), 호출인자(call_args), 인자 속성(arg_attributes) + 및 버퍼 메타 정보(buffer_types)를 반환합니다. + + 왜 필요한가: wrapper 코드와 MLIR 함수 선언을 일관되게 생성하기 위해 + 모든 인자 정보를 통합하여 제공해야 합니다. + """ buffer_types = {} + # 그래프에 존재하는 버퍼들을 순회하여 메타 정보 수집 for x in V.graph.buffers: if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] + # 그래프 입력(심볼릭 포함) 처리 for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): buffer_types[name] = [get_sympy_Expr_dtype(val), 1, [1], [1]] else: buffer_types[name] = [val.get_dtype(), val.get_numel(), val.get_size(), val.get_stride()] + # 상수/추가 노드 정보 병합 buffer_types.update( {name: [val.dtype, 1, [1], [1]] for name, val in V.graph.constants.items()} ) @@ -185,11 +209,13 @@ def mlir_argdefs(self, extra_node=dict()): arg_defs = [] arg_attributes = [] def set_info(outer, inner, arg_type): + # outer: 실제 그래프 이름, inner: MLIR 내부 이름(%X) mlir_shape = self.get_mlir_shape(buffer_types[outer]) arg_defs.append(f"%{inner}: {mlir_shape}") call_args.append(outer) arg_attributes.append([outer] + [[arg_type] + buffer_types[outer]]) + # inplaced, input, output, sizevar 등 카테고리별로 인자 등록 for inplaced in unique(self.inplace_buffers.values()): if self._buffer_is_marked_removed(inplaced): continue @@ -209,19 +235,34 @@ def set_info(outer, inner, arg_type): return arg_defs, call_args, arg_attributes, buffer_types class VectorLaneMapping(): + """Vector lane (vlane) 관련 매핑 정보를 관리하는 유틸리티 클래스. + + 역할: 주어진 타일을 vlane 단위로 어떻게 분할할지(vlane split axis/stride, 사용 vlane 수 등)를 계산 + 하여 각 lane 당 처리량(타일 크기/stride 등)과 벡터화 크기를 결정합니다. + """ def __init__(self, vector_lane: int, forced_vec_size: int, vlane_split_axis: int, vlane_stride: int): + # 하드웨어/매핑 관련 파라미터 보관 self.vector_lane = vector_lane self.vlane_split_axis = vlane_split_axis self.vlane_stride = vlane_stride self.forced_vec_size = forced_vec_size def get_used_vlane(self, tile_size: list[int]): + """타일 크기에서 실제로 사용될 vlane 수를 계산. + + 계산: split_axis 차원의 크기 / vlane_stride를 올림한 값과 전체 vector_lane 중 작은 값을 선택. + 이유: 타일의 해당 축이 vlane 단위로 어떻게 분배되는지 판단하여 자원 할당을 결정하기 위함. + """ return min( math.ceil(tile_size[self.vlane_split_axis] / self.vlane_stride), self.vector_lane ) def get_tile_size_per_lane(self, tile_size: list[int]): + """타일을 lane당 단위로 나눴을 때의 per-lane 타일 크기를 반환. + + 필요한 이유: per-lane 계산량과 메모리 요구량 추정을 위해 각 lane의 타일 크기 정보를 얻어야 함. + """ per_lane = tile_size.copy() used = self.get_used_vlane(tile_size) if self.vlane_split_axis < 0 or self.vlane_split_axis >= len(per_lane): @@ -230,20 +271,33 @@ def get_tile_size_per_lane(self, tile_size: list[int]): return per_lane def get_numel_per_lane(self, tile_size: list[int]): + # per-lane 타일의 원소 수(=연산량)를 계산 return math.prod(self.get_tile_size_per_lane(tile_size)) def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int]): + """원래 타일 stride를 per-lane stride로 변환. + + 이유: 메모리 접근 패턴을 lane 단위로 재계산하여 DMA/로컬 인덱싱을 조정하기 위함. + """ tile_stride = tile_stride.copy() # original strides get_tile_size_per_lane = self.get_tile_size_per_lane(tile_size) coeff = tile_size[self.vlane_split_axis]//get_tile_size_per_lane[self.vlane_split_axis] - # Propagate stride according to per-lane tile size + # Per-lane로 전파할 때 필요한 stride 보정 수행 for i in range(len(tile_stride)): if tile_stride[i] > tile_stride[self.vlane_split_axis]: tile_stride[i] = tile_stride[i] // coeff return tile_stride def get_compute_vec_size(self, tile_size: list[int], reduction_numel: int, nr_rdim: int) -> int: + """계산에 사용할 벡터화(또는 SIMD) 단위 크기를 결정. + + - forced_vec_size가 설정되어 있으면 강제값을 반환. + - reduction이 있는 경우 제약을 고려해 적절한 분할 크기를 계산. + - 그렇지 않으면 stride 단위를 기준으로 8/4/2 배수 중 적절한 크기를 선택. + + 목적: 벡터 연산의 폭을 결정하여 코드 생성 시 vector load/store 및 연산 단위를 맞추기 위함. + """ if self.forced_vec_size is not None: return self.forced_vec_size @@ -481,16 +535,26 @@ def adjust(self, old: int, new: int, dim: int) -> int: raise extension_codecache.TileSizeError("Cannot find suitable tile size under the given constraints.") class MLIRMultiDimTile(TileAdjustMixin): + """다차원 타일 설명자(타일 크기/stride/벡터 매핑 등). + + 역할: 타일의 크기, stride, axis 순서 및 vlane 매핑 정보를 보관하고, + 코드 생성 시 MLIR에서 사용할 형태로 변환하는 헬퍼를 제공합니다. + """ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=None, forced_vec_size=None): super().__init__() + # 식별자 이름 self.name = "" + # 내부 타일 사이즈 리스트 self._tile_size = list(tile_size) + # 계산된 stride(초기에는 None, update_tile_stride로 설정) self._tile_stride = None + # 각 축에 대한 제약 정보(TileConstraint 객체 목록) self.tile_constraint = [TileConstraint(vlane_stride) for _ in tile_size] + # 기본 축 순서 self.tile_axis_order = list(range(len(tile_size))) self.update_tile_stride() - # Vector lane mapping config + # Vector lane 매핑 설정 보관 self.vmap = VectorLaneMapping( vector_lane=vector_lane, forced_vec_size=forced_vec_size, @@ -498,9 +562,10 @@ def __init__(self, tile_size, vector_lane, vlane_split_axis=None, vlane_stride=N vlane_stride=vlane_stride ) + # implicit dim, reduction 깊이 등 추가 메타 self.implicit_dim_size = None self.nr_rdim = 0 - self.offset = sympy.Integer(0) # Dram offset + self.offset = sympy.Integer(0) # DRAM offset을 sympy로 표현 def set_name(self, name: str): self.name = name def get_name(self) -> str: return self.name @@ -511,15 +576,21 @@ def get_nr_dim(self) -> str: return len(self._tile_size) def get_reduction_numel(self): return reduce(mul, self.get_tile_size()[-1*self.nr_rdim:], 1) def set_tile_size(self, tile_size, tile_axis_order=None, constraints=None): + # 타일 크기/순서/제약을 적용 후 stride를 갱신 self._tile_size = list(tile_size) self.tile_axis_order = list(range(len(tile_size))) if tile_axis_order is None else tile_axis_order self.update_tile_stride() def set_tile_size_stride(self, tile_size, tile_stride): + # 타일과 stride를 직접 설정할 때 사용 self._tile_size = list(tile_size) self._tile_stride = list(tile_stride) def update_tile_stride(self): + """타일 사이즈와 axis order로부터 각 축에 대한 stride를 계산. + + 이유: 각 차원의 stride는 메모리 인덱싱 및 DMA 인코딩에 필요합니다. + """ strides = [1] * len(self._tile_size) init = 1 @@ -534,13 +605,14 @@ def update_tile_stride(self): self._tile_stride = strides def get_dim_size(self, index): + # 정수 인덱스 또는 'indexN' 형태(sympy 변수)를 받아 해당 축 크기를 반환 if isinstance(index, int): return self._tile_size[index] elif "index" in str(index): return self._tile_size[int(str(index)[5:])] raise NotImplementedError("Unsupported format of index") - # Vector mapping delegation + # Vector mapping delegation - 내부 vmap으로 위임 def get_tile_size_per_lane(self): return self.vmap.get_tile_size_per_lane(self._tile_size) def get_used_vlane(self): return self.vmap.get_used_vlane(self._tile_size) def get_numel_per_lane(self): return self.vmap.get_numel_per_lane(self._tile_size) @@ -549,31 +621,56 @@ def get_compute_vec_size(self): return self.vmap.get_compute_vec_size(self._tile # Helper functions for codegen def get_mlir_shape(self, dtype): + # MLIR memref shape 문자열을 생성(예: memref<16x8xf32, 1>) shape = "x".join([str(dim) for dim in self._tile_size]) return f"memref<{shape}x{dtype}, 1>" def get_mlir_vshape(self, mlir_dtype): + # 벡터 타입 문자열(벡터화 크기 > 1인 경우 vector<...> 형식) return f"vector<{self.get_compute_vec_size()}x{mlir_dtype}>" if self.get_compute_vec_size() > 1 else f"{mlir_dtype}" class MLIRWrapperKenrelGroup(cpp.KernelGroup): + """MLIR용 래퍼 커널 그룹. 인자와 타일 정보를 보관하는 컨테이너. + + 역할: 기존 C++/CUDA용 KernelGroup을 확장하여 MLIR 전용 args와 tile descriptor를 갖도록 함. + wrapper/렌더러가 타일 정보를 읽고 코드생성에 사용할 수 있도록 중앙에서 관리합니다. + """ def __init__(self): super().__init__() + # MLIR 전용 인자 관리 객체 self.args = MLIRKernelArgs() + # 코드 생성 시 사용될 MLIRMultiDimTile 객체(초기에는 None) self.tile_desc : MLIRMultiDimTile = None def set_tile_info(self, tile_desc : MLIRMultiDimTile): + # 외부에서 계산된 타일 설명자를 등록 self.tile_desc = tile_desc class BaseMLIRHardwareInfo(): + """하드웨어 관련 기본 정보 보관용 베이스 클래스. + + 역할: VPU/스패드/정밀도/코어 수 등 코드 생성 시 참조되는 하드웨어 파라미터를 중앙화합니다. + 확장자가 원하는 하드웨어 설정을 여기서 정의하여 다른 컴포넌트가 일관되게 사용하도록 합니다. + """ def __init__(self): - # Default HW setting + # Default HW setting (extension_config에서 값을 읽어 설정) + # vector_lane: VPU의 lane 수(벡터 연산 폭의 단위) self.vector_lane = extension_config.vpu_num_lanes + # spad_info: 스패드(로컬 메모리) 크기/설정 정보 self.spad_info = extension_config.CONFIG_SPAD_INFO + # 정밀도(byte 단위 등) - e.g., f32 -> 4 bytes self.precision = extension_config.CONFIG_PRECISION + # 병렬 코어 수 self.num_cores = extension_config.CONFIG_NUM_CORES + # 벡터 레지스터 길이(bits) self.vlen = extension_config.vpu_vector_length_bits class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): + """MLIR 코드 생성의 공통 기반 클래스. + + 역할: MLIR 전용 코드 버퍼, CSE, 변수/버퍼 메타, 루프 정보 및 재컴파일 로직을 제공하여 + 실제 커널별 구현이 이를 활용하도록 공통 기능을 캡슐화합니다. + """ newvar_prefix = "%" suffix = "" overrides = None @@ -581,23 +678,29 @@ class BaseMLIRKernel(common.Kernel, BaseMLIRHardwareInfo): store_format = None def __init__(self, kernel_group, reason=None): + """기본 상태(버퍼, 루프, CSE 등)를 초기화. + + reason: 재컴파일 원인을 전달(recodegen)하여 타일 조정/재시도 로직에 사용. + """ super().__init__(kernel_group.args) self.kernel_group = kernel_group - # Kernel iteration range info + # 루프/범위 관련 상태 self.call_ranges = None self.ranges = None self.reduction_depth = None self.itervars = None - # Code buffer + # 코드 버퍼: 벡터 연산 본문, 리덕션 접미사 등 self.vector_compute = IndentedBuffer() self.reductions_suffix = IndentedBuffer() self.cse = common.CSE(self.newvar_prefix, self.suffix) - # MLIR SSA tracker + # MLIR SSA 변수 정보 추적기 self.var_info = {} # MLIR variable info self.buffer_types : dict = None # format: dtype, numel, size, stride + # compute index 이름 및 루프 레벨 설정 self.compute_idx = "compute_idx" self.compute_body_loop = LoopLevel(self.compute_idx, 1) self.prologue_compute_body_loop = LoopLevel(self.compute_idx, 1) + # 재컴파일 이유 (예: spad overflow 등)를 저장 self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False diff --git a/PyTorchSimFrontend/mlir/mlir_foobar_template.py b/PyTorchSimFrontend/mlir/mlir_foobar_template.py new file mode 100644 index 00000000..13188c81 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_foobar_template.py @@ -0,0 +1,117 @@ +from typing import List, Optional +import sympy +from torch import empty_strided +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from torch._inductor.ir import IRNode, Buffer +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend import extension_config +from pathlib import Path +import json + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + %M_const = arith.constant {{ M }} : index + %N_const = arith.constant {{ N }} : index + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + linialg.copy {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) } + } {outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFoobarTemplate(MLIRTemplate): + + def __init__(self, input_nodes, layout, input_reorder=None): + # Initialize the MLIR template with the kernel name and input/output nodes. + super().__init__("kernel", input_nodes, layout, input_reorder) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = epilogue_nodes[-1] + + X = self.input_nodes[0] + Y = self.output_node + + X_tensor = empty_strided(X.layout.size, X.layout.stride) + + M = X_tensor.size()[0] + N = X_tensor.size()[1] + # path = Path(extension_config.codegen_external_mapping_file) + # with path.open("r") as f: + # data = json.load(f) + # tile_info = data[f"{M}x{N}"] + # TILE_M, TILE_N = tile_info.values() + TILE_M = 64 + TILE_N = 64 + + TILE = kernel.vector_lane + + vlane_stride = 1 + vlane_split_axis = 0 + X_tile_size = [TILE_M,TILE_N] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_tile_desc.offset = X.get_layout().offset + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index1") * X_stride[1]] + + Y_tile_size = [TILE_M,TILE_N] + Y_tile_stride = [1, TILE_M] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] + + # X_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') + # Y_flat_mlir_shape = f"memref<{M}x{{DATA_STYPE}}>".replace('{DATA_STYPE}', 'f32') + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + M=M, N=N, + TILE=TILE, + TILE_M=TILE_M, + TILE_N=TILE_N, + X=X, + Y=Y, + X_idx=X_idx, + Y_idx=Y_idx, + X_tile_desc=X_tile_desc, + Y_tile_desc=Y_tile_desc, + #X_flat_mlir_shape=X_flat_mlir_shape, + #Y_flat_mlir_shape=Y_flat_mlir_shape, + DATA_STYPE="f32", + input_reorder=self.input_reorder, + ) + + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="Y_buffer", + dram_var="Y", + dram_tile_desc=Y_tile_desc, + ) + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"]]) + return code \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index bbc63b45..3790b6df 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -25,7 +25,8 @@ func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_ + desc, indent_size=2) }} {% if not Bias %} %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} {{ kernel.def_local_vars(indent_size=2) }} diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..ddc0bdb9 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,43 +1,65 @@ -from typing import List, Optional, Sequence - -import torch -from torch._inductor.lowering import lowerings, index_impl -from torch._inductor.kernel.mm_common import mm_args -# from torch._inductor.select_algorithm import ExternKernelChoice -from torch._inductor import ir -from torch._inductor.virtualized import V -from torch._inductor.ir import TensorBox -from PyTorchSimFrontend.extension_op import MLIRExternKernelChoice -from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate -from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate -from PyTorchSimFrontend.mlir.mlir_conv_template import MLIRConvTemplate -from PyTorchSimFrontend.mlir.mlir_conv_mt_template import MLIRConvMultiTileTemplate -from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate -from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate -from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate -from PyTorchSimFrontend import extension_config +# mlir_lowering.py +# 이 파일은 PyTorch Inductor의 lowering 단계에서 특정 aten 연산들을 MLIR 템플릿 또는 커스텀 구현으로 매핑합니다. +# 각 lowering 함수는 입력 IR(TensorBox)들을 받아 MLIR 템플릿을 생성하거나, 외부 커널 호출을 준비합니다. +from typing import List, Optional, Sequence + +import torch +from torch._inductor.lowering import lowerings, index_impl +from torch._inductor.kernel.mm_common import mm_args +# from torch._inductor.select_algorithm import ExternKernelChoice +from torch._inductor import ir +from torch._inductor.virtualized import V +from torch._inductor.ir import TensorBox +from PyTorchSimFrontend.extension_op import MLIRExternKernelChoice +from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate +from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate +from PyTorchSimFrontend.mlir.mlir_conv_template import MLIRConvTemplate +from PyTorchSimFrontend.mlir.mlir_conv_mt_template import MLIRConvMultiTileTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate +from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate +from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_foobar_template import MLIRFoobarTemplate +from PyTorchSimFrontend import extension_config + +# shortcut: aten ops에 접근하기 쉽게 변수에 할당합니다. aten = torch.ops.aten +# sparse mm 연산을 MLIR 외부 커널로 매핑하기 위한 래퍼를 생성합니다. (외부 커널 이름은 "custom_op::sparse_addmm") aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +#tuned_mm, tuned_addmm, tuned_bmm는 각각 mm, addmm, bmm 연산을 MLIR 템플릿으로 변환하는 함수입니다. + def tuned_mm(mat1, mat2, * ,layout=None): + # mm (행렬 곱) 연산을 받아 MLIR GEMM 템플릿으로 변환합니다. + # mm_args는 입력들의 형상/레이아웃 정보를 통일하고 (m,n,k,layout,mat1,mat2)를 반환합니다. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + # GEMM 템플릿을 생성합니다. 레이아웃 정보는 템플릿 생성에 필요합니다. mlir_template = MLIRGemmTemplate([mat1, mat2], layout) + # 템플릿을 생성하고 출력 IR 노드를 반환합니다. generate()는 템플릿의 IR 그래프를 구성합니다. return mlir_template.generate(input_nodes=[mat1, mat2], layout=layout).output_node() + def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + # addmm (alpha * mat1 @ mat2 + beta * inp) 연산을 MLIR GEMM 템플릿으로 변환합니다. + # mm_args는 입력/출력 크기에 맞춰 inp를 확장하거나 필요한 정보를 반환합니다. m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + # GEMM 템플릿에 inp(바이어스/누적 결과)를 포함하여 생성합니다. mlir_template = MLIRGemmTemplate([mat1, mat2, inp_expanded], layout) + # 생성된 템플릿의 출력 노드를 반환합니다. addmm의 결과를 IR상 노드로 대체합니다. return mlir_template.generate().output_node() + def tuned_bmm(mat1, mat2, *, layout=None): + # 배치 행렬곱(bmm)을 BMM 템플릿으로 변환합니다. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) mlir_template = MLIRBMMTemplate([mat1, mat2], layout) + # 템플릿을 생성 후 출력 노드를 반환합니다. return mlir_template.generate().output_node() + def conv_layout( x: TensorBox, weight: TensorBox, @@ -49,22 +71,30 @@ def conv_layout( output_padding: tuple[int, ...], groups: int, ) -> ir.Layout: - """Determine output layout for a convolution""" + """Determine output layout for a convolution + + 이 함수는 가상 실행(fake mode)으로 convolution 연산의 출력 텐서의 shape/stride를 구하고 + 그 결과로 Inductor의 FixedLayout을 반환합니다. 템플릿을 만들 때 출력 레이아웃을 알아야 하므로 필요합니다. + """ + # V.graph.fake_mode를 사용하면 실제 데이터를 계산하지 않고 크기 추론을 할 수 있습니다. with V.graph.fake_mode: + # aten.convolution을 호출하여 출력의 사이즈와 스트라이드를 얻습니다. output = torch.ops.aten.convolution( - ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(x, guard_shape=True), # TensorBox를 텐서 형태로 변환하여 크기 추론에 사용 ir.ir_node_to_tensor(weight, guard_shape=True), ir.ir_node_to_tensor(bias, guard_shape=True), stride, - tuple(V.graph.sizevars.size_hint(p) for p in padding), + tuple(V.graph.sizevars.size_hint(p) for p in padding), # padding 값은 sizevars를 통해 힌트를 사용 dilation, transposed, tuple(V.graph.sizevars.size_hint(p) for p in output_padding), groups, ) + # 출력의 사이즈/스트라이드를 Inductor 내부 포맷으로 변환합니다. sizes = ir.convert_shape_to_inductor(output.size()) stride = ir.convert_shape_to_inductor(output.stride()) + # FixedLayout은 장치, dtype, 크기, 스트라이드를 고정된 레이아웃으로 표현합니다. return ir.FixedLayout( x.get_device(), x.get_dtype(), @@ -72,6 +102,7 @@ def conv_layout( stride, ) + def convolution( x: TensorBox, weight: TensorBox, @@ -83,11 +114,13 @@ def convolution( output_padding: List[int], groups: int, ): + # 입력으로 들어오는 리스트들을 튜플로 바꿔 불변성을 보장하고 일관된 타입으로 사용합니다. stride = tuple(stride) padding = tuple(padding) dilation = tuple(dilation) output_padding = tuple(output_padding) + # 템플릿 생성 시 필요한 인자들을 kwargs로 모아둡니다. kwargs = { "stride": stride, "padding": padding, @@ -97,25 +130,35 @@ def convolution( "groups": groups, } + # TensorBox는 지연(lazy) 표현일 수 있기 때문에 실제 값을 생성해야 템플릿 생성에 사용할 수 있습니다. x.realize() weight.realize() + # 컨볼루션 템플릿은 채널-마지막 레이아웃을 기대할 때가 있어 보장하기 위한 헬퍼입니다. x = ir.ExternKernel.require_channels_last(x) + # 배치 및 입력 채널 수를 빠르게 확인하여 템플릿 분기를 결정합니다. BATCH = x.layout.size[0] I_C = x.layout.size[1] weight = ir.ExternKernel.require_channels_last(weight) + # 출력 레이아웃을 미리 계산하여 템플릿에 넘깁니다. layout = conv_layout(x, weight, None, **kwargs) - # Select conv kernel + # 적절한 컨볼루션 템플릿을 선택합니다. 싱글 배치, 스트라이드 여부, 멀티 타일 등 상황에 따라 분기합니다. if BATCH == 1 and stride[0] == 1 and extension_config.CONFIG_SINGLE_BATCH_CONV: + # 배치=1 & stride=1용 간소화된 구현 사용 mlir_template = MLIRConvSingleBatchTemplate([x, weight, bias], layout, **kwargs) elif BATCH == 1 and stride[0] != 1 and extension_config.CONFIG_SINGLE_BATCH_CONV: + # 배치=1 & stride!=1인 경우 다른 템플릿 사용 mlir_template = MLIRConvSingleBatchStridedTemplate([x, weight, bias], layout, **kwargs) elif I_C < extension_config.vpu_num_lanes // 8 and extension_config.CONFIG_MULTI_TILE_CONV: # 8 is hard-coded for now. This should be changed to a better heuristic. + # 입력 채널이 작아 멀티-타일 전략이 효과적일 때 multi-tile 템플릿 사용 mlir_template = MLIRConvMultiTileTemplate([x, weight, bias], layout, **kwargs) else: + # 기본 일반 컨볼루션 템플릿 사용 mlir_template = MLIRConvTemplate([x, weight, bias], layout, **kwargs) + # 생성된 템플릿의 출력 노드를 반환합니다. return mlir_template.generate().output_node() + def maxpool_layout( x: TensorBox, kernel_size: List[int], @@ -124,7 +167,11 @@ def maxpool_layout( dilation: List[int], ceil_mode: bool, ) -> ir.Layout: - """Determine output layout for a maxpool""" + """Determine output layout for a maxpool + + conv_layout와 유사하게 maxpool 연산의 출력 layout을 추론합니다. + """ + # fake_mode로 실제 데이터 없이 출력 크기를 추론합니다. with V.graph.fake_mode: output, _ = torch.ops.aten.max_pool2d_with_indices( ir.ir_node_to_tensor(x, guard_shape=True), @@ -134,9 +181,11 @@ def maxpool_layout( dilation, ceil_mode, ) + # 출력의 사이즈와 스트라이드를 Inductor 형식으로 변환 sizes = ir.convert_shape_to_inductor(output.size()) stride = ir.convert_shape_to_inductor(output.stride()) + # FixedLayout으로 반환 return ir.FixedLayout( x.get_device(), x.get_dtype(), @@ -144,6 +193,7 @@ def maxpool_layout( stride, ) + def custom_maxpool( x: TensorBox, kernel_size: List[int], @@ -152,6 +202,7 @@ def custom_maxpool( dilation: List[int] = [1, 1], ceil_mode: bool = False ): + # maxpool 호출을 템플릿화하여 MLIRMaxPoolTemplate로 처리합니다. kwargs = { "kernel_size": kernel_size, "stride": stride, @@ -159,33 +210,57 @@ def custom_maxpool( "dilation": dilation, "ceil_mode": ceil_mode, } + # 출력 레이아웃을 미리 계산 layout = maxpool_layout(x, kernel_size, stride, padding, dilation, ceil_mode) mlir_template = MLIRMaxPoolTemplate([x], layout, **kwargs) + # TensorBox 실체화 x.realize() + # 템플릿을 생성하여 출력 노드를 반환합니다. template_node = mlir_template.generate().output_node() + # indices(인덱스)는 현재 사용하지 않으므로 dummy x와 함께 반환합니다. (FIXME: indices 처리 필요) return template_node, x # FIXME: x is dummy IRNode, indices are not used in our case + def sparse_addmm(*args, **kwargs): - _, sp_mat1, sp_mat2 = args - mat1_layout = sp_mat1.layout + # 희소 행렬 연산의 예시적 외부 커널 매핑 + _, sp_mat1, sp_mat2 = args # 첫 인자는 out placeholder, 그 다음이 두 희소 행렬 + mat1_layout = sp_mat1.layout # 희소 행렬의 레이아웃을 참고 + # out의 range와 dims 정보를 사용해 출력 크기를 계산합니다. (구조체 접근은 API 의존적) out_range = args[0].data.data.data.ranges size = [out_range[i] for i in args[0].data.dims] + # FlexibleLayout을 만들어서 외부 커널에 필요한 레이아웃 정보를 제공합니다. layout = ir.FlexibleLayout( device=mat1_layout.device, dtype=mat1_layout.dtype, size=size # FIXME: Example code for aten op overwrite by externkernel call ) + # 외부 스파스 행렬 연산으로 바인딩하고 출력 노드를 반환합니다. return aten_spmm.bind((sp_mat1, sp_mat2), layout).output_node() + def custom_unsafe_index(x, indices): - # We can't fuse indirect access + indexed_expression + computation + # 안전하지 않은 인덱스 접근은 간단히 인덱스 구현으로 처리하되, + # TensorBox인 경우 실체화(realize)하여 실제 텐서/데이터가 준비되도록 합니다. + # 주석: 간접 접근(indirect access) + indexed_expression + computation은 fusion할 수 없습니다. if isinstance(x, TensorBox): x.realize() + # index_impl을 호출할 때 check=False로 하여 일부 검사를 건너뜁니다(unsafe 행동을 허용). return index_impl(x, indices, check=False) +def custom_foobar(a, *args, **kwargs): + a.realize() + layout = a.layout + mlir_template = MLIRFoobarTemplate([a], layout) + return mlir_template.generate().output_node() + + +# aten 연산 오버로드들을 위에서 정의한 커스텀 lowering 함수로 등록합니다. lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) + +lowerings.update({getattr(aten._foobar, overload): custom_foobar for overload in aten._foobar.overloads()}) +# 설정에 따라 max_pool2d_with_indices를 커스텀 구현으로 교체할 수 있습니다. (타이밍/측정용 풀링 구현) if extension_config.CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py index 3658f992..48bc476b 100644 --- a/PyTorchSimFrontend/mlir/mlir_maxpool_template.py +++ b/PyTorchSimFrontend/mlir/mlir_maxpool_template.py @@ -8,31 +8,34 @@ import sympy # This template only represents the DMA operations +# TEMPLATE defines the MLIR code for the max-pooling operation. +# The MLIR dialect used here includes `func.func` for function definitions, +# `affine.for` for loop constructs, and `memref` for memory operations. TEMPLATE = r""" -{{kernel.def_global_vars()}} +{{kernel.def_global_vars()}} # Define global variables for the kernel. func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - {{- kernel.def_local_vars(indent_size=2) }} - affine.for %index0 = 0 to {{ BCH }} step {{ out_tile }} { - affine.for %index1 = 0 to {{ W }} step {{ out_tile }} { - {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} - {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} - } { outer_loop=true } - } { outer_loop=true } - return + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} # Define SRAM buffer for input X. + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} # Define SRAM buffer for output Y. + {{- kernel.def_local_vars(indent_size=2) }} # Define local variables for the kernel. + affine.for %index0 = 0 to {{ BCH }} step {{ out_tile }} { # Outer loop over batch-channel-height. + affine.for %index1 = 0 to {{ W }} step {{ out_tile }} { # Inner loop over width. + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} # DMA operation to move input data into SRAM. + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} # DMA operation to move output data from SRAM to DRAM. + } { outer_loop=true } # Mark the inner loop as an outer loop for parallelization. + } { outer_loop=true } # Mark the outer loop as an outer loop for parallelization. + return # Return from the function. } """ class MLIRMaxPoolTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, kernel_size, stride, padding, dilation, ceil_mode, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.ceil_mode = ceil_mode + self.kernel_size = kernel_size # Size of the pooling kernel. + self.stride = stride # Stride of the pooling operation. + self.padding = padding # Padding applied to the input. + self.dilation = dilation # Dilation factor for the pooling kernel. + self.ceil_mode = ceil_mode # Whether to use ceil or floor for output size calculation. def render(self, kernel: MLIRTemplateKernel, @@ -41,58 +44,59 @@ def render(self, tile_info = None, **kwargs): if template_buffer_node is not None: - self.output_node = template_buffer_node + self.output_node = template_buffer_node # Set the output node if provided. if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = cast(Buffer, epilogue_nodes[-1]) - X = self.input_nodes[0] - Y = self.output_node - out_tile = kernel.vector_lane - in_tile = self.stride[0] * (out_tile - 1) + self.dilation[0] * (self.kernel_size[0] - 1) + 1 # padding should be considered? - 2 * self.padding - B = Y.get_size()[0] - C = Y.get_size()[1] - H = Y.get_size()[2] - W = Y.get_size()[3] - BCH = B * C * H - kernel.loop_size = None + self.output_node = cast(Buffer, epilogue_nodes[-1]) # Use the last epilogue node as the output. + X = self.input_nodes[0] # Input tensor. + Y = self.output_node # Output tensor. + out_tile = kernel.vector_lane # Tile size for the output. + in_tile = self.stride[0] * (out_tile - 1) + self.dilation[0] * (self.kernel_size[0] - 1) + 1 # Calculate input tile size. + + B = Y.get_size()[0] # Batch size. + C = Y.get_size()[1] # Number of channels. + H = Y.get_size()[2] # Height of the output tensor. + W = Y.get_size()[3] # Width of the output tensor. + BCH = B * C * H # Combined batch, channel, and height size. + kernel.loop_size = None # No specific loop size set. # Prepare tile descriptors - vlane_stride = 1 # Used dummy value - vlane_split_axis = 1 - X_tile_size = [in_tile, in_tile] - X_tile_stride = [1, in_tile] + vlane_stride = 1 # Stride for vector lanes (dummy value). + vlane_split_axis = 1 # Axis to split vector lanes. + X_tile_size = [in_tile, in_tile] # Tile size for input tensor. + X_tile_stride = [1, in_tile] # Stride for input tile. X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) - X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) - X_tile_desc.set_name("X_buffer") - X_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] # To keep index arguemnt order, we used index_list + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) # Set tile size and stride. + X_tile_desc.set_name("X_buffer") # Name the tile descriptor for input. + X_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] # Indexing for input tensor. - Y_tile_size = [out_tile, out_tile] - Y_tile_stride = [1, out_tile] + Y_tile_size = [out_tile, out_tile] # Tile size for output tensor. + Y_tile_stride = [1, out_tile] # Stride for output tile. Y_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) - Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) - Y_tile_desc.set_name("W_buffer") - Y_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) # Set tile size and stride. + Y_tile_desc.set_name("W_buffer") # Name the tile descriptor for output. + Y_idx = [sympy.Symbol("index0"), sympy.Symbol("index1")*W] # Indexing for output tensor. kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - X=X, - Y=Y, - BCH=BCH, - W=W, - out_tile=out_tile, - X_idx = X_idx, - Y_idx = Y_idx, - X_tile_desc = X_tile_desc, - Y_tile_desc = Y_tile_desc, - input_reorder = self.input_reorder + KERNEL_NAME=self.name, # Kernel name. + kernel=kernel, # Kernel object. + X=X, # Input tensor. + Y=Y, # Output tensor. + BCH=BCH, # Combined batch, channel, and height size. + W=W, # Width of the output tensor. + out_tile=out_tile, # Tile size for the output. + X_idx = X_idx, # Indexing for input tensor. + Y_idx = Y_idx, # Indexing for output tensor. + X_tile_desc = X_tile_desc, # Tile descriptor for input tensor. + Y_tile_desc = Y_tile_desc, # Tile descriptor for output tensor. + input_reorder = self.input_reorder # Input reorder option. ) kernel.epilogue_info = dict( - output_node = self.output_node.name, - sram_var = "Y_buffer", - dram_var = "Y", - dram_tile_desc = Y_tile_desc, + output_node = self.output_node.name, # Name of the output node. + sram_var = "Y_buffer", # SRAM variable for output. + dram_var = "Y", # DRAM variable for output. + dram_tile_desc = Y_tile_desc, # Tile descriptor for output in DRAM. ) - kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - kernel.add_loop_info([X.get_numel()], [kernel.vector_lane, kernel.vector_lane]) - return code + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} # Exception handling for output tensor. + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) # Render the MLIR code. + kernel.add_loop_info([X.get_numel()], [kernel.vector_lane, kernel.vector_lane]) # Add loop information. + return code # Return the rendered MLIR code. diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 23be941c..0eb1afd8 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -1,193 +1,193 @@ -import os -import math -import sympy -from functools import reduce -import operator -from sympy import symbols, sympify -from PyTorchSimFrontend import extension_config -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel - -from torch._inductor import config -from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode -from torch._inductor.utils import IndentedBuffer -from torch._inductor.virtualized import V -from torch._inductor.ir import LoopBody -from torch._inductor import dependencies - -from . import mlir_common -from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering - -class MLIRScheduling(BaseScheduling): - count = 0 - target_kernel = MLIRKernel - def __init__(self, scheduler): - self.scheduler = scheduler - self.scheduler.can_fuse_origin = self.scheduler.can_fuse - self.scheduler.can_fuse = self.can_fuse_with_exceptions - #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() - self._ready_to_flush = False - self.outer_function = set() - config.inplace_buffers = False # FIXME. inout kernel makes trouble.. So disabled it! - self.max_fusion_size = 5 - - def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: - # Extract base template node - base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] - base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] - if node1.get_device() != node2.get_device(): - return False - if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): - return False - - if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): - # For matmul/bmm+reduction case - size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) - target_symbol = symbols("r0") +import os +import math +import sympy +from functools import reduce +import operator +from sympy import symbols, sympify +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel + +from torch._inductor import config +from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode +from torch._inductor.utils import IndentedBuffer +from torch._inductor.virtualized import V +from torch._inductor.ir import LoopBody +from torch._inductor import dependencies + +from . import mlir_common +from . import mlir_lowering + + +class MLIRScheduling(BaseScheduling): + count = 0 + target_kernel = MLIRKernel # 사용할 MLIR 커널을 지정. + def __init__(self, scheduler): + self.scheduler = scheduler # 스케줄러를 인스턴스 변수로 저장합니다. + self.scheduler.can_fuse_origin = self.scheduler.can_fuse # 원래의 fusion 가능성을 설정합니다. + self.scheduler.can_fuse = self.can_fuse_with_exceptions # 예외가 있는 fusion 가능성을 설정합니다. + #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Inductor 버그 수정을 위한 몽키 패치 + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() # MLIRWrapperKernelGroup 인스턴스를 생성합니다. + self._ready_to_flush = False # 플러시 준비 상태를 초기화합니다. + self.outer_function = set() # 외부 함수 집합을 초기화합니다. + config.inplace_buffers = False # FIXME. inout 커널 문제로 비활성화합니다. + self.max_fusion_size = 5 # 최대 fusion 크기를 설정합니다. + + def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: # 예외가 있는 fusion 가능성을 확인하는 메서드 + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] # node1의 템플릿 노드를 가져옵니다. + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] # node2의 템플릿 노드를 가져옵니다. + if node1.get_device() != node2.get_device(): # 두 노드의 장치가 다르면 + return False # fusion 불가능 + if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): # 두 노드가 스케줄러 노드가 아니면 + return False # fusion 불가능 + + if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: # 특정 조건을 만족하면 + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate # GEMM 템플릿을 가져옵니다. + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate # BMM 템플릿을 가져옵니다. + if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): # 매트릭스 곱셈과 축소의 경우 + # 매트릭스 곱셈/배치 매트릭스 곱셈 + 축소의 경우 + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) # 크기 일치 여부 확인 + target_symbol = symbols("r0") # 기호 r0을 정의합니다. try: - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] - stride = int(sympify(stride).coeff(target_symbol)) + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] # stride를 가져옵니다. + stride = int(sympify(stride).coeff(target_symbol)) # 기호를 사용하여 stride를 정수로 변환합니다. except: - return False - - # We can't fuse dim=-1 - layout_possible = stride != 1 - # Directed linked? - dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 - dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) - return size_match and layout_possible and dependency_check and dependency_size - - # For prologue fusion case - if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': - return False - if node1.is_reduction(): - return False - if len(node1.read_writes.writes) != 1: - return False - if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME - return False - - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): - return False - # We don't fuse this edge case... - if base_template_node2[0].group[1][0][0] == 1: - return False - - if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: - node1 = self.revert_group(node1) - return True - - return self.scheduler.can_fuse_origin(node1, node2) - - def _set_flush_status(self, status: bool): - self._ready_to_flush = status - - def can_fuse_vertical(self, node1, node2): - return self.can_fuse_horizontal(node1, node2) - - def can_fuse_horizontal(self, node1, node2): - if not extension_config.CONFIG_FUSION: - return False - if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: - return False - _, (vars1, reduce1) = node1.group - _, (vars2, reduce2) = node2.group - - # Reduction is currently not supported - if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: - return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users - if node1.is_reduction() or node2.is_reduction(): - return False - - # Can't fuse two template node - if node1.is_template() and node2.is_template(): - return False - - if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: - return False - - # Check template node fusion - if node1.is_template() or node2.is_template(): - # Don't fuse maxpool template code - from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) - template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) + return False # 예외 발생 시 fusion 불가능 + + # dim=-1로 fusion 불가능 + layout_possible = stride != 1 # 레이아웃 가능 여부 확인 + # 방향성 링크 확인 + dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users] # 의존성 확인 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) # 의존성 크기 확인 + return size_match and layout_possible and dependency_check and dependency_size # 모든 조건이 만족되면 fusion 가능 + + # 프로로그 fusion의 경우 + if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: # 특정 조건을 만족하면 + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate # GEMM 템플릿을 가져옵니다. + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate # BMM 템플릿을 가져옵니다. + target_node = base_template_node2[0].node # 타겟 노드를 설정합니다. + if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': # 특정 조건을 확인합니다. + return False # fusion 불가능 + if node1.is_reduction(): # node1이 축소 노드이면 + return False # fusion 불가능 + if len(node1.read_writes.writes) != 1: # node1의 쓰기 수가 1이 아니면 + return False # fusion 불가능 + if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): # node1이 타겟 노드의 입력이 아니면 + return False # fusion 불가능 + + # 현재 BMM과 MM만 프로로그 fusion을 지원합니다. + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): # 타겟 노드가 BMM 또는 GEMM 템플릿이 아니면 + return False # fusion 불가능 + # 이 엣지는 fusion하지 않습니다. + if base_template_node2[0].group[1][0][0] == 1: # 특정 조건을 확인합니다. + return False # fusion 불가능 + + if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: # node1의 쓰기가 node2의 읽기와 일치하면 + node1 = self.revert_group(node1) # node1의 그룹을 되돌립니다. + return True # fusion 가능 + + return self.scheduler.can_fuse_origin(node1, node2) # 기본 fusion 가능성 확인 + + def _set_flush_status(self, status: bool): # 플러시 상태를 설정하는 메서드 + self._ready_to_flush = status # 플러시 준비 상태를 업데이트합니다. +#flush = 스케줄러가 모아둔 커널 그룹을 실제 코드로 반환 + def can_fuse_vertical(self, node1, node2): # 수직 fusion 가능성 확인 + return self.can_fuse_horizontal(node1, node2) # 수평 fusion 가능성 확인 + + def can_fuse_horizontal(self, node1, node2): # 수평 fusion 가능성 확인 + if not extension_config.CONFIG_FUSION: # fusion이 비활성화되어 있으면 + return False # fusion 불가능 + if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: # 두 노드의 크기가 최대 fusion 크기를 초과하면 + return False # fusion 불가능 + _, (vars1, reduce1) = node1.group # node1의 변수와 축소 정보를 가져옵니다. + _, (vars2, reduce2) = node2.group # node2의 변수와 축소 정보를 가져옵니다. + + # 축소는 현재 지원되지 않습니다. + if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: # 두 노드가 축소 노드이면 + return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users # 변수와 축소 정보가 일치하면 fusion 가능 + if node1.is_reduction() or node2.is_reduction(): # 하나라도 축소 노드이면 + return False # fusion 불가능 + + # 두 템플릿 노드는 fusion할 수 없습니다. + if node1.is_template() and node2.is_template(): # 두 노드가 템플릿이면 + return False # fusion 불가능 + + if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: # 안전하지 않은 인덱스가 포함되어 있으면 + return False # fusion 불가능 + + # 템플릿 노드 fusion 확인 + if node1.is_template() or node2.is_template(): # 하나라도 템플릿이면 + # maxpool 템플릿 코드는 fusion하지 않습니다. + from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate # MaxPool 템플릿을 가져옵니다. + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate # BMM 템플릿을 가져옵니다. + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate # GEMM 템플릿을 가져옵니다. + template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) # node1의 템플릿 노드를 가져옵니다. + template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) # node2의 템플릿 노드를 가져옵니다. if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ - template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): - return False - - # Pointwise check - v1_total = math.prod(vars1) if len(vars1) else 0 - v2_total = math.prod(vars2) if len(vars2) else 0 - if v1_total != v2_total: - return False - - # Pattern check - template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) - has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) - if not has_depedency: - return False - - # Revert act_node.group : simplify_and_reorder() modified _body, _size, group - if template_node.group != act_node.group: - # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: - return False - - if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): - return False - self.revert_group(act_node) - return True - - # Check elementwise fusion - if vars1 == vars2 and reduce1 == reduce2: - return True - return False - - def revert_group(self, act_nodes, args=None, var_ranges=None): - for act_node in act_nodes.get_nodes(): - if args is None or var_ranges is None: + template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): # MaxPool 템플릿이면 + return False # fusion 불가능 + + # 포인트별 확인 + v1_total = math.prod(vars1) if len(vars1) else 0 # node1의 변수 곱을 계산합니다. + v2_total = math.prod(vars2) if len(vars2) else 0 # node2의 변수 곱을 계산합니다. + if v1_total != v2_total: # 두 변수 곱이 일치하지 않으면 + return False # fusion 불가능 + + # 패턴 확인 + template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) # 템플릿 노드와 활성 노드를 설정합니다. + has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) # 의존성 확인 + if not has_depedency: # 의존성이 없으면 + return False # fusion 불가능 + + # act_node.group 되돌리기: simplify_and_reorder()가 수정한 _body, _size, group + if template_node.group != act_node.group: # 그룹이 다르면 + # 이 경우 fusion하지 않습니다. + if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: # 특정 조건을 확인합니다. + return False # fusion 불가능 + + if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): # 크기가 다르면 + return False # fusion 불가능 + self.revert_group(act_node) # act_node의 그룹을 되돌립니다. + return True # fusion 가능 + + # 요소별 fusion 확인 + if vars1 == vars2 and reduce1 == reduce2: # 변수와 축소 정보가 일치하면 + return True # fusion 가능 + return False # fusion 불가능 + + def revert_group(self, act_nodes, args=None, var_ranges=None): # 그룹을 되돌리는 메서드 + for act_node in act_nodes.get_nodes(): # 각 활성 노드에 대해 + if args is None or var_ranges is None: # 인자나 변수 범위가 주어지지 않으면 args, var_ranges = dependencies.index_vars_no_squeeze( act_node.node.data.get_size(), act_node.node.data.get_reduction_size(), prefix="q" - ) + ) # 변수 인덱스를 설정합니다. body = LoopBody( act_node.node.get_store_function(), (args if act_node.node.get_reduction_type() else args[:1]), var_ranges, - ) - index_size = [] - reduce_size = [] - for v, s in var_ranges.items(): - if v in args[0]: - index_size.append(s) - else: - reduce_size.append(s) - node_device = act_node.get_device() - ranges = (index_size, reduce_size) - act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) - - def group_fn(self, sizes): - return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - - def codegen_nodes(self, nodes): + ) # 루프 본체를 정의합니다. + index_size = [] # 인덱스 크기 초기화 + reduce_size = [] # 축소 크기 초기화 + for v, s in var_ranges.items(): # 각 변수와 크기에 대해 + if v in args[0]: # 인자에 변수가 포함되어 있으면 + index_size.append(s) # 인덱스 크기에 추가 + else: # 그렇지 않으면 + reduce_size.append(s) # 축소 크기에 추가 + node_device = act_node.get_device() # 노드의 장치를 가져옵니다. + ranges = (index_size, reduce_size) # 인덱스와 축소 크기 범위를 설정합니다. + act_node._sizes, act_node._body, act_node.group = (ranges), body, (node_device, self.group_fn(ranges)) # 노드의 크기, 본체, 그룹을 업데이트합니다. + + def group_fn(self, sizes): # 그룹 함수를 정의합니다. + return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) # 크기를 단순화하여 튜플로 반환합니다. + + def codegen_nodes(self, nodes): # 노드에 대한 코드를 생성하는 메서드 _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) - ).group + ).group # 가장 큰 축소 그룹을 찾습니다. - # Note: We assume that ther is at least one loop in the nodes - # But, inductor simplifies the group, there could be no loop - # In that case, we add dummy loop(size=1) to the group + # 노드에 루프가 적어도 하나 있다고 가정합니다. + # 그러나 인덕터가 그룹을 단순화하므로 루프가 없을 수도 있습니다. + # 그런 경우, 더미 루프(크기=1)를 그룹에 추가합니다. if len(group) == 0: for idx, node in enumerate(nodes): if len(node.node.data.get_size()) == 0: @@ -200,109 +200,110 @@ def codegen_nodes(self, nodes): sym0 = sympy.Symbol("q0") args = [[sym0] + [sympy.Number(0)] * (len(node.node.data.get_size())-1), []] var_ranges = {sym0: sympy.Number(1)} - self.revert_group(node, args, var_ranges) + self.revert_group(node, args, var_ranges) # 노드 그룹을 되돌립니다. _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) - ).group + ).group # 다시 가장 큰 축소 그룹을 찾습니다. - ex_kernel = self.target_kernel(kernel_group=self.kernel_group) - ex_kernel.kernel_group = self.kernel_group + ex_kernel = self.target_kernel(kernel_group=self.kernel_group) # 타겟 커널 인스턴스를 생성합니다. + ex_kernel.kernel_group = self.kernel_group # 커널 그룹을 설정합니다. - kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" - MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" # 커널 이름 후보를 설정합니다. + MLIRScheduling.count += 1 # 클래스 변수를 증가시킵니다. + src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) # 노드에 대한 코드를 생성합니다. kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) - ex_kernel.call_kernel(kernel_name) - _, args, _, _ = ex_kernel.args.mlir_argdefs() - args = ", ".join(args) - eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) + ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) # 커널을 정의합니다. + ex_kernel.call_kernel(kernel_name) # 커널을 호출합니다. + _, args, _, _ = ex_kernel.args.mlir_argdefs() # 커널 인자 정보를 가져옵니다. + args = ", ".join(args) # 인자를 문자열로 변환합니다. + eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) # EAGER 모드 설정을 가져옵니다. if (eager_mode): V.graph.wrapper_code.writeline( f"yield ({kernel_name}, ({args}))" - ) - self._set_flush_status(True) - - def ready_to_flush(self): - return self._ready_to_flush - - def codegen_sync(self): - pass - - def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() - self._set_flush_status(False) - - def define_function(self, kernel): - partial_code, function_name = kernel.def_function() - if partial_code is not None and function_name not in self.outer_function: - with V.set_kernel_handler(kernel): - code = partial_code.finalize() - wrapper = V.graph.wrapper_code - wrapper.header.writeline(code) - self.outer_function.add(function_name) - - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): - wrapper = V.graph.wrapper_code - if src_code in wrapper.src_to_kernel: - kernel_name = wrapper.src_to_kernel[src_code] + ) # EAGER 모드에서 커널 호출을 기록합니다. + self._set_flush_status(True) # 플러시 준비 상태를 True로 설정합니다. + + + def ready_to_flush(self): # 플러시 준비 상태 확인 + return self._ready_to_flush # 현재 플러시 준비 상태를 반환합니다. + + def codegen_sync(self): # 동기화 코드 생성을 위한 메서드 + pass # 현재 구현은 비워둡니다. + + def flush(self): # 플러시 메서드 + self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) # 커널 그룹에 대한 정의 및 호출을 생성합니다. + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() # 새로운 MLIRWrapperKernelGroup 인스턴스를 생성합니다. + self._set_flush_status(False) # 플러시 준비 상태를 False로 설정합니다. + + def define_function(self, kernel): # 커널에 대한 함수를 정의하는 메서드 + partial_code, function_name = kernel.def_function() # 커널에서 부분 코드를 가져옵니다. + if partial_code is not None and function_name not in self.outer_function: # 유효한 코드와 함수 이름 확인 + with V.set_kernel_handler(kernel): # 커널 핸들러 설정 + code = partial_code.finalize() # 부분 코드를 최종화합니다. + wrapper = V.graph.wrapper_code # 그래프의 래퍼 코드를 가져옵니다. + wrapper.header.writeline(code) # 헤더에 코드를 기록합니다. + self.outer_function.add(function_name) # 외부 함수 집합에 함수 이름을 추가합니다. + + def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): # 커널 정의 메서드 + wrapper = V.graph.wrapper_code # 그래프의 래퍼 코드를 가져옵니다. + if src_code in wrapper.src_to_kernel: # 소스 코드가 이미 등록되어 있으면 + kernel_name = wrapper.src_to_kernel[src_code] # 기존 커널 이름을 사용합니다. else: - wrapper.src_to_kernel[src_code] = kernel_name - - codecache_def = IndentedBuffer() - codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") - codecache_def.writeline(f"vectorlane_size={vector_lane},") - codecache_def.writeline(f"loop_size={loop_size},") - codecache_def.writeline(f"spad_info={spad_info},") - codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes,") - codecache_def.writeline(f"vlen={extension_config.vpu_vector_length_bits})") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) - return kernel_name - - def codegen_template(self, template_node, epilogue_nodes): - # Handle prologue pattern - prologue_nodes = [] - if not template_node.is_template(): - epilogue_nodes = [template_node] + epilogue_nodes - for i, node in enumerate(epilogue_nodes): - if node.is_template(): - template_node = node - prologue_nodes = epilogue_nodes[:i] - epilogue_nodes = epilogue_nodes[i+1:] + wrapper.src_to_kernel[src_code] = kernel_name # 새로운 소스 코드를 등록합니다. + + codecache_def = IndentedBuffer() # 코드 캐시 정의를 위한 버퍼 생성 + codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") # MLIR 커널 컴파일 호출 + codecache_def.writeline(f"vectorlane_size={vector_lane},") # 벡터 레인 크기 설정 + codecache_def.writeline(f"loop_size={loop_size},") # 루프 크기 설정 + codecache_def.writeline(f"spad_info={spad_info},") # SPAD 정보 설정 + codecache_def.writeline(f"origins={origins},") # 기원 정보 설정 + codecache_def.writeline("arg_attributes=arg_attributes,") # 인자 속성 설정 + codecache_def.writeline(f"vlen={extension_config.vpu_vector_length_bits})") # VPU 벡터 길이 설정 + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) # 커널을 정의합니다. + return kernel_name # 커널 이름 반환 + + def codegen_template(self, template_node, epilogue_nodes): # 템플릿 코드 생성을 위한 메서드 + # 프로로그 패턴 처리 + prologue_nodes = [] # 프로로그 노드 초기화 + if not template_node.is_template(): # 템플릿 노드가 아니면 + epilogue_nodes = [template_node] + epilogue_nodes # 에필로그 노드에 템플릿 노드를 추가합니다. + for i, node in enumerate(epilogue_nodes): # 각 에필로그 노드에 대해 + if node.is_template(): # 노드가 템플릿이면 + template_node = node # 템플릿 노드를 업데이트합니다. + prologue_nodes = epilogue_nodes[:i] # 프로로그 노드를 설정합니다. + epilogue_nodes = epilogue_nodes[i+1:] # 나머지 에필로그 노드를 업데이트합니다. break - # Generate template code - template_buffer = template_node.node - kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) - _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + # 템플릿 코드 생성 + template_buffer = template_node.node # 템플릿 노드의 버퍼를 가져옵니다. + kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) # 커널 렌더링을 위한 설정 + _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() # 버퍼 타입 정보를 가져옵니다. + src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) # 템플릿 노드에 대한 코드를 생성합니다. - with V.set_kernel_handler(kernel): + with V.set_kernel_handler(kernel): # 커널 핸들러 설정 kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.loop_size, origins={str(i) for i in template_node.node.origins}) - self.define_function(kernel) + kernel.loop_size, origins={str(i) for i in template_node.node.origins}) # 커널 정의 + self.define_function(kernel) # 커널에 대한 함수 정의 - kernel.call_kernel(kernel_name) - V.graph.removed_buffers |= kernel.removed_buffers - _, args, _, _ = self.kernel_group.args.mlir_argdefs() - eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) + kernel.call_kernel(kernel_name) # 커널 호출 + V.graph.removed_buffers |= kernel.removed_buffers # 제거된 버퍼 업데이트 + _, args, _, _ = self.kernel_group.args.mlir_argdefs() # 커널 인자 정보 가져오기 + eager_mode = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) # EAGER 모드 설정 가져오기 if (eager_mode): - target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name + f"_{len(args)}" - args = ", ".join(args) + target_kernel_name = kernel_name if kernel.outer_func_name is None else kernel.outer_func_name + f"_{len(args)}" # 타겟 커널 이름 설정 + args = ", ".join(args) # 인자를 문자열로 변환 V.graph.wrapper_code.writeline( f"yield ({target_kernel_name}, ({args}))" - ) - self._set_flush_status(True) - - def enter_context_fixed(self, node): - def get_order(n): - if n not in self.scheduler.origin_to_index: - self.scheduler.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) - return self.scheduler.origin_to_index[n] - - origins = [(get_order(e), idx, e) for n in node.get_nodes() for idx, e in enumerate(n.node.origins)] - if origins: - _, _, last = max(origins) - V.graph.wrapper_code.enter_context(last) + ) # EAGER 모드에서 커널 호출 기록 + self._set_flush_status(True) # 플러시 준비 상태를 True로 설정 + + def enter_context_fixed(self, node): # 컨텍스트 진입을 위한 고정 메서드 + def get_order(n): # 노드 순서를 가져오는 내부 함수 + if n not in self.scheduler.origin_to_index: # 노드가 인덱스에 없으면 + self.scheduler.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) # 노드 인덱스를 업데이트합니다. + return self.scheduler.origin_to_index[n] # 노드의 인덱스를 반환합니다. + + origins = [(get_order(e), idx, e) for n in node.get_nodes() for idx, e in enumerate(n.node.origins)] # 노드 기원 정보 가져오기 + if origins: # 기원 정보가 있으면 + _, _, last = max(origins) # 가장 마지막 기원 노드를 찾습니다. + V.graph.wrapper_code.enter_context(last) # 컨텍스트 진입 diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index e493464a..5e0bc75c 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -1,54 +1,66 @@ -import functools -import itertools -import textwrap -import re -import os -import contextlib -import math -import sympy -from functools import reduce -import operator -from collections import OrderedDict - -from typing import List, Optional -from unittest.mock import patch - -from torch._inductor.codegen.common import KernelTemplate, ChoiceCaller, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer -from torch._inductor.select_algorithm import PartialRender -from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller -from torch._inductor.autotune_process import TensorMeta -from torch._inductor.virtualized import V, NullHandler, _ops as ops -from torch._inductor.utils import IndentedBuffer -from torch._inductor.codecache import write_atomic - -import PyTorchSimFrontend.extension_codecache as extension_codecache -from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest -from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo -from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction -from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode -from torch._inductor.codegen import common - -from PyTorchSimFrontend import extension_config -from . import mlir_common +# mlir_template.py +# MLIR 템플릿 기반의 커널 생성/타일링/매핑 유틸리티들을 포함하는 파일입니다. + +import functools # 함수 헬퍼(예: partial) 사용 +import itertools # 반복자 조합, 순열 등 유틸리티 +import textwrap # 코드 블록 정렬/포맷에 사용 +import re # 정규 표현식 처리 +import os # 파일/디렉터리 조작 +import contextlib # 컨텍스트 매니저 유틸리티 +import math # 수학 함수 (ceil, sqrt 등) +import sympy # 기호 수학(약수, 분해 등) 유틸리티 +from functools import reduce # 시퀀스 누적 연산에 사용 +import operator # 연산자 함수(곱셈 등) 사용 +from collections import OrderedDict # 순서 유지 딕셔너리 + +from typing import List, Optional # 타입 힌트 +from unittest.mock import patch # 테스트나 임시 패치용 + +# Inductor 내부의 공통 템플릿/유틸 가져오기 +from torch._inductor.codegen.common import KernelTemplate, ChoiceCaller, CSE, DeferredLine # 코드 생성 공통 유틸 +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer # IR 관련 타입 +from torch._inductor.select_algorithm import PartialRender # 부분 렌더링 도우미 +from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller # CUDA 호출러(참조) +from torch._inductor.autotune_process import TensorMeta # 오토튜닝을 위한 텐서 메타 +from torch._inductor.virtualized import V, NullHandler, _ops as ops # 가상화 유틸 +from torch._inductor.utils import IndentedBuffer # 들여쓰기 지원 버퍼 +from torch._inductor.codecache import write_atomic # 코드 캐시 쓰기 유틸 + +# 확장(Frontend) 관련 모듈 +import PyTorchSimFrontend.extension_codecache as extension_codecache # 확장된 코드 캐시 구현 +from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest # MLIR용 벤치 요청 구조 +from PyTorchSimFrontend.mlir.mlir_common import BaseMLIRHardwareInfo # 하드웨어 관련 정보 베이스 +from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel, reduction_init, reduction_partial_combine_vec, reduction_combine_vec, is_welford_reduction # MLIR 코드 생성 백엔드 함수 +from PyTorchSimFrontend.mlir.mlir_scheduling import SchedulerNode # 스케줄링 관련 노드 타입 +from torch._inductor.codegen import common # 공통 코드 생성 유틸 + +from PyTorchSimFrontend import extension_config # 확장 설정 로드 +from . import mlir_common # 같은 패키지의 공용 유틸 class IndentedBufferGroup: + """여러 IndentedBuffer( loads/compute/stores 등)를 그룹화하여 임시로 커널에 적용/복원하는 유틸. + + 사용 예: prologue/epilogue 등 특정 블록에서 별도의 버퍼로 코드 생성을 수행한 뒤 원상 복귀. + """ def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): + # kernel 참조와 여러 목적의 IndentedBuffer를 초기화합니다. self.kernel = kernel - self.body = IndentedBuffer() - self.loads = IndentedBuffer() - self.compute = IndentedBuffer() - self.stores = IndentedBuffer() - self.applys = IndentedBuffer() - self.dma_loads = IndentedBuffer() - self.dma_stores = IndentedBuffer() - self.spad_buffer = IndentedBuffer() + self.body = IndentedBuffer() # 전체 바디용 + self.loads = IndentedBuffer() # 로드 라인용 + self.compute = IndentedBuffer() # 계산 라인용 + self.stores = IndentedBuffer() # 저장 라인용 + self.applys = IndentedBuffer() # 후처리용 라인 + self.dma_loads = IndentedBuffer() # DMA 로드 전용 + self.dma_stores = IndentedBuffer() # DMA 저장 전용 + self.spad_buffer = IndentedBuffer() # 스패드 관련 라인 + # CSE(공통 하위식 제거) 인스턴스들: 이름 접두사로 구분 self.cse = common.CSE("%", "", name_prefix=f"{prefix}") self.apply_cse = common.CSE("%", "", name_prefix=f"{prefix}apply") - # Original buffers will be saved later in the 'with' block + # with 블록 진입 전 원래 버퍼들을 저장하기 위한 사전 self.original_buffers = {} def set_buffers(self): + # 현재 그룹의 버퍼들을 실제 커널의 속성으로 설정하여, 이후 생성되는 코드가 여기에 기록되게 합니다. self.kernel.loads = self.loads self.kernel.compute = self.compute self.kernel.stores = self.stores @@ -60,6 +72,7 @@ def set_buffers(self): self.kernel.apply_cse = self.apply_cse def restore_buffers(self): + # 저장해둔 원래 버퍼들을 복원합니다. self.kernel.loads = self.original_buffers['loads'] self.kernel.compute = self.original_buffers['compute'] self.kernel.stores = self.original_buffers['stores'] @@ -72,6 +85,7 @@ def restore_buffers(self): @contextlib.contextmanager def as_local(self): + # 컨텍스트 진입 시 현재 커널의 버퍼들을 저장하고 그룹 버퍼로 교체합니다. self.original_buffers = { 'loads': self.kernel.loads, 'compute': self.kernel.compute, @@ -84,12 +98,16 @@ def as_local(self): 'apply_cse': self.kernel.apply_cse, } try: - self.set_buffers() + self.set_buffers() # 그룹 버퍼로 교체 yield self finally: - self.restore_buffers() + self.restore_buffers() # 종료 시 복원 class MLIRTemplateKernel(MLIRKernel, BaseMLIRHardwareInfo): + """MLIR 기반 템플릿 커널을 표현하는 핵심 클래스입니다. + + 이 클래스는 템플릿 렌더링에 필요한 메타데이터, 루프/타일 정보, CSE, prologue/epilogue 버퍼 그룹 등을 관리합니다. + """ def __init__(self, kernel_name, input_nodes, @@ -99,28 +117,38 @@ def __init__(self, outer_func_render=None, kernel_arg_attributes=None, reason=None) -> None: + # MLIRKernel 초기화: kernel_group이 주어지지 않으면 기본 Wrapper 그룹 사용 super().__init__(kernel_group if kernel_group is not None else mlir_common.MLIRWrapperKenrelGroup()) + # 식별자 및 입력/콜 사이즈 저장 self.kernel_name = kernel_name self.input_nodes = input_nodes self.call_size = call_size + # 노드 이름과 루프 정보를 위한 컨테이너 self.named_nodes = {} self.loop_info = {} + # outer function 관련 선택적 정보 self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render + # 커널 인자 속성을 외부에서 주입 가능 self.kernel_arg_attributes = kernel_arg_attributes + # 렌더 후크, 버퍼 이름, 렌더 옵션 self.render_hooks = OrderedDict() self.buffer_names = dict() self.render_options = dict() + # 타일/루프 관련 변수 self.tile_size = [] self.loop_size = None + # CSE(공통 하위식 제거) 인스턴스들: 맵/상수/할당 식별자에 사용 self.map_cse = CSE("#", self.suffix, name_prefix="t_map") self.const_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_const") self.alloc_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="t_alloc") + # Prologue/Epilogue에서 별도 버퍼 관리를 위한 그룹 self.prologue_buffer_group = IndentedBufferGroup(self, prefix="prologue_") self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") + # 전역 변수와 예외 노드 저장 self.global_vars = IndentedBuffer() self.exception_nodes = {} - # Reduction data structure + # Reduction 관련 상태와 버퍼 self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False self.reduction_body_loop = None @@ -128,11 +156,15 @@ def __init__(self, self.reduction_info = {} self.reduction_epilogue_result = {} self.reduction_mean = [] - # Dim info + # 차원(alias) 정보 및 이유(reason) self.dim_aliasing = {} self.reason = reason def reset(self, reason): + """커널 상태를 주어진 reason으로 재초기화합니다. + + 테스트나 재사용 시 인스턴스를 초기 상태로 되돌리기 위해 사용합니다. + """ self.__init__( self.kernel_name, self.input_nodes, self.call_size, self.kernel_group, @@ -141,7 +173,12 @@ def reset(self, reason): ) def add_loop_info(self, mat_size, tile_size): + """행렬 및 타일 크기로부터 각 루프 인덱스의 [start, end, stride] 정보를 생성하여 저장합니다. + + mat_size: 전체 루프 범위, tile_size: 각 차원에서의 타일 크기(스트라이드) + """ for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): + # index0, index1, ... 형태의 키로 루프 정보를 저장 self.loop_info[f"index{idx}"] = [0, loop_size, stride] def gemmini_gemm_mapping(self, M, N, K): @@ -200,40 +237,56 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + """GEMM용 타일 후보들을 생성하고 휴리스틱으로 우수 후보를 선택합니다. + + 고려 항목: 스패드 사용량, lane 당 사용량, weight reuse, 최소 타일 수 등 + """ tile_candidates = [] + # 스패드/레인/정밀도 정보 spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane - max_spad_size = spad_size // 2 # double buffer - max_spad_per_lane = spad_size_per_lane // 2 # double buffer + max_spad_size = spad_size // 2 # double buffer을 고려한 최대 사용 가능 스패드 + max_spad_per_lane = spad_size_per_lane // 2 # lane 당 최대 스패드 minimum_n_tile = self.num_cores if min_tile else 1 + # 패딩 팩터 결정: 벡터 lane 단위로 패딩하거나 기본값 8을 사용 m_pad_factor = self.vector_lane if M > self.vector_lane else 8 n_pad_factor = self.vector_lane if N > self.vector_lane else 8 k_pad_factor = self.vector_lane if K > self.vector_lane else (8 if pad_k else 1) K = max(K, 8) + # 차원을 패딩하여 정렬 단위를 맞춤 M_padded = ((M + m_pad_factor - 1) // m_pad_factor) * m_pad_factor N_padded = ((N + n_pad_factor - 1) // n_pad_factor) * n_pad_factor K_padded = ((K + k_pad_factor - 1) // k_pad_factor) * k_pad_factor indexI, indexJ, indexK = (M_padded // self.vector_lane, N_padded // self.vector_lane, K_padded // self.vector_lane) max_used_spad_size = 0 - mapping = (self.vector_lane, self.vector_lane, self.vector_lane) + mapping = (self.vector_lane, self.vector_lane, self.vector_lane) # 기본 매핑 + # 타일 분할 후보의 약수를 이용하여 후보 범위를 만듭니다. tile_M_range = sympy.divisors(indexI) if M > self.vector_lane else [1] tile_N_range = sympy.divisors(indexJ) if N > self.vector_lane else [1] tile_K_range = sympy.divisors(indexK) if K > self.vector_lane else [1] - maximize_i_j = 1 # reuse weight - for k in tile_K_range: # store tile candidates for manual mapping + maximize_i_j = 1 # weight reuse를 극대화하기 위한 보조 변수 + for k in tile_K_range: # K 차원의 타일 후보 반복 (각 k는 factor) + # tile_K: 실제 타일의 K 크기. K가 vector_lane보다 큰 경우 벡터 레인 단위로 확장 tile_K = k * self.vector_lane if K > self.vector_lane else K_padded - for i in tile_M_range: + for i in tile_M_range: # M 차원 타일 후보 반복 + # tile_M: M 차원의 실제 타일 크기 (vector lane 단위 또는 패딩된 값) tile_M = i * self.vector_lane if M > self.vector_lane else M_padded - for j in tile_N_range: + for j in tile_N_range: # N 차원 타일 후보 반복 + # tile_N: N 차원의 실제 타일 크기 tile_N = j * self.vector_lane if N > self.vector_lane else N_padded + # 다음으로 각 후보에 대해 필요한 스패드 사용량(입력, 가중치, 출력 포함)을 추정합니다. + # used_spad_size는 전체 스패드 사용량(바이트 단위, precision을 곱함)을 의미합니다. used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision - weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) - input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) - output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) + # lane 당 가중치/입력/출력의 크기를 계산하여 lane 분산 관점에서의 사용량을 추정합니다. + weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) # 가중치 크기 per lane + input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) # 입력 크기 per lane + output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) # 출력 크기 per lane + # lane 당 사용량들을 합쳐 실제 lane 단위로 필요한 스패드 사용량을 계산합니다. used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: + # 디렉터리/파일에 후보를 기록하여 외부 검증/수집에 사용합니다. dir_path = f"{extension_config.CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" os.makedirs(dir_path, exist_ok=True) file_path = f"{dir_path}/gemm_{M}_{K}_{N}.txt" @@ -247,6 +300,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p with open(file_path, "a") as f: f.write(line_to_write) + # 휴리스틱 탐색: 후보들을 평가하여 최적 후보를 선정 for k in tile_K_range: # heuristic search tile_K = k * self.vector_lane if K > self.vector_lane else K_padded for i in tile_M_range: @@ -258,8 +312,10 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + # 전체 매트릭스에 필요한 타일 수 예측 (너무 작은 타일은 불리함) n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) + # 다양한 기준을 결합해 우수 후보 선정: 스패드 사용량, weight reuse, 최소 타일 수 등 if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: max_used_spad_size = used_spad_size maximize_i_j = tile_M * tile_N @@ -267,28 +323,37 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p if check_spad_size: tile_candidates.append((used_spad_size, (tile_M, tile_N, tile_K))) + # 사용량 기준으로 후보 정렬 및 반환 tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) tile_candidates = [v for _, v in tile_candidates] return tile_candidates def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + """컨볼루션을 GEMM으로 근사하여 타일 후보를 생성합니다. + + 변수 설명: K_H/K_W 필터 차원, O_H/O_W 출력 차원, stride/dilation 등의 파라미터를 고려합니다. + """ tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane - max_spad_size = spad_size // 2 # double buffer - max_spad_per_lane = spad_size_per_lane // 2 # double buffer + max_spad_size = spad_size // 2 # double buffer 고려 + max_spad_per_lane = spad_size_per_lane // 2 # lane 당 최대 + # 후보 선정용 보조 변수 max_used_spad_size = 0 + # 먼저 GEMM 근사 값으로 M,N,K를 구합니다 (conv->GEMM 변환 관점) M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] - max_k_h_w = 1 # maximize kernel size - max_o_h_w = 1 # maximize output size - K = min(K, self.vector_lane) + max_k_h_w = 1 # kernel size 최대화 보조 + max_o_h_w = 1 # output size 최대화 보조 + K = min(K, self.vector_lane) # K는 vector lane 이하로 제한 for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): for k_h in sympy.divisors(K_H): for k_w in sympy.divisors(K_W): + # 입력(ih,iw) 크기 계산: output/stride/dilation 고려 i_h = 1 + (o_h - 1) * stride[0] + (k_h - 1) * dilation[0] i_w = 1 + (o_w - 1) * stride[1] + (k_w - 1) * dilation[1] + # 가중치/입력/출력의 스패드 사용량 계산 weight_size = k_w * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N @@ -297,6 +362,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + # lane 및 전체 스패드 제한을 넘지 않는지 확인 check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) @@ -313,6 +379,11 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation return tile_candidates def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + """Create convolution tiling candidates that allow multi-tile decomposition along kernel width. + + 설명: conv->GEMM 근사를 사용하되 K_W와 같은 커널 폭을 고려해 다중 타일 전략을 생성합니다. + 필요성: 일부 conv 설정에서 단일 타일로 충분히 표현할 수 없을 때, 효과적인 multi-tile 분해를 찾기 위해 사용됩니다. + """ tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -349,6 +420,11 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, return tile_candidates def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + """Create convolution tiling candidates targeting single-batch usage. + + 설명: 입력 배치가 1인 경우에 맞춘 conv 타일 후보를 생성합니다. stride/dilation 및 filter 크기를 반영합니다. + 필요성: 단일 배치에서 메모리/스패드 활용을 최적화하고 성능을 높이기 위해 사용됩니다. + """ tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -385,6 +461,11 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio return tile_candidates def meta_kernel(self): + """Prepare and register metadata needed by the wrapper and external tooling. + + 이 메서드는 wrapper 코드에 출력할 루프 정보와 인자 속성을 정리하여 등록합니다. + 목적: 생성된 커널 코드와 외부 툴(예: 검증/벤치마크)이 필요로 하는 메타정보를 제공하기 위함입니다. + """ wrapper = V.graph.wrapper_code kernel_arg_attributes = self.kernel_arg_attributes _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -399,14 +480,26 @@ def meta_kernel(self): wrapper.add_import_once(f"arg_attributes = {arg_attributes}") def call_kernel(self, kernel_name): + """Generate and register the wrapper call to the compiled kernel. + + 역할: wrapper에 커널 호출 코드를 생성하여 외부(파이썬 또는 래퍼)에서 해당 커널을 실행할 수 있게 합니다. + 왜 필요한가: 템플릿으로 생성된 커널을 실제 호출 코드와 연결하기 위해 필요합니다. + """ wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args, cuda=False) - + + # node = schedule buffer def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): + """Generate source code for a template given its render and surrounding prologue/epilogue nodes. + + 이 함수는 주어진 템플릿(render)을 실행하여 부분 코드를 얻고, prologue/epilogue 노드를 코드화하며 + 필요한 load/store/reduction 훅을 교체하여 통합된 소스 코드를 반환합니다. + 왜 필요한가: 템플릿 기반 커널의 전체 소스(프로로그/에필로그 포함)를 일관되게 생성하기 위해 필요합니다. + """ with self as kernel: _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() for node in [template_node, *prologue_nodes, *epilogue_nodes]: @@ -489,6 +582,10 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ return src_code def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): + """For each tile candidate, generate code, run benchmark and collect results. + + 목적: 자동 튜닝을 위해 후보별로 코드를 생성하고 실행(벤치마크) 결과를 수집하여 최적안을 찾을 수 있게 합니다. + """ choices = [] for tile_info in tile_candidates: if extension_config.CONFIG_DEBUG_MODE: @@ -501,6 +598,10 @@ def make_choices(self, tile_candidates, render, template_node, prologue_nodes, e return choices def _log_autotune_result(self, best_choice, best_cycle): + """Log the result of autotuning (best tile size and cycles). + + 필요성: 자동 튜닝 결과를 사용자에게 알려주고 디버깅/분석에 사용됩니다. + """ tile_size = best_choice[2] print( f"[Auto-tune] Optimal tile size: {list(tile_size)}, " @@ -508,6 +609,11 @@ def _log_autotune_result(self, best_choice, best_cycle): ) def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): + """Top-level API to produce source for given template nodes. + + 동작: autotune 설정에 따라 자동 튜닝을 실행하거나(있다면), 첫 후보 또는 단일 타일로 코드를 생성합니다. + 왜 필요한가: 실제 커널 소스 생성의 진입점으로 상위 로직이 이 함수를 호출합니다. + """ if "autotune" in extension_config.codegen_mapping_strategy and len(tile_candidates): src_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) self.loop_size = loop_size @@ -534,6 +640,10 @@ def _prepare_simulator_headers(self, src_code): write_atomic(gem5_write_path, self.gem5_header.getvalue()) def codegen_prologue_body(self): + """Generate the prologue portion of the kernel body (DMA loads, spad setup, prologue compute). + + 왜 필요한가: prologue는 타일의 입력/가중치 로드와 초기화 작업을 수행하며, main compute 이전에 필요한 준비 코드를 제공합니다. + """ body = IndentedBuffer() with self.prologue_buffer_group.as_local(): body.splice(self.spad_buffer) @@ -553,6 +663,10 @@ def codegen_prologue_body(self): return body def codegen_epilogue_body(self): + """Generate the epilogue portion of the kernel body (stores, reduction handling, DMA outs). + + 목적: 메인 계산 후 출력 저장과 리덕션 처리 등 후처리를 관리하여 결과를 메모리로 내보내는 역할을 합니다. + """ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] @@ -600,6 +714,11 @@ def def_kernel( names_str: str = "", input_reorder: Optional[List[int]] = None, ) -> str: + """Register kernel input/output names and hook to render function signature. + + 역할: 입력/출력 노드와 이름을 매핑하고, 렌더 시 사용할 인자 정의 훅을 등록합니다. + 왜 필요한가: 템플릿이 생성한 커널을 외부에서 호출할 때 정확한 인자 시그니처를 제공하기 위해 필요합니다. + """ names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -648,6 +767,10 @@ def def_conv_kernel( padded_input_size: List[int] = [], input_reorder: Optional[List[int]] = None, ) -> str: + """Define convolution-specific kernel signature and handle padded input size adjustments. + + 이유: convolution의 경우 파이썬 래퍼에서 패딩을 처리하므로 템플릿 시그니처에 패딩된 입력 크기를 반영해야 합니다. + """ names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -686,6 +809,10 @@ def kernel_hook(): # This function is for convolution wrapper function finalizing. def def_wrapper(self, only_store_buffer: bool = False, epilogue_buffer: str = False): + """Register a wrapper function signature hook used to finalize convolution wrappers. + + 목적: 파이썬 레벨의 래퍼 함수에서 사용할 인자 시그니처를 정의합니다(주로 buffer 이름만 전달). + """ def wrapper_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) wrapper_arg_defs = [arg.split('%')[1].split(':')[0] for arg in arg_defs] @@ -696,12 +823,24 @@ def wrapper_hook(): return "" def get_conv_inputs(self): + """Return mapping of convolution input buffer names used by the kernel. + + 유용성: convolution wrapper/외부 코드가 입력 버퍼 이름을 필요로 할 때 호출됩니다. + """ return self.kernel_group.args.input_buffers def get_conv_outputs(self): + """Return mapping of convolution output buffer names that are actively used (not REMOVED). + + 유용성: wrapper가 출력 버퍼들을 쿼리할 때 사용됩니다. + """ return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} def load_input(self, indent_size: int = 0): + """Create a render hook that prepares input (DMA-in and prologue) code for the kernel. + + 이유: 입력 데이터와 가중치를 타일까지 맞추어 DRAM에서 SRAM/SPAD로 불러오는 코드를 생성합니다. + """ def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() @@ -734,6 +873,10 @@ def hook(): return "" def store_output(self, indent_size: int = 0): + """Register a render hook that returns the epilogue (store/output) code. + + 목적: 커널의 출력 저장/후처리 코드를 템플릿 렌더링 과정에서 올바른 위치에 삽입하기 위해 필요합니다. + """ def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() @@ -744,6 +887,10 @@ def hook(): return "" def reduction_output(self, indent_size: int = 0): + """Register a hook that injects reduction-specific output code into rendered template. + + 이유: 리덕션 연산의 특수한 후처리 코드(축소 결과 집계 등)를 템플릿의 출력 부분에 주입하기 위해 사용됩니다. + """ def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() @@ -752,6 +899,10 @@ def hook(): return "" def def_function(self): + """Optionally define an outer (Python) function wrapper for the kernel. + + 목적: 외부에서 호출 가능한 파이썬 래퍼를 생성하거나, 없다면 None을 반환합니다. + """ _, call_args, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) @@ -763,6 +914,10 @@ def def_function(self): return None, None def def_global_vars(self): + """Register global variable definitions hook for the template rendering. + + 이유: 템플릿에서 필요한 전역 변수(예: 헤더에 들어갈 상수)를 렌더 시 삽입하기 위해 사용됩니다. + """ key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() @@ -772,6 +927,10 @@ def hook(): return key def def_local_vars(self, indent_size=0): + """Register local variable definitions (constants and allocations) for rendering. + + 목적: 커널 내부에서 사용되는 상수/할당 변수를 정의하고 템플릿 내부에서 참조 가능하게 합니다. + """ key = "" def hook(): code = IndentedBuffer() @@ -786,6 +945,10 @@ def hook(): def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, subtile_size:list=[], async_type=None, indent_size=0): + """Generate DMA operation code (MVIN/MVOUT) for given DRAM variable and tile descriptor. + + 필요성: DRAM <-> SPAD(SRAM) 이동을 MLIR/시뮬레이터용 코드로 변환하기 위해 사용됩니다. subtile/async 옵션을 통해 세부 동작을 제어합니다. + """ # Prepare code block local_code = IndentedBuffer() with V.set_kernel_handler(self): @@ -828,6 +991,10 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): + """Define/get the SRAM (SPAD) buffer memref declaration for a given DRAM name and tile. + + 목적: 타일용 SRAM 전역 버퍼를 할당하고 해당 global memref 선언 코드를 반환합니다. + """ # Prepare code block with V.set_kernel_handler(self): dtype = self.named_nodes[dram_name].get_layout().dtype @@ -837,6 +1004,10 @@ def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): return textwrap.indent(code, " "*indent_size).strip() def render(self, template, kwargs, define_function=None): + """Render an MLIR template and attach rendering hooks. + + 역할: 주어진 템플릿을 실제 코드 문자열로 렌더링하고, 필요한 경우 define_function을 통해 훅을 등록합니다. + """ code = template.render(**kwargs) if define_function is not None: define_function(self) @@ -847,10 +1018,18 @@ def render(self, template, kwargs, define_function=None): ) def get_spad_size_per_lane(self, tile_m, tile_n): + """Estimate SPAD usage per lane given tile dimensions. + + 사용 이유: SPAD 사용량을 타일링/매핑 후보 평가에서 비교하기 위함입니다. + """ size = tile_m * ((tile_n + self.vector_lane - 1) // self.vector_lane) return max(size, 2) # vector load/store def load_epilogue(self, name: str, index: sympy.Expr): + """Load data from SRAM (epilogue path) into vector registers for computation. + + 목적: epilogue 모드에서 SRAM에서 벡터를 읽어오는 코드를 생성하여 리덕션/후처리 계산에 사용됩니다. + """ index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) @@ -911,6 +1090,10 @@ def load_epilogue(self, name: str, index: sympy.Expr): return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): + """Store a computed value back into SRAM and schedule DMA out if necessary (epilogue path). + + 필요성: epilogue에서 계산된 값을 SRAM에 저장하고 최종적으로 DRAM으로 MVOUT을 생성하여 결과를 내보냅니다. + """ index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) @@ -961,6 +1144,10 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): self.dma_stores.writeline(DeferredLine(name, code)) def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): + """Handle generation of partial reduction storage and merging logic for a reduction operation. + + 필요성: 리덕션의 중간 결과를 로드/결합/저장하여 최종 결과를 생성하는데 필요한 코드를 생성합니다. Welford 등 특별한 리덕션도 처리합니다. + """ argmax_or_argmin = reduction_type in {"argmax", "argmin"} if argmax_or_argmin: raise NotImplementedError() #TODO: argmin, argmax @@ -1024,6 +1211,10 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): return sram_var def store_reduction_epilogue(self, name, index, value): + """Finalize the reduction by combining partial results and emitting MVOUT to DRAM. + + 필요성: 여러 단계로 나뉜 리덕션의 파셜 결과들을 합치고, 최종적으로 DRAM에 저장하는 절차를 담당합니다. + """ index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) @@ -1067,47 +1258,6 @@ def store_reduction_epilogue(self, name, index, value): line = f"{operation} %{init_vec}, %{value}[{compute_index_var}] : {partial_tile_shape}, {partial_vshape}" self.reductions_suffix.writeline(line) - # 2 step reduction - new_vec_size = 2 - new_vshape = f"vector<{partial_vec_size//new_vec_size}x{new_vec_size}x{mlir_dtype}>" - new_reduced_shape = f"vector<{new_vec_size}x{mlir_dtype}>" - out = self.cse.generate(self.reductions_suffix, f"vector.shape_cast %{out} : {partial_vshape} to {new_vshape}") - init_vec = self.const_cse.generate(self.const_buffer, f"vector.broadcast %{init} : {mlir_dtype} to {new_reduced_shape}") - out = self.cse.generate(self.reductions_suffix, reduction_combine_vec(self.reduction_info[value][0], out, init_vec, axis=0, shape=new_vshape, reduced_shape=new_reduced_shape)) - out2 = self.cse.generate(self.reductions_suffix, f"vector.shuffle %{out}, %{out} [1, 0] : {new_reduced_shape}, {new_reduced_shape}") - - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - self.register_var_info(out, [new_vec_size, mlir_dtype]) - self.register_var_info(out2, [new_vec_size, mlir_dtype]) - out = reduction_partial_combine_vec(self.reduction_info[value][0], out, out2) - self.compute, self.reductions_suffix = self.reductions_suffix, self.compute - - if self.welford_reduce_out is not None: - # NOTE: It not a real welford algorithm... We just used E(X^2) - E(X)^2 - divider = self.cse.generate(self.reductions_suffix, f"arith.constant {float(self.r_dim_size)} : f32") - if self.buffer_types[name][1] > 1: - divider_vec = self.cse.generate(self.reductions_suffix, f"vector.broadcast %{divider} : f32 to {new_reduced_shape}") - else: - divider_vec = divider - - if self.current_node.node.origin_node: # FIXME: This is a temporary solution - # mean = SUM(X) / N - self.reduction_mean.append(self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}")) - out = self.reduction_mean[i] - else: - # m2 = (E(X^2) - E(X)^2) * N - sqr_mean = self.cse.generate(self.reductions_suffix, f"arith.divf %{out}, %{divider_vec} : {new_reduced_shape}") - mean_sqr = self.cse.generate(self.reductions_suffix, f"arith.mulf %{self.reduction_mean[i]}, %{self.reduction_mean[i]} : {new_reduced_shape}") - variance = self.cse.generate(self.reductions_suffix, f"arith.subf %{sqr_mean}, %{mean_sqr} : {new_reduced_shape}") - m2 = self.cse.generate(self.reductions_suffix, f"arith.mulf %{variance}, %{divider_vec} : {new_reduced_shape}") - out = m2 - - final_zero_var_list[-1] = f"%{body_index_var}" - final_compute_index_var = ",".join(final_zero_var_list) - operation = "affine.vector_store" - line = f"{operation} %{out}, %{sram_var}[{final_compute_index_var}] : {final_tile_shape}, {new_reduced_shape}" - self.reductions_suffix.writeline(DeferredLine(name, line)) - # MVOUT Encoding # Generate DMA instruction attribute = f"{{dram_stride={dram_stride}, sram_stride={final_tile_stride}, padding=0}}" @@ -1116,6 +1266,10 @@ def store_reduction_epilogue(self, name, index, value): self.reductions_suffix.writeline(DeferredLine(name, code)) def set_tile_size(self, template_fusion_info, prologue=False): + """Configure tile descriptor and related loop/reduction state based on template fusion info. + + 왜 필요한가: 템플릿이 요구하는 타일 크기/벡터화 정보를 커널 상태에 반영하고, 리덕션일 경우 관련 루프 및 벡터 크기를 조정합니다. + """ tile_desc = template_fusion_info["dram_tile_desc"] if "dim_aliasing" in template_fusion_info: self.dim_aliasing = template_fusion_info["dim_aliasing"] @@ -1147,6 +1301,10 @@ def set_tile_size(self, template_fusion_info, prologue=False): return tile_desc def rename_indexing(self, index) -> sympy.Expr: + """Apply dim_aliasing substitutions safely to avoid cyclic renames. + + 필요성: dim aliasing을 적용할 때 이름 충돌(서로 바꾸는 케이스)을 피하기 위해 임시 이름을 사용하여 안전하게 치환합니다. + """ for dim_name, dim_aliased_name in self.dim_aliasing.items(): index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) # To avoid this case ({"index0":"index1", "index1":"index0"}) diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index ffe8e4fc..9633c27d 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -35,7 +35,7 @@ def poisson_request_generator(lambda_requests, max_msec_time=None): yield current_time -class Request: +class Request: # model을 배치단위로 request 받고 처리 """ Each request has model name, it's own id, and requested time. """ request_id = 0 QUEUED = 1 @@ -54,7 +54,7 @@ def __init__(self, model:str, batchable_input_tensor : List[torch.Tensor], self.id = self.allocate_id() self.request_queue_idx = request_queue_idx - def allocate_id(self): + def allocate_id(self): # id 할당 allocated_id = Request.request_id Request.request_id += 1 return allocated_id @@ -67,7 +67,7 @@ def set_finished(self, finish_time): self.state = self.FINISHED self.finish_time.append(finish_time) - def get_latency(self): + def get_latency(self): # 실행시간 구하기 # Todo. Provide Toke-By-Token if self.state == self.FINISHED: turnaround_time = self.finish_time[-1] - self.arrival_time @@ -86,7 +86,7 @@ def get_latency(self): return turnaround_time, response_time, tbt_time - def free_memory(self): + def free_memory(self): """ Free memory resources that are allocated for handle this request """ return @@ -130,14 +130,14 @@ def get_batchable_input(self): def get_shared_input(self): return self.batched_req[0].shared_input_tensor - def get_input(self): + def get_input(self): # batch tensor + shared tensor return self.get_batchable_input() + self.get_shared_input() def __str__(self): return f"DNN Model: {self.model_name}, Partion idx: {self.partition_idx} Req: {self.batched_req[0]}" @staticmethod - def register_model(model_name : str, compiled_model): + def register_model(model_name : str, compiled_model): # modelmap list에 model 등록 SchedulerDNNModel.MODEL_MAP[model_name] = compiled_model class PyTorchSimRunner: diff --git a/tests/test_foobar.py b/tests/test_foobar.py new file mode 100644 index 00000000..10292d87 --- /dev/null +++ b/tests/test_foobar.py @@ -0,0 +1,48 @@ +import torch +import torch._dynamo + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_foobar(device, size=(128, 128)): + def vector_foobar(a): + return torch._foobar(a) + + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(vector_foobar) + res = opt_fn(x) + + out = x.cpu() + test_result("Foobar", res, out) + + +if __name__ == "__main__": + import os + import sys + import argparse + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + parser = argparse.ArgumentParser(description="Run Foobar test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + test_foobar(device, (1, 1)) + test_foobar(device, (47, 10)) + test_foobar(device, (128, 128)) + test_foobar(device, shape) \ No newline at end of file diff --git a/tests/test_matmul.py b/tests/test_matmul.py index cd30bd30..180358c9 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -27,7 +27,8 @@ def custom_matmul(a, b): w1 = weight.to(device=device) x2 = input.to("cpu") w2 = weight.to("cpu") - opt_fn = torch.compile(dynamic=False)(custom_matmul) + opt_fn = torch.compile(dynamic=False)(custom_matmul) # dynamic = False -> 모든 tensor의 shape은 compile 시점에 고정 + # -> loop 범위, 메모리 주소, 타일링 크기 전부 상수로 박기 -> loop unrolling 가능(스케줄링 단순화) -> 성능 상승 res = opt_fn(x1, w1) y = custom_matmul(x2, w2) test_result("Matmul Forward", res, y) @@ -52,7 +53,7 @@ def custom_matmul(bias, a, b): def test_addmm2(device, input_size=128, hidden_size=128, output_size=128): def custom_matmul(bias, a, b): - return torch.matmul(a, b) #+ bias + return torch.matmul(a, b) #+ bias ?! torch.manual_seed(0) input = torch.randn(input_size, hidden_size) weight = torch.randn(hidden_size, output_size) diff --git a/tutorial/session1/ExecutionMode.ipynb b/tutorial/session1/ExecutionMode.ipynb index 22e00bed..fc501cf1 100644 --- a/tutorial/session1/ExecutionMode.ipynb +++ b/tutorial/session1/ExecutionMode.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,9 +29,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n", + "Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/npu/build.ninja...\n", + "Building extension module npu...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module npu...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "Wrapper Codegen Path = /tmp/torchinductor_root/yr/cyrl4zqohiglmrez32dmaijhd3sfdh4xabea5splhxwtwckiykpr.py\n", + "[Gem5] Gem5 is running.. \n", + "[Spike] Running Spike simulator\n", + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_084328.log\"\n" + ] + } + ], "source": [ "from Scheduler.scheduler import PyTorchSimRunner\n", "device = PyTorchSimRunner.setup_device().custom_device()\n", @@ -52,9 +76,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Spike] Running Spike simulator\n" + ] + } + ], "source": [ "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.json\"\n", "\n", @@ -74,9 +106,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_084421.log\"\n" + ] + } + ], "source": [ "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", "\n", @@ -97,9 +138,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wrapper Codegen Path = /tmp/torchinductor_root/m6/cm63zhmgb7n2askwt37lf72xuvbgpk6uvtmexreuxosqt3g5466s.py\n", + "[Gem5] Gem5 is running... \n", + "[TOGSim] TOGSim is running.. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_085446.log\"\n" + ] + } + ], "source": [ "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_timing_only.json\"\n", "\n", @@ -112,11 +164,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-12-31 08:54:45.643] [info] Total execution cycles: 299095\n" + ] + } + ], "source": [ - "!cat /root/workspace/PyTorchSim/outputs/20251202_160520/togsim_result.log | grep \"Total execution cycle\"" + "!cat /workspace/PyTorchSim/togsim_results/20251231_085446.log | grep \"Total execution cycle\"" ] }, { @@ -128,9 +188,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[TOGSim] TOGSim is running... \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_085548.log\"\n" + ] + } + ], "source": [ "os.environ['TOGSIM_CONFIG']=f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_2_cores.json\"\n", "\n", @@ -143,11 +212,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-12-31 08:55:47.922] [info] Total execution cycles: 167394\n" + ] + } + ], "source": [ - "!cat /root/workspace/PyTorchSim/outputs/20251202_160547/togsim_result.log | grep \"Total execution cycle\"" + "!cat /workspace/PyTorchSim/togsim_results/20251231_085548.log | grep \"Total execution cycle\"" ] }, { @@ -160,7 +237,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "base", "language": "python", "name": "python3" }, diff --git a/tutorial/session1/Inference.ipynb b/tutorial/session1/Inference.ipynb index a49e2440..c55f2907 100644 --- a/tutorial/session1/Inference.ipynb +++ b/tutorial/session1/Inference.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -53,9 +53,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n", + "No modifications detected for re-loaded extension module npu, skipping build step...\n", + "Loading extension module npu...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Spike] Running Spike simulator\n", + "[TOGSim] TOGSim is running.. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_081752.log\"\n" + ] + } + ], "source": [ "from Scheduler.scheduler import PyTorchSimRunner\n", "device = PyTorchSimRunner.setup_device().custom_device()\n", @@ -70,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -92,9 +111,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------------------\n", + "|MatMul Test Passed|\n", + "--------------------\n" + ] + } + ], "source": [ "test_result(\"MatMul\", npu_out, cpu_out)" ] @@ -102,7 +131,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "base", "language": "python", "name": "python3" }, diff --git a/tutorial/session1/LogAnalysis.ipynb b/tutorial/session1/LogAnalysis.ipynb index 4f1e17cb..3b5ad4c6 100644 --- a/tutorial/session1/LogAnalysis.ipynb +++ b/tutorial/session1/LogAnalysis.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -31,9 +31,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n", + "Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/npu/build.ninja...\n", + "Building extension module npu...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module npu...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "Wrapper Codegen Path = /tmp/torchinductor_root/yr/cyrl4zqohiglmrez32dmaijhd3sfdh4xabea5splhxwtwckiykpr.py\n", + "[Gem5] Gem5 is running.. \n", + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/tutorial/session1/togsim_results/20251231_091154.log\"\n" + ] + } + ], "source": [ "from Scheduler.scheduler import PyTorchSimRunner\n", "device = PyTorchSimRunner.setup_device().custom_device()\n", @@ -54,9 +77,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/tutorial/session1/togsim_results/20251231_092630.log\"\n" + ] + } + ], "source": [ "os.environ['TOGSIM_DEBUG_LEVEL']=\"trace\"\n", "\n", @@ -77,7 +109,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "base", "language": "python", "name": "python3" }, diff --git a/tutorial/session1/Mapping.ipynb b/tutorial/session1/Mapping.ipynb index b02c98fe..32924565 100644 --- a/tutorial/session1/Mapping.ipynb +++ b/tutorial/session1/Mapping.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -29,9 +29,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n", + "Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/npu/build.ninja...\n", + "Building extension module npu...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module npu...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "Wrapper Codegen Path = /tmp/torchinductor_root/yr/cyrl4zqohiglmrez32dmaijhd3sfdh4xabea5splhxwtwckiykpr.py\n", + "[Gem5] Gem5 is running.. \n", + "[Spike] Running Spike simulator\n", + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_075825.log\"\n" + ] + } + ], "source": [ "from Scheduler.scheduler import PyTorchSimRunner\n", "device = PyTorchSimRunner.setup_device().custom_device()\n", @@ -45,11 +69,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-12-31 07:58:25.086] [info] Total execution cycles: 48555\n" + ] + } + ], "source": [ - "!cat /root/workspace/PyTorchSim/outputs/20251202_154524/togsim_result.log | grep \"Total execution cycle\"" + "!cat /workspace/PyTorchSim/togsim_results/20251231_075825.log | grep \"Total execution cycle\"" ] }, { @@ -62,9 +94,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wrapper Codegen Path = /tmp/torchinductor_root/pc/cpcaxj7e7vdnspoz7rgrba6r2bofbgzjblvyfh2j37ta32vme4zs.py\n", + "[Gem5] Gem5 is running. \n", + "[Spike] Running Spike simulator\n", + "[TOGSim] TOGSim is running.. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_083119.log\"\n" + ] + } + ], "source": [ "torch._dynamo.reset()\n", "\n", @@ -79,11 +123,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-12-31 08:31:18.508] [info] Total execution cycles: 46168\n" + ] + } + ], "source": [ - "!cat /root/workspace/PyTorchSim/outputs/20251202_141933/togsim_result.log | grep \"Total execution cycle\"" + "!cat /workspace/PyTorchSim/togsim_results/20251231_083119.log | grep \"Total execution cycle\"" ] }, { @@ -95,9 +147,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Wrapper Codegen Path = /tmp/torchinductor_root/ev/cev4vqqayc33pq3b5lgvnxf5wb2tuzsftkk7xqnmhe7skb2pqvpf.py\n", + "[Gem5] Gem5 is running.. \n", + "[Spike] Running Spike simulator\n", + "[TOGSim] TOGSim is running. \n", + "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251231_083259.log\"\n" + ] + } + ], "source": [ "torch._dynamo.reset()\n", "\n", @@ -112,11 +176,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2025-12-31 08:13:54.628] [info] Total execution cycles: 48515\n" + ] + } + ], "source": [ - "!cat /root/workspace/PyTorchSim/outputs/20251202_141951/togsim_result.log | grep \"Total execution cycle\"" + "!cat /workspace/PyTorchSim/togsim_results/20251231_081355.log | grep \"Total execution cycle\"" ] }, { @@ -129,7 +201,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "base", "language": "python", "name": "python3" }, diff --git a/tutorial/session1/tutorial_external_mapping.json b/tutorial/session1/tutorial_external_mapping.json index 3982d950..cc3501e0 100644 --- a/tutorial/session1/tutorial_external_mapping.json +++ b/tutorial/session1/tutorial_external_mapping.json @@ -1,6 +1,6 @@ { "1024_1024_1024" : { - "TILE_M" : 512, + "TILE_M" : 1024, "TILE_N" : 512, "TILE_K" : 512 } diff --git a/tutorial/session2/Hands_on.ipynb b/tutorial/session2/Hands_on.ipynb index 33ec1a28..4456700e 100644 --- a/tutorial/session2/Hands_on.ipynb +++ b/tutorial/session2/Hands_on.ipynb @@ -13,8 +13,7 @@ "Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...\n", "Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/npu/build.ninja...\n", "Building extension module npu...\n", - "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", - "Loading extension module npu...\n" + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" ] }, { @@ -23,6 +22,13 @@ "text": [ "ninja: no work to do.\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading extension module npu...\n" + ] } ], "source": [ @@ -32,6 +38,8 @@ "import torch._dynamo\n", "import torch.utils.cpp_extension\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", + "os.environ['TOGSIM_CONFIG'] = f\"{base_dir}/tutorial/session1/togsim_configs/togsim_config_functional_only.yml\"\n", + "os.environ['TORCHSIM_DUMP_PATH'] = os.path.join(os.getcwd(), \"togsim_results\")\n", "sys.path.append(base_dir)\n", "\n", "from Scheduler.scheduler import PyTorchSimRunner\n", @@ -70,24 +78,49 @@ "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "[2025-12-03 02:02:14,679] [0/0] torch._inductor.debug: [WARNING] model___9 debug trace: /tmp/torchinductor_root/uu/cuumxtbdv4ukzpymchmrda2exohouwcdybawmj2v7jog4vbvoycf.debug\n" + "Wrapper Codegen Path = /tmp/torchinductor_root/lv/clvbryu3oxrzed6ugznbccuozwlntk5l4ww72d3h43thlpgidq35.py\n", + "[Gem5] Gem5 is running... [Gem5] Gem5 simulation failed with error: \"Redirecting stdout and stderr to /workspace/PyTorchSim/tutorial/session2/togsim_results/outputs/55kxfvowrj4/m5out/sto.log\n", + "\"\n", + "\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Wrapper Codegen Path = /tmp/torchinductor_root/uu/cuumxtbdv4ukzpymchmrda2exohouwcdybawmj2v7jog4vbvoycf.py\n", - "[Gem5] Gem5 is running... \n", - "[Spike] Running Spike simulator\n", - "[TOGSim] TOGSim is running.. \n", - "[TOGSim] Simulation log is stored to \"/workspace/PyTorchSim/togsim_results/20251203_020218.log\"\n", - "------------------\n", - "|exp2 Test Passed|\n", - "------------------\n" + "ename": "RuntimeError", + "evalue": "Gem5 Simulation Failed: \"Redirecting stdout and stderr to /workspace/PyTorchSim/tutorial/session2/togsim_results/outputs/55kxfvowrj4/m5out/sto.log\n\"", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/workspace/PyTorchSim/Simulator/simulator.py:176\u001b[0m, in \u001b[0;36mCycleSimulator.compile_and_simulate\u001b[0;34m(self, target_binary, array_size, vectorlane_size, silent_mode)\u001b[0m\n\u001b[1;32m 175\u001b[0m progress_thread\u001b[38;5;241m.\u001b[39mstart()\n\u001b[0;32m--> 176\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43msubprocess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcheck_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgem5_cmd\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstderr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubprocess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDEVNULL\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 177\u001b[0m finished \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/subprocess.py:421\u001b[0m, in \u001b[0;36mcheck_output\u001b[0;34m(timeout, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 419\u001b[0m kwargs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m empty\n\u001b[0;32m--> 421\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mpopenargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstdout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mPIPE\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 422\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mstdout\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/subprocess.py:526\u001b[0m, in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check \u001b[38;5;129;01mand\u001b[39;00m retcode:\n\u001b[0;32m--> 526\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CalledProcessError(retcode, process\u001b[38;5;241m.\u001b[39margs,\n\u001b[1;32m 527\u001b[0m output\u001b[38;5;241m=\u001b[39mstdout, stderr\u001b[38;5;241m=\u001b[39mstderr)\n\u001b[1;32m 528\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m CompletedProcess(process\u001b[38;5;241m.\u001b[39margs, retcode, stdout, stderr)\n", + "\u001b[0;31mCalledProcessError\u001b[0m: Command '['/workspace/gem5/build/RISCV/gem5.opt', '-r', '--stdout-file=sto.log', '-d', '/workspace/PyTorchSim/tutorial/session2/togsim_results/outputs/55kxfvowrj4/m5out', '/workspace/PyTorchSim/gem5_script/script_systolic.py', '-c', '/workspace/PyTorchSim/tutorial/session2/togsim_results/outputs/55kxfvowrj4/cycle_bin', '--vlane', '128']' returned non-zero exit status 1.", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m func \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mexp2\n\u001b[1;32m 5\u001b[0m opt_fn \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcompile(dynamic\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)(func)\n\u001b[0;32m----> 6\u001b[0m npu_out \u001b[38;5;241m=\u001b[39m \u001b[43mopt_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnpu_x\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m cpu_out \u001b[38;5;241m=\u001b[39m func(cpu_x)\n\u001b[1;32m 8\u001b[0m test_result(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexp2\u001b[39m\u001b[38;5;124m\"\u001b[39m, npu_out, cpu_out)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 487\u001b[0m dynamo_config_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 491\u001b[0m set_eval_frame(prior)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:15\u001b[0m, in \u001b[0;36mwrap_inline..inner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrap_inline\u001b[39m(fn):\n\u001b[1;32m 11\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;124;03m Create an extra frame around fn that is not in skipfiles\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inner\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 487\u001b[0m dynamo_config_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 491\u001b[0m set_eval_frame(prior)\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:17\u001b[0m, in \u001b[0;36mwrap_inline..inner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:901\u001b[0m, in \u001b[0;36maot_module_simplified..forward\u001b[0;34m(*runtime_args)\u001b[0m\n\u001b[1;32m 899\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(params_flat)\n\u001b[1;32m 900\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(runtime_args)\n\u001b[0;32m--> 901\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfull_args\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:81\u001b[0m, in \u001b[0;36mmake_boxed_func..g\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mg\u001b[39m(args):\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:94\u001b[0m, in \u001b[0;36mcreate_runtime_wrapper..runtime_wrapper\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 89\u001b[0m \u001b[38;5;66;03m# When we have an inference graph, we run with torch.no_grad.\u001b[39;00m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;66;03m# It's possible to get an inference graph with inputs that require grad,\u001b[39;00m\n\u001b[1;32m 91\u001b[0m \u001b[38;5;66;03m# in which case we want to make sure autograd is disabled\u001b[39;00m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;66;03m# (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)\u001b[39;00m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 94\u001b[0m all_outs \u001b[38;5;241m=\u001b[39m call_func_at_runtime_with_args(\n\u001b[1;32m 95\u001b[0m compiled_fn,\n\u001b[1;32m 96\u001b[0m args,\n\u001b[1;32m 97\u001b[0m disable_amp\u001b[38;5;241m=\u001b[39mdisable_amp,\n\u001b[1;32m 98\u001b[0m )\n\u001b[1;32m 100\u001b[0m num_mutated_runtime_inps \u001b[38;5;241m=\u001b[39m runtime_metadata\u001b[38;5;241m.\u001b[39mnum_mutated_inp_runtime_indices\n\u001b[1;32m 101\u001b[0m num_intermediate_bases \u001b[38;5;241m=\u001b[39m runtime_metadata\u001b[38;5;241m.\u001b[39mnum_intermediate_bases\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:105\u001b[0m, in \u001b[0;36mcall_func_at_runtime_with_args\u001b[0;34m(f, args, steal_args, disable_amp)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(f, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_boxed_call\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 105\u001b[0m out \u001b[38;5;241m=\u001b[39m normalize_as_list(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 107\u001b[0m \u001b[38;5;66;03m# TODO: Please remove soon\u001b[39;00m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;66;03m# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670\u001b[39;00m\n\u001b[1;32m 109\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 110\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYour compiler for AOTAutograd is returning a function that doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt take boxed arguments. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSee https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 113\u001b[0m )\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:118\u001b[0m, in \u001b[0;36maot_dispatch_base..rng_functionalization_wrapper\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fw\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py:864\u001b[0m, in \u001b[0;36mCompiledFxGraph.__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 863\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs: List[Any]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 864\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_current_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py:892\u001b[0m, in \u001b[0;36m_run_from_cache\u001b[0;34m(compiled_graph, inputs)\u001b[0m\n\u001b[1;32m 884\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m compiled_graph\u001b[38;5;241m.\u001b[39martifact_path\n\u001b[1;32m 885\u001b[0m compiled_graph\u001b[38;5;241m.\u001b[39mcompiled_artifact \u001b[38;5;241m=\u001b[39m PyCodeCache\u001b[38;5;241m.\u001b[39mload_by_key_path(\n\u001b[1;32m 886\u001b[0m compiled_graph\u001b[38;5;241m.\u001b[39mcache_key,\n\u001b[1;32m 887\u001b[0m compiled_graph\u001b[38;5;241m.\u001b[39martifact_path,\n\u001b[1;32m 888\u001b[0m compiled_graph\u001b[38;5;241m.\u001b[39mcache_linemap,\n\u001b[1;32m 889\u001b[0m compiled_graph\u001b[38;5;241m.\u001b[39mconstants,\n\u001b[1;32m 890\u001b[0m )\u001b[38;5;241m.\u001b[39mcall\n\u001b[0;32m--> 892\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_graph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompiled_artifact\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/tmp/torchinductor_root/lv/clvbryu3oxrzed6ugznbccuozwlntk5l4ww72d3h43thlpgidq35.py:96\u001b[0m, in \u001b[0;36mcall\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 94\u001b[0m buf0 \u001b[38;5;241m=\u001b[39m empty((\u001b[38;5;241m16\u001b[39m, \u001b[38;5;241m16\u001b[39m), device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnpu\u001b[39m\u001b[38;5;124m'\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 95\u001b[0m sram_plan_prefix(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbuf0\u001b[39m\u001b[38;5;124m'\u001b[39m, buf0)\n\u001b[0;32m---> 96\u001b[0m \u001b[43mextension_kernel_0\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg0_1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuf0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 97\u001b[0m sram_plan_postfix(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marg0_1\u001b[39m\u001b[38;5;124m'\u001b[39m, arg0_1)\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m arg0_1\n", + "File \u001b[0;32m/workspace/PyTorchSim/PyTorchSimFrontend/extension_codecache.py:264\u001b[0m, in \u001b[0;36mCustomAsyncCompile.mlir..dummy_simulator\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdummy_simulator\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# Wait for compilation\u001b[39;00m\n\u001b[0;32m--> 264\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[43mfuture\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresult\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfilelock\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FileLock\n\u001b[1;32m 266\u001b[0m lock_dir \u001b[38;5;241m=\u001b[39m get_lock_dir()\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/concurrent/futures/_base.py:458\u001b[0m, in \u001b[0;36mFuture.result\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CancelledError()\n\u001b[1;32m 457\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_state \u001b[38;5;241m==\u001b[39m FINISHED:\n\u001b[0;32m--> 458\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__get_result\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m()\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/concurrent/futures/_base.py:403\u001b[0m, in \u001b[0;36mFuture.__get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception:\n\u001b[1;32m 402\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 403\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exception\n\u001b[1;32m 404\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 405\u001b[0m \u001b[38;5;66;03m# Break a reference cycle with the exception in self._exception\u001b[39;00m\n\u001b[1;32m 406\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/opt/conda/lib/python3.10/concurrent/futures/thread.py:58\u001b[0m, in \u001b[0;36m_WorkItem.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuture\u001b[38;5;241m.\u001b[39mset_exception(exc)\n", + "File \u001b[0;32m/workspace/PyTorchSim/PyTorchSimFrontend/extension_codecache.py:245\u001b[0m, in \u001b[0;36mCustomAsyncCompile.mlir..task\u001b[0;34m()\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtask\u001b[39m():\n\u001b[0;32m--> 245\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[43mMLIRCodeCache\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 246\u001b[0m \u001b[43m \u001b[49m\u001b[43mvaldiation_wrapper_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_binary_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 247\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_binary_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_binary_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43marg_attributes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43marg_attributes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvectorlane_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvectorlane_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mtile_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtile_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspad_info\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspad_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morigins\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morigins\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[43m \u001b[49m\u001b[43msilent_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msilent_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m key\n", + "File \u001b[0;32m/workspace/PyTorchSim/PyTorchSimFrontend/extension_codecache.py:216\u001b[0m, in \u001b[0;36mMLIRCodeCache.load\u001b[0;34m(cls, source_code, validation_wrapper_name, validation_binary_name, cycle_wrapper_name, cycle_binary_name, arg_attributes, vectorlane_size, spad_info, origins, silent_mode, **kwargs)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;66;03m# Run cyclesim\u001b[39;00m\n\u001b[1;32m 215\u001b[0m cyclesim \u001b[38;5;241m=\u001b[39m CycleSimulator()\n\u001b[0;32m--> 216\u001b[0m cycle_list \u001b[38;5;241m=\u001b[39m \u001b[43mcyclesim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile_and_simulate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwrite_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcycle_binary_name\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m \u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43marray_size\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvectorlane_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msilent_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msilent_mode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[38;5;66;03m# Create TOG\u001b[39;00m\n\u001b[1;32m 219\u001b[0m w_offset, x_offset \u001b[38;5;241m=\u001b[39m vectorlane_size, vectorlane_size\n", + "File \u001b[0;32m/workspace/PyTorchSim/Simulator/simulator.py:186\u001b[0m, in \u001b[0;36mCycleSimulator.compile_and_simulate\u001b[0;34m(self, target_binary, array_size, vectorlane_size, silent_mode)\u001b[0m\n\u001b[1;32m 184\u001b[0m finished \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 185\u001b[0m progress_thread\u001b[38;5;241m.\u001b[39mjoin()\n\u001b[0;32m--> 186\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGem5 Simulation Failed: \u001b[39m\u001b[38;5;130;01m\\\"\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39moutput\u001b[38;5;241m.\u001b[39mdecode()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\\"\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 188\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdir_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/stats.txt\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m stat_file:\n\u001b[1;32m 189\u001b[0m raw_list \u001b[38;5;241m=\u001b[39m stat_file\u001b[38;5;241m.\u001b[39mreadlines()\n", + "\u001b[0;31mRuntimeError\u001b[0m: Gem5 Simulation Failed: \"Redirecting stdout and stderr to /workspace/PyTorchSim/tutorial/session2/togsim_results/outputs/55kxfvowrj4/m5out/sto.log\n\"" ] } ], @@ -104,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "5bfdf22f-e749-41a5-a2cf-dcbb630bfb83", "metadata": {}, "outputs": [ @@ -185,11 +218,8 @@ " affine.for %compute_idx = 0 to 2 step 2\n", " {\n", " %tmp0 = affine.vector_load %spad0[%compute_idx] : memref<256xf32, 1>, vector<2xf32>\n", - " %tmp1 = arith.constant 0.69314718055994528623 : f32\n", - " %tmp2 = vector.broadcast %tmp1 : f32 to vector<2xf32>\n", - " %tmp3 = arith.mulf %tmp0, %tmp2 : vector<2xf32>\n", - " %tmp4 = math.exp %tmp3 : vector<2xf32>\n", - " affine.vector_store %tmp4, %spad1[%compute_idx] : memref<256xf32, 1>, vector<2xf32>\n", + " %tmp1 = math.exp2 %tmp0 : vector<2xf32>\n", + " affine.vector_store %tmp1, %spad1[%compute_idx] : memref<256xf32, 1>, vector<2xf32>\n", " } {inner_loop=false}\n", " memref.dma_start %spad1[%const0], %out_ptr0[%index0], %const2, %alloc1[%const0], %const0, %const1 : memref<256xf32, 1>, memref<256xf32>, memref<1xi32> {dram_stride=[1], sram_stride=[1], padding=0}\n", " } {outer_loop=true}\n", @@ -232,7 +262,7 @@ } ], "source": [ - "!cat /tmp/torchinductor_root/uu/cuumxtbdv4ukzpymchmrda2exohouwcdybawmj2v7jog4vbvoycf.py" + "!cat /tmp/torchinductor_root/lv/clvbryu3oxrzed6ugznbccuozwlntk5l4ww72d3h43thlpgidq35.py" ] } ], From a9c4c1892caa1457ac2cd4ca60fed6d4bcb2176e Mon Sep 17 00:00:00 2001 From: parkminkyumin Date: Sun, 25 Jan 2026 09:38:18 +0000 Subject: [PATCH 2/2] Add updated documentation for adding an ATen operator (MLIR path) --- docs/add_ATen_operator.md | 662 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 662 insertions(+) create mode 100644 docs/add_ATen_operator.md diff --git a/docs/add_ATen_operator.md b/docs/add_ATen_operator.md new file mode 100644 index 00000000..8f730d79 --- /dev/null +++ b/docs/add_ATen_operator.md @@ -0,0 +1,662 @@ +# How to Add a New ATen Operator via MLIR Templates in PyTorchSim + +--- + +## Overview +PyTorchSim executes PyTorch programs by **lowering ATen operators into MLIR**, +followed by backend-specific code generation and simulation. +This wiki helps contributors add support for **new ATen operators** by defining custom lowerings that route operators through MLIR templates in the MLIR-based execution path. + + +We use a dummy ATen operator, `torch._foobar()`, as a minimal example. Although `_foobar` has trivial semantics, it still exercises the full +integration workflow: + +- Defining an MLIR template +- Adding a custom lowering +- Registering the lowering +- Validating correctness with a test + + + +--- +## Background + +Before diving into the step-by-step implementation, it helps to understand how an ATen operator flows through PyTorchSim’s MLIR-based pipeline. There are two main codegen paths: the **ordinary path** (generic lowering → standard MLIR codegen) and the **template path** (lowering → MLIRTemplate → TemplateBuffer → template-driven codegen). Supporting a new ATen operator typically means mapping it at the lowering stage into the template path so its code is generated by a custom MLIR template. + +### Graph capture (`test_.py`) + +`torch.compile` runs `Dynamo` to trace the Python function and produce an `FX graph`. +That `FX graph` is then processed by `AOTAutograd`, which lowers it into an ATen-based graph. +Next, `Inductor` applies transformations such as decomposition and functionalization, and lowers the `ATen graph` into `Inductor IR`, which contains `ordinary IR nodes` and, when a custom lowering is used, `template-backed nodes`. At this point, `aten.` becomes eligible for custom handling in the lowering stage. + +### Lowering stage (`mlir_lowering.py`) + +`aten.` is dispatched to a custom function through the lowerings registry (`lowerings.update({aten._foobar: custom_foobar})`). + +The lowering : +* Materializes the input (`TensorBox.realize()`) +* Creates the MLIR template instance from `mlir__template.py` +* Returns a template-backed `TemplateBuffer` via `generate().output_node()` + +This is the primary hook where new ATen operator support is introduced. + +### Scheduling stage (`mlir_scheduling.py`) + +The template-backed node enters the scheduler and follows the +`codegen_template()` path. + +At this stage: +- Scheduling/Tiling decisions are applied +- MLIR source code is rendered from the template +- Kernel is prepared for registration + +### Kernel registration (`define_kernel()`) + +The generated MLIR source (`src_code`) is registered via the wrapper in `define_kernel()`. This emits a `custom_async_compile.mlir(...)` call into the wrapper so the kernel can be compiled, cached, and reused at runtime. + + +--- + +## Table of Contents + +- [Toy Example: `_foobar` as a Dummy ATen Op](#toy-example-_foobar-as-a-dummy-aten-op) + +- [Step 1 — Create `mlir__template.py`](#step-1--create-mlir_op_templatepy) + +- [Step 2 — Add a Custom Lowering in `mlir_lowering.py`](#step-2--add-a-custom-lowering-in-mlir_loweringpy) + - [Step 2.1 — Define the Custom Lowering Function](#step-21--define-the-custom-lowering-function) + - [Step 2.2 — Register the Lowering via `lowerings.update`](#step-22--register-the-lowering-via-loweringsupdate) + +- [Step 3 — Add a Test (`test_.py`)](#step-3--add-a-test-test_oppy) + +- [Summary](#summary) + + + +--- +## Toy Example: `_foobar` as a Dummy ATen Op + +`_foobar` is a deliberately trivial ATen operator exposed in PyTorch as [`torch._foobar()`](https://docs.pytorch.org/cppdocs/api/function_namespaceat_1a79acf777b53d06d41bdfa735b767506a.html). + +### Signature + +From the PyTorch C++ API, `_foobar` has the following signature: + +* `at::_foobar(const Tensor& self, bool arg1=true, bool arg2=true)` + +In other words, it takes: +- One input tensor (`self`) +- Three optional boolean flags (`arg1`, `arg2`) + +### Behavior + +In our Python checks, `_foobar` behaves like a simple **copy/identity-style op**: +- The output has the same shape and dtype as the input +- The output values match the input values + + +Once you understand the full flow with `_foobar`, you can replace it with other ATen op and expand +the lowering/template logic as needed. + +--- + + + +## Step 1 — Create `mlir__template.py` + +This step defines the MLIR template for the new ATen operator. In `_foobar`, the template is a minimal copy‑style kernel that allocates `SRAM buffers` for `X` and `Y`, then uses a tiled `affine.for` loop to move an `X tile` into the `X_buffer` via `DMA`, copy it into `Y_buffer` with `vector.load/store`, and write the `Y tile` back to `DRAM`. The file also sets up tile descriptors, indexing, and render options that fill the Jinja placeholders used for code generation. + +### File skeleton and imports + +Pulls in the pieces needed to define and render an MLIR template, including: + +* The base `MLIRTemplate` / `MLIRTemplateKernel` classes for template codegen +* `IRNode` from Inductor IR for node typing +* MLIR helpers from `mlir_common` (e.g., tile descriptors) +* `sympy` and `empty_strided` for shape/stride computations +* A standard utility(`typing`) used by the template + +```python +from typing import List, Optional +import sympy +from torch import empty_strided +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from torch._inductor.ir import IRNode +from PyTorchSimFrontend.mlir import mlir_common +``` + +### MLIR string template + +This template defines the raw **MLIR kernel** as a **Jinja-style string**, with placeholders such as `{{ M }}`, `{{ N }}`, `{{ TILE_M }}`, `{{ TILE_N }}`, and the tile descriptors (`X_tile_desc`, `Y_tile_desc`) are filled during `render()` via `kernel.render_options`. + +**In this `_foobar` implementation, the template:** + +Allocates SRAM buffers for `X` and `Y` with `def_sram_buffer`. +Builds a 2D tiled loop nest over `M` and `N` using `affine.for` and step sizes `TILE_M`, `TILE_N`. +Uses explicit DMA ops (`MVIN` / `MVOUT`) to move tiles between DRAM and SRAM. +Copies data from `%X_buffer` to `%Y_buffer` using `vector.load` / `vector.store` (vectorized copy). + +**For each tile of size (`TILE_M`, `TILE_N`):** + +The tile is transferred via DMA from DRAM into SRAM (`%X_buffer`). +The copy runs over rows (`i` increases by 1 each iteration) and columns in vector-sized chunks(`j` increases by `kernel.vector_lane`). +Each row is copied using `vector.load` / `vector.store` with `kernel.vector_lane` elements at a time. +If `TILE_N` = 128, one vector operation copies a whole row; +if `TILE_N` = 256, the row is copied in two vector chunks. (`vector_lane` = 128) +The completed tile in `%Y_buffer` is transferred via DMA back to DRAM(the output tensor `Y`). + + +`X_buffer` and `Y_buffer` are placed in `scratchpad (SPAD)`. Each tile is `TILE_M` * `TILE_N` elements, and each element is `f32` (4 bytes), so +`SPAD usage = 2 * TILE_M * TILE_N * 4 bytes` +Given scratchpad-size = 131072, we need: + +**2 * TILE_M * TILE_N * 4 <= 131072** + +Candidate sizes such as (128,128), (64,256), and (32,512), etc. fit. (128×128) had the lowest cycles and fewer tiles than options like (32×512). With vector utilization near 100% and low DMA activity, the kernel was compute‑bound, so reducing tile count mattered most. So we chose (128×128). + +If you need to pick or extend MLIR ops used in the template, refer to the official MLIR dialect list: +[MLIR dialect](https://mlir.llvm.org/docs/Dialects/) + +```python +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{- kernel.def_local_vars(indent_size=2) }} + %M_const = arith.constant {{ M }} : index + %N_const = arith.constant {{ N }} : index + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + affine.for %i = 0 to {{ TILE_M }} { + affine.for %j = 0 to {{ TILE_N }} step {{ kernel.vector_lane }} { + %v = vector.load %X_buffer[%i, %j] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.vector_lane }}x{{ DATA_STYPE }}> + vector.store %v, %Y_buffer[%i, %j] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.vector_lane }}x{{ DATA_STYPE }}> + } + } + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} + } { outer_loop=true } + } { outer_loop=true } + return +} +""" +``` + + + +### Template class definition & `__init__` + +Defines the `MLIRFoobarTemplate` class, which inherits from `MLIRTemplate`. +Initializes the template by calling the base class constructor to register. + +```python +class MLIRFoobarTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) +``` + +### Render entry point (`render`) + +Defines the main render entry point for the template. It selects the output node, `epilogue_nodes`/`prologue_nodes` are only used for fusion scenarios and are not otherwise consumed in this kernel. The method then prepares tile descriptors and indices, fills `kernel.render_options`, and returns the rendered MLIR code. + +```python + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): +``` + +### Output selection and basic setup + +Binds `X` and `Y`, then materializes an `empty_strided` tensor from `X.layout` to read the concrete shape and derive `M` and `N`. For this template, tile sizes are fixed to `TILE_M = 128` and `TILE_N = 128`. + +```python + if template_buffer_node is not None: + self.output_node = template_buffer_node + + X = self.input_nodes[0] + Y = self.output_node + + X_tensor = empty_strided(X.layout.size, X.layout.stride) + M = X_tensor.size()[0] + N = X_tensor.size()[1] + + TILE_M = 128 + TILE_N = 128 +``` + +### Tile descriptors and indices + +Defines 2D tile descriptors for `X` and `Y` using `MLIRMultiDimTile`. Each descriptor records the tile shape, vector‑lane mapping (`vlane_split_axis`, `vlane_stride`), and stride metadata that drive SRAM/DMA codegen. Here both tiles use `[TILE_M, TILE_N]` with explicit strides `[1, TILE_M]`. `vlane_split_axis` selects which axis is lane‑mapped, and `vlane_stride` sets the per‑lane spacing. This lane mapping is used when generating DMA metadata and when interpreting the SPAD layout for vectorized access. The input tile also carries the DRAM offset, and indexing is built from the 2D loop variables (`index0`, `index1`) multiplied by the DRAM strides of `X`/`Y`. + +```python + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_M,TILE_N] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_tile_desc.offset = X.get_layout().offset + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index1") * X_stride[1]] + + Y_tile_size = [TILE_M,TILE_N] + Y_tile_stride = [1, TILE_M] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] +``` + +### Render options + +Populates `kernel.render_options` with all values needed by the MLIR string template: kernel name/handle, tensor shape (`M`, `N`), tile sizes (`TILE_M`, `TILE_N`), input/output buffers (`X`, `Y`), index expressions (`X_idx`, `Y_idx`), tile descriptors (`X_tile_desc`, `Y_tile_desc`), data type, and input reordering. These options are then used to render `TEMPLATE`. + +```python + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + M=M, N=N, + TILE_M=TILE_M, + TILE_N=TILE_N, + X=X, + Y=Y, + X_idx=X_idx, + Y_idx=Y_idx, + X_tile_desc=X_tile_desc, + Y_tile_desc=Y_tile_desc, + DATA_STYPE="f32", + input_reorder=self.input_reorder, + ) +``` + +### Epilogue + +Stores epilogue metadata for the output tile—output_node, DRAM/SRAM names, and the tile descriptor—so the codegen can configure epilogue tile size and generate the final DMA/MVOUT path for the output. + +```python + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="Y_buffer", + dram_var="Y", + dram_tile_desc=Y_tile_desc, + ) + +``` + +### Render MLIR, register loop metadata and return + +Renders the MLIR string by applying `kernel.render_options` to `TEMPLATE`, records loop bounds/strides via `kernel.add_loop_info` for wrapper metadata, and returns the final MLIR source string. + +```python + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"]]) + + return code + +``` + +### Full `mlir_foobar_template.py` + +Copy-paste the full reference implementation below to create `mlir_foobar_template.py`. + +```python +from typing import List, Optional +import sympy +from torch import empty_strided +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from torch._inductor.ir import IRNode +from PyTorchSimFrontend.mlir import mlir_common + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X], outputs=[Y], names_str="X, Y", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} + {{- kernel.def_local_vars(indent_size=2) }} + %M_const = arith.constant {{ M }} : index + %N_const = arith.constant {{ N }} : index + affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { + affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { + {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, indent_size=6) }} + affine.for %i = 0 to {{ TILE_M }} { + affine.for %j = 0 to {{ TILE_N }} step {{ kernel.vector_lane }} { + %v = vector.load %X_buffer[%i, %j] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.vector_lane }}x{{ DATA_STYPE }}> + vector.store %v, %Y_buffer[%i, %j] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.vector_lane }}x{{ DATA_STYPE }}> + } + } + {{ kernel.def_dma_op("MVOUT", "Y", Y_idx, Y_tile_desc, indent_size=6) }} + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFoobarTemplate(MLIRTemplate): + + def __init__(self, input_nodes, layout, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + if template_buffer_node is not None: + self.output_node = template_buffer_node + + + X = self.input_nodes[0] + Y = self.output_node + + X_tensor = empty_strided(X.layout.size, X.layout.stride) + M = X_tensor.size()[0] + N = X_tensor.size()[1] + + TILE_M = 128 + TILE_N = 128 + + vlane_stride = 1 + vlane_split_axis = 1 + X_tile_size = [TILE_M,TILE_N] + X_tile_stride = [1, TILE_M] + X_tile_desc = mlir_common.MLIRMultiDimTile(X_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + X_tile_desc.set_tile_size_stride(X_tile_size, X_tile_stride) + X_tile_desc.set_name("X_buffer") + X_tile_desc.offset = X.get_layout().offset + X_stride = X.get_layout().stride + X_idx = [sympy.Symbol("index0") * X_stride[0], sympy.Symbol("index1") * X_stride[1]] + + Y_tile_size = [TILE_M,TILE_N] + Y_tile_stride = [1, TILE_M] + Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) + Y_tile_desc.set_name("Y_buffer") + Y_stride = Y.get_layout().stride + Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] + + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + M=M, N=N, + TILE_M=TILE_M, + TILE_N=TILE_N, + X=X, + Y=Y, + X_idx=X_idx, + Y_idx=Y_idx, + X_tile_desc=X_tile_desc, + Y_tile_desc=Y_tile_desc, + DATA_STYPE="f32", + input_reorder=self.input_reorder, + ) + + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="Y_buffer", + dram_var="Y", + dram_tile_desc=Y_tile_desc, + ) + + kernel.exception_nodes["Y"] = {"numel" : Y.get_numel()} + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["M"], kernel.render_options["N"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"]]) + return code + + +``` + + +--- + +## Step 2 — Add a Custom Lowering in `mlir_lowering.py` + +This step connects the ATen operator (`aten._foobar`) to the MLIR template. +All changes in this step are made in **`mlir_lowering.py`**. + +It consists of two parts: +1. Defining the custom lowering function +2. Registering that function so Inductor actually uses it + +--- + +### Step 2.1 — Define the Custom Lowering Function + +The custom lowering function specifies **how `aten._foobar` should be lowered** +during the Inductor lowering stage. + +For this example, the lowering is intentionally minimal and preserves the +MLIR/template path, making it suitable for testing and demonstration purposes. + +```python +def custom_foobar(a, *args, **kwargs): + a.realize() + layout = a.layout + mlir_template = MLIRFoobarTemplate([a], layout) + return mlir_template.generate().output_node() +``` + +This function follows the standard pattern for custom lowerings: + +* `a`: +The actual input tensor. During lowering, this is typically wrapped in a +`TensorBox`, which is Inductor’s IR wrapper that carries the tensor along +with its layout/metadata and deferred computation context. + +* `a.realize()`: +Materializes the input `TensorBox` so the MLIR template sees a concrete buffer / IR node. +This is a safe default pattern to ensure shape/stride/layout metadata is available. + +* `layout = a.layout`: +Gets the Inductor layout object, which encapsulates device, dtype, +shape (size), and stride information. The MLIR template uses this to build +memref types and derive tiling/indexing behavior. + +* `MLIRFoobarTemplate([a], layout)`: +Instantiates the MLIR template with the input node and its layout. + +* `generate().output_node()`: +Builds the template-backed Inductor IR node and returns it as the lowering +result. + + +--- + +### Step 2.2 — Register the Lowering via `lowerings.update` + +Defining the lowering function is only the first step. You also need to register it so Inductor can dispatch to it during lowering. + + +```python +lowerings.update({getattr(aten._foobar, overload): custom_foobar for overload in aten._foobar.overloads()}) +``` +The `lowerings` registry is consulted by Inductor during the lowering phase. +When the ATen graph contains `aten._foobar`, Inductor looks it up here and dispatches to custom_foobar instead of the default path. + +Using `aten._foobar.overloads()` ensures that all overload variants of the +operator are covered, even if multiple signatures exist. + + +--- +## Step 3 — Add a Test (`test_.py`) + +This step validates the custom lowering end-to-end. +The test ensures that the operator is correctly captured, lowered, compiled, +and executed through the PyTorchSim MLIR path. + +### Test helper for correctness + +Defines a small helper function that compares the compiled output against a +CPU reference result and reports pass/fail status. +This helper can be reused across multiple operator tests. + +```python +import torch +import torch._dynamo + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) +``` + +### Define a small wrapper function + +Defines a minimal wrapper function `vector_foobar` that calls +`torch._foobar(a)`. + +This function is what `torch.compile` captures into the graph, making it the +entry point for FX Graph → ATen → Inductor → custom lowering. + +```python +def test_foobar(device, size=(128, 128)): + def vector_foobar(a): + return torch._foobar(a) +``` + +### Create input and compile + +Creates a random input tensor and compiles the wrapper function using +`torch.compile(dynamic=False)`. + +```python + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(vector_foobar) + res = opt_fn(x) +``` + +### Run and verify + +Executes the compiled function and compares the result against a reference +output. + +For `_foobar`, the reference output is simply the input tensor moved to CPU, +since the operator behaves like an identity op. + +```python + out = x.cpu() + test_result("Foobar", res, out) +``` + +### `__main__` + +Defines the command-line entry point for the test. +This section: + +* Sets up the custom PyTorchSim device and runner + +* Parses shape arguments + +* Runs the test across multiple input sizes + +```python +if __name__ == "__main__": + import os + import sys + import argparse + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + parser = argparse.ArgumentParser(description="Run Foobar test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + test_foobar(device, (1, 1)) + test_foobar(device, (47, 10)) + test_foobar(device, (128, 128)) + test_foobar(device, shape) +``` + +Full `test_foobar.py` +Copy-paste the full reference implementation below to create `test_foobar.py`. + +```python +import torch +import torch._dynamo + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_foobar(device, size=(128, 128)): + def vector_foobar(a): + return torch._foobar(a) + + x = torch.randn(size).to(device=device) + opt_fn = torch.compile(dynamic=False)(vector_foobar) + res = opt_fn(x) + + out = x.cpu() + test_result("Foobar", res, out) + + +if __name__ == "__main__": + import os + import sys + import argparse + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + + parser = argparse.ArgumentParser(description="Run Foobar test with dynamic shape") + parser.add_argument('--shape', type=str, default="(512,768)") + args = parser.parse_args() + shape = tuple(map(int, args.shape.strip('()').split(','))) + + from Scheduler.scheduler import PyTorchSimRunner + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + test_foobar(device, (1, 1)) + test_foobar(device, (47, 10)) + test_foobar(device, (128, 128)) + test_foobar(device, shape) +``` +--- +## Summary + +To add support for a new ATen operator in PyTorchSim’s MLIR path, you: + +1. Define an MLIR template (`mlir__template.py`), +2. Implement and register a custom lowering in `mlir_lowering.py`, +3. Validate the integration with a dedicated test (`test_.py`). + +The `_foobar` example illustrates the complete integration flow for adding a new +ATen operator in PyTorchSim. You can use this example as a reference when +extending PyTorchSim to support additional ATen operators. +