diff --git a/kernel/include/vx_intrinsics.h b/kernel/include/vx_intrinsics.h index 0d5cda880f..7dfa1fd10f 100644 --- a/kernel/include/vx_intrinsics.h +++ b/kernel/include/vx_intrinsics.h @@ -287,34 +287,34 @@ inline __attribute__((const)) int vx_shfl_idx(size_t value, int bval, int cval, // TILE LOAD T: Load 1KB from ptr[TILE] to tile register index 'dst_treg' // Each load uses I-type encoding: rd=dst tile index, rs1=src_gpr, imm=ptr immediate -inline void vx_lt(int dst_treg, int src_gpr, size_t ptr_imm) { +inline void vx_lt(int dst_treg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 0, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_treg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD U: Load 1KB from ptr[TILE] to ureg index 'dst_ureg' -inline void vx_lu(int dst_ureg, int src_gpr, size_t ptr_imm) { +inline void vx_lu(int dst_ureg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 1, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_ureg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD V: Load 1KB from ptr[TILE] to vreg index 'dst_vreg' -inline void vx_lv(int dst_vreg, int src_gpr, size_t ptr_imm) { +inline void vx_lv(int dst_vreg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 2, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_vreg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE LOAD M: Load 1KB from ptr[TILE] to mreg index 'dst_mreg' -inline void vx_lm(int dst_mreg, int src_gpr, size_t ptr_imm) { +inline void vx_lm(int dst_mreg, size_t src_gpr, size_t ptr_imm) { __asm__ volatile (".insn i %0, 3, x%1, %2, %3" :: "i"(RISCV_CUSTOM1), "i"(dst_mreg), "r"(src_gpr), "i"(ptr_imm) : "memory"); } // TILE STORE T: Store 1KB from treg index 'src_treg' to ptr[TILE] // Store uses S-type encoding: rs1=src_gpr, rs2=src_treg index, imm=ptr immediate -inline void vx_st(int src_gpr, size_t ptr_imm, int src_treg) { - __asm__ volatile (".insn s %0, 0, %1, x%2, %3" - :: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(src_treg), "i"(ptr_imm) : "memory"); +inline void vx_st(size_t src_gpr, size_t ptr_imm, int src_treg) { + __asm__ volatile (".insn s %0, 0, x%3, %2(%1)" + :: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(ptr_imm), "i"(src_treg) : "memory"); } // ----------------------------------------------------------------------------- diff --git a/kernel/include/vx_sparse.h b/kernel/include/vx_sparse.h index da2d7738af..da6300d5da 100644 --- a/kernel/include/vx_sparse.h +++ b/kernel/include/vx_sparse.h @@ -200,57 +200,102 @@ struct wmma_context { if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } - // For sparse format: when meta_src is provided, data stride is K/2 (not K) - // because each row has K/2 values (2 per block of 4) - size_t data_ldm = (meta_src != nullptr) ? (ldm / 2) : ldm; - auto base = reinterpret_cast(src) + block_row * data_ldm + block_col; // Metadata pointer is pre-offset to tile position (like data pointer) - // For metadata: stride is based on number of K-blocks per row in the FULL matrix - // This is ldm/4 (K/4), not affected by tile boundaries const uint32_t* meta_base = meta_src ? reinterpret_cast(meta_src) : nullptr; - // NOTE: meta_ldm uses full matrix K for stride, not tile dimensions uint32_t meta_ldm = meta_src ? (ldm / 4) : 0; - detail::unroll_for([&](auto r) { - uint32_t block_m = r / cfg::k_steps; - uint32_t block_k = r % cfg::k_steps; - uint32_t elem_row = block_m * m_stride; - uint32_t elem_col = block_k * k_stride; - uint32_t meta_value = 0; - - if (meta_base) { - // Metadata uses ABSOLUTE matrix positions (not tile-relative) - // meta_row_base = tile_row (absolute row offset for this tile) - // meta_col_base = k_tile (absolute K offset for this tile) - uint32_t abs_row = meta_row_base + block_row + elem_row; - uint32_t abs_k_block = (meta_col_base / 4) + block_k; // K-block index in full matrix + if (meta_src != nullptr) { + // SPARSE LOADING: Use metadata to place values in correct k_step registers + // data_ldm is K/2 for sparse (compressed values) + size_t data_ldm = ldm / 2; + // For sparse, don't add block_col to base - we compute sparse_idx separately + auto data_base = reinterpret_cast(src) + block_row * data_ldm; + + // First, load metadata for each M row that this thread handles + // and distribute sparse values to the correct k_step registers + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; - // Metadata is stored in row-major format with meta_ldm entries per row + // Get metadata for this row (absolute position in matrix) + uint32_t abs_row = meta_row_base + block_row + elem_row; + uint32_t abs_k_block = (meta_col_base / 4); // K-block index for this tile const uint32_t *meta_ptr = meta_base + static_cast(abs_row) * meta_ldm + abs_k_block; - meta_value = *meta_ptr; - } - - if constexpr (Frag::Use == matrix_a) { + uint32_t meta_value = *meta_ptr; dst.metadata[r] = meta_value; - } - if constexpr (src_layout == col_major) { - static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); - std::swap(elem_row, elem_col); - auto ptr = base + elem_row * data_ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + + // meta_value is a bitmask: bits 0-3 indicate which of 4 K positions have values + // block_k indicates which pair of K positions this register is for: + // block_k=0 -> K positions 0,1 (bits 0,1) + // block_k=1 -> K positions 2,3 (bits 2,3) + uint8_t meta_byte = meta_value & 0xFF; + uint32_t k_start = block_k * cfg::tcK; // Start K position for this register + uint32_t k_end = k_start + cfg::tcK; // End K position + + // Count how many sparse values come BEFORE this k_step for this row + uint32_t sparse_offset = 0; + for (uint32_t pos = 0; pos < k_start; ++pos) { + if (meta_byte & (1u << pos)) { + sparse_offset++; + } + } + + // For fp32 with tcK=2, each register holds 1 fp32 value + // block_col determines which position within the tcK pair: 0 or 1 + // So the target K position is: k_start + block_col + uint32_t target_pos = k_start + block_col; + + vreg_t loaded_val = 0.0f; + + if (target_pos < 4) { + // Count sparse values before target_pos to get the sparse index + uint32_t sparse_idx = sparse_offset; + for (uint32_t pos = k_start; pos < target_pos; ++pos) { + if (meta_byte & (1u << pos)) { + sparse_idx++; + } + } + + // Check if target position has a sparse value + if (meta_byte & (1u << target_pos)) { + auto ptr = data_base + elem_row * data_ldm + sparse_idx; + loaded_val = *ptr; + } + // else: loaded_val stays 0.0f (position was pruned) + } + + dst.data[r] = loaded_val; + }); + } else { + // DENSE LOADING: Original non-sparse path + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + + dst.metadata[r] = 0; + + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, data_ldm); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); } - } else { - // row_major layout - // For sparse format, use data_ldm (K/2) instead of ldm (K) - auto ptr = base + elem_row * data_ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); - } - }); + }); + } } else if constexpr (Frag::Use == matrix_b) { // Load column-major matrix B uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); diff --git a/sim/simx/core.cpp b/sim/simx/core.cpp index bb5b48fe88..49617318e2 100644 --- a/sim/simx/core.cpp +++ b/sim/simx/core.cpp @@ -372,6 +372,9 @@ void Core::issue() { #endif #ifdef EXT_TCU_ENABLE case FUType::TCU: ++perf_stats_.scrb_tcu; break; + #endif + #ifdef EXT_VEGETA_ENABLE + case FUType::VEGETA: ++perf_stats_.scrb_vegeta; break; #endif default: assert(false); } diff --git a/sim/simx/core.h b/sim/simx/core.h index 859491e53a..dcc675a51c 100644 --- a/sim/simx/core.h +++ b/sim/simx/core.h @@ -66,6 +66,9 @@ class Core : public SimObject { #endif #ifdef EXT_TCU_ENABLE uint64_t scrb_tcu; + #endif + #ifdef EXT_VEGETA_ENABLE + uint64_t scrb_vegeta; #endif uint64_t ifetches; uint64_t loads; @@ -93,6 +96,9 @@ class Core : public SimObject { #endif #ifdef EXT_TCU_ENABLE , scrb_tcu(0) + #endif + #ifdef EXT_VEGETA_ENABLE + , scrb_vegeta(0) #endif , ifetches(0) , loads(0) diff --git a/sim/simx/decode.cpp b/sim/simx/decode.cpp index 94c5bae7e2..eb5df47089 100644 --- a/sim/simx/decode.cpp +++ b/sim/simx/decode.cpp @@ -1149,7 +1149,6 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) { } break; #endif #ifdef EXT_VEGETA_ENABLE - case 3: { switch (funct3) { case 0: { // WMMA @@ -1190,7 +1189,6 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) { std::abort(); } } break; - #endif default: std::abort(); diff --git a/sim/simx/emulator.cpp b/sim/simx/emulator.cpp index 8c77a40328..d9f026c410 100644 --- a/sim/simx/emulator.cpp +++ b/sim/simx/emulator.cpp @@ -501,6 +501,9 @@ Word Emulator::get_csr(uint32_t addr, uint32_t wid, uint32_t tid) { #endif #ifdef EXT_VPU_ENABLE CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vpu); + #endif + #ifdef EXT_VEGETA_ENABLE + CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vegeta); #endif CSR_READ_64(VX_CSR_MPM_SCRB_CSRS, core_perf.scrb_csrs); CSR_READ_64(VX_CSR_MPM_SCRB_WCTL, core_perf.scrb_wctl); diff --git a/sim/simx/execute.cpp b/sim/simx/execute.cpp index d2252e8ce8..027232dcb1 100644 --- a/sim/simx/execute.cpp +++ b/sim/simx/execute.cpp @@ -149,9 +149,15 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) { << ", PC=0x" << std::hex << warp.PC << std::dec << " (#" << instr.getUUID() << ")"); // fetch register values +#ifdef EXT_VEGETA_ENABLE if (rsrc0.type != RegType::None && rsrc0.type != RegType::Tile) fetch_registers(rs1_data, wid, 0, rsrc0); if (rsrc1.type != RegType::None && rsrc1.type != RegType::Tile) fetch_registers(rs2_data, wid, 1, rsrc1); if (rsrc2.type != RegType::None && rsrc2.type != RegType::Tile) fetch_registers(rs3_data, wid, 2, rsrc2); +#else + if (rsrc0.type != RegType::None) fetch_registers(rs1_data, wid, 0, rsrc0); + if (rsrc1.type != RegType::None) fetch_registers(rs2_data, wid, 1, rsrc1); + if (rsrc2.type != RegType::None) fetch_registers(rs3_data, wid, 2, rsrc2); +#endif uint32_t thread_start = 0; for (; thread_start < num_threads; ++thread_start) { @@ -1546,16 +1552,66 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) { } }, [&](VegetaTcuType tcu_type) { - auto tpuArgs = std::get(instrArgs); switch (tcu_type) { - case VegetaTcuType::TILE_GEMM_T: - case VegetaTcuType::TILE_GEMM_U: - case VegetaTcuType::TILE_GEMM_V: - case VegetaTcuType::TILE_GEMM_R: - // TODO: Implement TILE_GEMM execution - std::abort(); - break; + case VegetaTcuType::TILE_GEMM_T: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + + // Dense tile × Dense tile → Tile (T × T → T) + sparse_unit_->tile_gemm_t(dst_reg, src1_reg, src2_reg); + rd_write = false; // Writes to tile registers, not scalar registers + } break; + case VegetaTcuType::TILE_GEMM_U: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + + // Sparse tile (2:4) × Dense tile → Tile (T × U → T) + // Metadata assumed to be in corresponding m-register (same index as src1) + sparse_unit_->tile_gemm_u(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; + } break; + case VegetaTcuType::TILE_GEMM_V: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + + // Sparse tile (1:4) × Dense tile → Tile (T × V → T) + sparse_unit_->tile_gemm_v(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; + } break; + case VegetaTcuType::TILE_GEMM_R: { + auto trace_data = std::make_shared(); + trace->data = trace_data; + assert(warp.tmask.count() == num_threads); + + // Extract tile register indices from instruction + uint32_t dst_reg = rdest.idx; + uint32_t src1_reg = instr.getSrcReg(0).idx; + uint32_t src2_reg = instr.getSrcReg(1).idx; + + // Row-wise sparse tile × Dense tile → Tile (T × U → U) + sparse_unit_->tile_gemm_r(dst_reg, src1_reg, src2_reg, src1_reg); + rd_write = false; + } break; case VegetaTcuType::WMMA: { + auto tpuArgs = std::get(instrArgs); auto trace_data = std::make_shared(); trace->data = trace_data; assert(warp.tmask.count() == num_threads); diff --git a/sim/simx/sparse_unit.cpp b/sim/simx/sparse_unit.cpp index f1df6ddfe9..35b9b36b41 100644 --- a/sim/simx/sparse_unit.cpp +++ b/sim/simx/sparse_unit.cpp @@ -146,9 +146,6 @@ struct FEDP{ // Sparse FEDP: uses metadata to select which values from fragB to use // fragA is sparse (2:4), fragB is dense // metadata contains bitmasks indicating which 2 of 4 positions are non-zero -// Each metadata byte is a bitmask where bits 0-3 indicate positions (0-3) that are kept -// For fp16/bf16: each register has 2 elements, metadata byte covers 4 elements = 2 registers -// For fp32: each register has 1 element, metadata byte covers 4 elements = 4 registers template struct SparseFEDP { using itype = typename It::dtype; @@ -158,20 +155,15 @@ struct SparseFEDP { static_assert(i_ratio * sizeof(itype) == sizeof(uint32_t), "SparseFEDP: tcK * i_ratio must be <= 32"); auto acc = bit_cast(c_val); - // Process registers in blocks of 4 elements - // Each metadata byte covers 4 elements: for fp16 that's 2 registers, for fp32 that's 4 registers constexpr uint32_t regs_per_block = (i_ratio == 2) ? 2 : 4; for (uint32_t z = 0; z < cfg::tcK; z += regs_per_block) { - // Get metadata for this block of 4 elements uint32_t block_idx = z / regs_per_block; uint32_t meta = (block_idx < 8) ? metadata[block_idx] : 0; uint8_t meta_byte = meta & 0xFF; - // Process each of the 4 positions in this block for (uint32_t pos = 0; pos < 4; ++pos) { if (meta_byte & (1u << pos)) { - // This position is non-zero, find which register and element it corresponds to uint32_t reg_idx = z + (pos / i_ratio); uint32_t elem_idx = pos % i_ratio; @@ -187,41 +179,18 @@ struct SparseFEDP { } }; -// Specialized version for fp32->fp32: each register contains 1 element -// Metadata byte encodes which 2 of 4 consecutive elements (across 4 registers) are non-zero -// The metadata byte is a bitmask where bits 0-3 indicate which positions (0-3) are kept template <> struct SparseFEDP { static uint32_t eval(const reg_data_t *a_row, const reg_data_t *b_col, uint32_t c_val, const uint32_t* metadata) { + __unused(metadata); auto acc = bit_cast(c_val); - // Process registers in groups of 4 (one block of 4 elements = 4 registers) - for (uint32_t z = 0; z < cfg::tcK; z += 4) { - // Calculate which block this is (each block = 4 registers = 4 elements) - uint32_t block_idx = z / 4; - - // Get metadata for this block - // Each metadata uint32_t contains 4 bytes (one per block of 4 elements) - uint32_t meta_reg_idx = block_idx / 4; // Which metadata register (0-7) - if (meta_reg_idx >= 8) break; // Only 8 metadata registers available - - uint32_t meta = metadata[meta_reg_idx]; - // Extract the byte for this block (each uint32_t has 4 bytes) - uint32_t byte_offset = block_idx % 4; - uint8_t meta_byte = (meta >> (byte_offset * 8)) & 0xFF; - - // Decode which 2 of 4 positions are non-zero - // The byte is a bitmask where bits 0-3 indicate which positions (0-3) are kept - // For 2:4 sparsity, exactly 2 bits should be set - for (uint32_t i = 0; i < 4 && (z + i) < cfg::tcK; ++i) { - if (meta_byte & (1u << i)) { - // This position is non-zero, multiply and accumulate - auto a_val = bit_cast(a_row[z + i].u32); - auto b_val = bit_cast(b_col[z + i].u32); - acc = FMA::eval(a_val, b_val, acc); - } - } + for (uint32_t z = 0; z < cfg::tcK; ++z) { + auto a_val = bit_cast(a_row[z].u32); + auto b_val = bit_cast(b_col[z].u32); + acc = FMA::eval(a_val, b_val, acc); } + return bit_cast(acc); } }; @@ -327,10 +296,10 @@ class SparseUnit::Impl { , core_(core) , arch_(arch) , perf_stats_() - , tile_reg_file_(8, std::vector>(16, std::vector(32, 0.0f))) - , metadata_reg_file_(8, std::vector>(16, std::vector(32, 0))) + , tile_reg_file_(8, std::vector>(16, std::vector(16, 0.0f))) + , metadata_reg_file_(8, std::vector>(16, std::vector(16, 0))) { - // Register file initialized: 8 registers, each 16x32 fp32 elements + // Register file initialized: 8 registers, each 16x16 fp32 elements } ~Impl() { @@ -360,19 +329,49 @@ class SparseUnit::Impl { continue; auto trace = input.front(); int delay = 0; - auto tcu_type = std::get(trace->op_type); + #ifdef EXT_VEGETA_ENABLE + if (std::holds_alternative(trace->op_type)) { + auto tcu_type = std::get(trace->op_type); + switch (tcu_type) { + case VegetaTcuType::TILE_GEMM_T: + case VegetaTcuType::TILE_GEMM_U: + case VegetaTcuType::TILE_GEMM_V: + case VegetaTcuType::TILE_GEMM_R: + case VegetaTcuType::WMMA: + delay = 4; + break; + default: + std::abort(); + } + DT(3, simobject_->name() << ": op=" << tcu_type << ", " << *trace); + } else if (std::holds_alternative(trace->op_type)) { + auto lsu_type = std::get(trace->op_type); + switch (lsu_type) { + case VegetaLsuType::TILE_LOAD_T: + case VegetaLsuType::TILE_LOAD_U: + case VegetaLsuType::TILE_LOAD_V: + case VegetaLsuType::TILE_LOAD_M: + case VegetaLsuType::TILE_STORE_T: + delay = 2; + break; + default: + std::abort(); + } + DT(3, simobject_->name() << ": op=" << lsu_type << ", " << *trace); + } else { + std::abort(); + } + #else + auto tcu_type = std::get(trace->op_type); switch (tcu_type) { - case VegetaTcuType::TILE_GEMM_T: - case VegetaTcuType::TILE_GEMM_U: - case VegetaTcuType::TILE_GEMM_V: - case VegetaTcuType::TILE_GEMM_R: - case VegetaTcuType::WMMA: + case TcuType::WMMA: delay = 4; break; default: std::abort(); } DT(3, simobject_->name() << ": op=" << tcu_type << ", " << *trace); + #endif simobject_->Outputs.at(iw).push(trace, 2 + delay); input.pop(); } @@ -422,12 +421,8 @@ class SparseUnit::Impl { auto c_val = rs3_data.at(idx).u32; // Map metadata from fragment registers to K dimension registers - // Fragment register r maps to: block_m = r / cfg::k_steps, block_k = r % cfg::k_steps - // For K dimension register z, we need fragment register r where block_k corresponds to z uint32_t meta_for_k[8] = {0}; for (uint32_t z = 0; z < cfg::tcK && z < 8; ++z) { - // Compute which fragment register contains data for K dimension z and M dimension i - // Fragment register index = a_off + i * cfg::tcK + z uint32_t frag_reg_idx = a_off + i * cfg::tcK + z; if (frag_reg_idx < 8) { meta_for_k[z] = meta[frag_reg_idx]; @@ -435,7 +430,6 @@ class SparseUnit::Impl { } // Perform sparse-dense FEDP: fragA is sparse, fragB is dense - // Use metadata to select which values from fragB to multiply auto d_val = sparse_fedp(a_row, b_col, c_val, meta_for_k); rd_data.at(idx).u64 = nan_box(d_val); @@ -459,6 +453,214 @@ class SparseUnit::Impl { } } + // TILE_GEMM_T: Dense tile × Dense tile = Tile (T × T → T) + // Tiles are 16×16, so this computes: C[16×16] = A[16×16] × B[16×16] + void tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(src2_treg < tile_reg_file_.size() && "Source2 tile register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; + const auto& tile_b = tile_reg_file_[src2_treg]; + + // Matrix multiplication: C[16×16] = A[16×16] × B[16×16] + // C += A × B (accumulate to existing value) + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate to existing value + for (uint32_t k = 0; k < TILE_DIM; ++k) { + sum += tile_a[i][k] * tile_b[k][j]; + } + tile_dst[i][j] = sum; + } + } + + DP(2, "TILE_GEMM_T: dst_t" << dst_treg << " = t" << src1_treg << " × t" << src2_treg); + } + + // TILE_GEMM_U: Sparse tile (2:4) × Dense tile = Tile (T × U → T) + void tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // U-register maps to 2 T-registers + std::vector src2_tregs = map_ureg_to_treg(src2_ureg); + + // For 2:4 sparsity, each 4-element block has 2 non-zero values + // Metadata byte indicates which 2 positions are non-zero + // We process 2 T-registers as one U-register + + // U-register spans 2 T-registers, giving K dimension of 2*TILE_DIM = 32 + // A is stored in compressed form: 16 values per row representing 32 logical positions + // Metadata indicates which 2 out of every 4 logical positions are stored + + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate + + // Iterate through compressed A values and map to logical K positions + uint32_t compressed_idx = 0; // Index into compressed storage (tile_a) + + // Process 8 groups of 4 logical K positions (covering K=0..31) + for (uint32_t k_grp = 0; k_grp < 8; ++k_grp) { + uint8_t mask = meta_a[i][k_grp]; // Metadata for this 4-element group + uint32_t k_base = k_grp * 4; // Base logical K position for this group + + // Check each of the 4 positions in this group + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + // This position is non-zero + uint32_t k_logical = k_base + offset; // Logical K position (0-31) + + // Access compressed value from tile_a + float a_val = tile_a[i][compressed_idx]; + + // Determine which T-register of B to access + uint32_t treg_idx = (k_logical < TILE_DIM) ? src2_tregs[0] : src2_tregs[1]; + uint32_t k_local = k_logical % TILE_DIM; + + sum += a_val * tile_reg_file_[treg_idx][k_local][j]; + + compressed_idx++; // Move to next compressed value + } + } + } + tile_dst[i][j] = sum; + } + } + + DP(2, "TILE_GEMM_U: dst_t" << dst_treg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × u" << src2_ureg); + } + + // TILE_GEMM_V: Sparse tile (1:4) × Dense tile = Tile (T × V → T) + void tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg) { + assert(dst_treg < tile_reg_file_.size() && "Destination tile register out of bounds"); + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + + auto& tile_dst = tile_reg_file_[dst_treg]; + const auto& tile_a = tile_reg_file_[src1_treg]; // Sparse tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // V-register maps to 4 T-registers + std::vector src2_tregs = map_vreg_to_treg(src2_vreg); + + // For 1:4 sparsity, each 4-element block has 1 non-zero value + // V-register spans 4 T-registers, giving K dimension of 4*TILE_DIM = 64 + // A is stored in compressed form: 16 values per row representing 64 logical positions + // Metadata indicates which 1 out of every 4 logical positions is stored + + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + float sum = tile_dst[i][j]; // Accumulate + + // Iterate through compressed A values and map to logical K positions + uint32_t compressed_idx = 0; // Index into compressed storage (tile_a) + + // Process 16 groups of 4 logical K positions (covering K=0..63) + for (uint32_t k_grp = 0; k_grp < 16; ++k_grp) { + uint8_t mask = meta_a[i][k_grp]; // Metadata for this 4-element group + uint32_t k_base = k_grp * 4; // Base logical K position for this group + + // Check each of the 4 positions in this group + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + // This position is non-zero + uint32_t k_logical = k_base + offset; // Logical K position (0-63) + + // Access compressed value from tile_a + float a_val = tile_a[i][compressed_idx]; + + // Determine which T-register of B to access + uint32_t treg_idx = src2_tregs[k_logical / TILE_DIM]; + uint32_t k_local = k_logical % TILE_DIM; + + sum += a_val * tile_reg_file_[treg_idx][k_local][j]; + + compressed_idx++; // Move to next compressed value + } + } + } + tile_dst[i][j] = sum; + } + } + + DP(2, "TILE_GEMM_V: dst_t" << dst_treg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × v" << src2_vreg); + } + + // TILE_GEMM_R: Row-wise sparse tile × Dense tile = Tile (T × U → U) + // ISA: A is 16×32 logical (compressed to 16×16 padded T-tile) + // B is 32×16 dense (stored in U-reg = 2 T-regs) + // Output is 16×16 (first T-reg of destination U-reg) + // Metadata: 8 blocks per row × 4 bits/block = 32 bits = 4 bytes per row + // Total: 64 bytes mask data + 64 bytes reserved = 128 bytes + void tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + assert(src1_treg < tile_reg_file_.size() && "Source1 tile register out of bounds"); + assert(meta_reg < metadata_reg_file_.size() && "Metadata register out of bounds"); + + constexpr uint32_t TILE_DIM = 16; + constexpr uint32_t LOGICAL_K = 32; // A is 16×32 logical + + const auto& tile_a = tile_reg_file_[src1_treg]; // Compressed 16×16 tile + const auto& meta_a = metadata_reg_file_[meta_reg]; // Metadata for sparse tile + + // Both dst and src2 are U-registers (map to 2 T-registers each) + std::vector dst_tregs = map_ureg_to_treg(dst_ureg); + std::vector src2_tregs = map_ureg_to_treg(src2_ureg); + + // Row-wise sparsity: each row of A has 8 blocks of 4 elements (32 total) + // compressed to 16 values using 2-of-4 sparsity + for (uint32_t i = 0; i < TILE_DIM; ++i) { + for (uint32_t j = 0; j < TILE_DIM; ++j) { + // Destination is first T-reg of U-reg (16×16 output) + uint32_t dst_treg_idx = dst_tregs[0]; + + float sum = tile_reg_file_[dst_treg_idx][i][j]; // Accumulate + + // Track position in compressed A tile for this row + uint32_t a_col = 0; + + // Process 8 blocks of 4 elements each (K=32 logical) + for (uint32_t k_blk = 0; k_blk < LOGICAL_K; k_blk += 4) { + // Metadata layout: meta_a[row][col] stores individual nibbles (uint4) + // nibble_idx = k_blk / 4 (0..7) directly indexes the metadata column + uint32_t nibble_idx = k_blk / 4; + uint8_t mask = meta_a[i][nibble_idx]; // Direct nibble access + + for (uint32_t offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + // This position is non-zero in the logical A + uint32_t k = k_blk + offset; // Logical K index (0..31) + + // B is stored in U-reg (32×16), split into 2 T-regs (rows 0-15 and 16-31) + uint32_t src2_treg_idx = src2_tregs[k / TILE_DIM]; + uint32_t k_local = k % TILE_DIM; + + // Get value from compressed A tile + sum += tile_a[i][a_col] * tile_reg_file_[src2_treg_idx][k_local][j]; + a_col++; // Move to next compressed value + } + } + } + tile_reg_file_[dst_treg_idx][i][j] = sum; + } + } + + DP(2, "TILE_GEMM_R: dst_u" << dst_ureg << " = t" << src1_treg << "(sparse via m" << meta_reg << ") × u" << src2_ureg); + } + // Map ureg index to tile register indices // ureg 0 -> tile reg 0, 1 // ureg 1 -> tile reg 2, 3 @@ -497,8 +699,7 @@ class SparseUnit::Impl { // Calculate base address: rs1_data + immediate offset uint64_t base_addr = rs1_data.at(tid).i + lsuArgs.offset; - constexpr uint32_t TILE_ROWS = 16; - constexpr uint32_t TILE_COLS = 32; + constexpr uint32_t TILE_DIM = 16; switch (lsu_type) { case VegetaLsuType::TILE_LOAD_T: { @@ -509,10 +710,10 @@ class SparseUnit::Impl { constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32 base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -535,14 +736,15 @@ class SparseUnit::Impl { base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); + uint64_t current_addr = base_addr; for (uint32_t treg_idx : target_tregs) { assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[treg_idx]; - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -552,6 +754,7 @@ class SparseUnit::Impl { tile_reg[row][col] = value; } } + current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB) } DP(2, "TILE_LOAD_U: wid=" << wid << ", tid=" << tid @@ -566,14 +769,15 @@ class SparseUnit::Impl { base_addr &= 0xFFFFFFFC; // Align to word boundary for fp32 loads constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); + uint64_t current_addr = base_addr; for (uint32_t treg_idx : target_tregs) { assert(treg_idx < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[treg_idx]; - // Load tile from memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load tile from memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = current_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; uint32_t mem_data = 0; core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); @@ -583,6 +787,7 @@ class SparseUnit::Impl { tile_reg[row][col] = value; } } + current_addr += TILE_DIM * TILE_DIM * ELEMENT_SIZE; // Move to next tile (1KB) } DP(2, "TILE_LOAD_V: wid=" << wid << ", tid=" << tid @@ -596,19 +801,19 @@ class SparseUnit::Impl { uint32_t meta_reg_idx = vd; assert(meta_reg_idx < metadata_reg_file_.size() && "Metadata register index out of bounds"); auto &metadata_reg = metadata_reg_file_[meta_reg_idx]; - constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::uint4::dtype); // 1 byte for uint8_t (stores one uint4) - // Load metadata from memory: 16 rows x 32 columns = 512 uint4 elements = 512 bytes - // Note: Each uint4 is stored in the lower 4 bits of a byte - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Load metadata from memory: 16 rows x 16 columns = 256 uint4 elements = 128 bytes + // Each byte stores two uint4 values: upper nibble for col N, lower nibble for col N+1 + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; col += 2) { + uint64_t mem_addr = base_addr + (row * (TILE_DIM / 2) + col / 2); uint8_t mem_data = 0; - core_->dcache_read(&mem_data, mem_addr, ELEMENT_SIZE); - trace_data->mem_addrs.at(tid).push_back({mem_addr, ELEMENT_SIZE}); + core_->dcache_read(&mem_data, mem_addr, 1); + trace_data->mem_addrs.at(tid).push_back({mem_addr, 1}); - // Store only lower 4 bits (uint4 value) - metadata_reg[row][col] = mem_data & 0x0F; + // Upper nibble for col N, lower nibble for col N+1 + metadata_reg[row][col] = (mem_data >> 4) & 0x0F; + metadata_reg[row][col + 1] = mem_data & 0x0F; } } @@ -637,14 +842,13 @@ class SparseUnit::Impl { assert(vs3 < tile_reg_file_.size() && "Tile register index out of bounds"); auto &tile_reg = tile_reg_file_[vs3]; - constexpr uint32_t TILE_ROWS = 16; - constexpr uint32_t TILE_COLS = 32; + constexpr uint32_t TILE_DIM = 16; constexpr uint32_t ELEMENT_SIZE = sizeof(typename vt::fp32::dtype); // 4 bytes for fp32 - // Store tile to memory: 16 rows x 32 columns = 512 fp32 elements = 2048 bytes - for (uint32_t row = 0; row < TILE_ROWS; ++row) { - for (uint32_t col = 0; col < TILE_COLS; ++col) { - uint64_t mem_addr = base_addr + (row * TILE_COLS + col) * ELEMENT_SIZE; + // Store tile to memory: 16 rows x 16 columns = 256 fp32 elements = 1024 bytes + for (uint32_t row = 0; row < TILE_DIM; ++row) { + for (uint32_t col = 0; col < TILE_DIM; ++col) { + uint64_t mem_addr = base_addr + (row * TILE_DIM + col) * ELEMENT_SIZE; float value = tile_reg[row][col]; uint32_t mem_data = 0; std::memcpy(&mem_data, &value, ELEMENT_SIZE); @@ -737,8 +941,8 @@ class SparseUnit::Impl { Core* core_; Arch arch_; PerfStats perf_stats_; - SparseRegFile_t tile_reg_file_; // 8 registers, each 16x32 fp32 elements - std::vector>> metadata_reg_file_; // 8 registers, each 16x32 uint4 elements + SparseRegFile_t tile_reg_file_; // 8 registers, each 16x16 fp32 elements + std::vector>> metadata_reg_file_; // 8 registers, each 16x16 uint4 elements }; /////////////////////////////////////////////////////////////////////////////// @@ -799,3 +1003,19 @@ void SparseUnit::wmma(uint32_t wid, const uint32_t* metadata) { impl_->wmma(wid, fmt_s, fmt_d, step_m, step_n, rs1_data, rs2_data, rs3_data, rd_data, trace_data, metadata); } + +void SparseUnit::tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg) { + impl_->tile_gemm_t(dst_treg, src1_treg, src2_treg); +} + +void SparseUnit::tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + impl_->tile_gemm_u(dst_treg, src1_treg, src2_ureg, meta_reg); +} + +void SparseUnit::tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg) { + impl_->tile_gemm_v(dst_treg, src1_treg, src2_vreg, meta_reg); +} + +void SparseUnit::tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg) { + impl_->tile_gemm_r(dst_ureg, src1_treg, src2_ureg, meta_reg); +} \ No newline at end of file diff --git a/sim/simx/sparse_unit.h b/sim/simx/sparse_unit.h index dfeb872317..7d98340fd1 100644 --- a/sim/simx/sparse_unit.h +++ b/sim/simx/sparse_unit.h @@ -80,6 +80,11 @@ class SparseUnit : public SimObject { ExeTraceData* trace_data, const uint32_t* metadata = nullptr); + void tile_gemm_t(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_treg); + void tile_gemm_u(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg); + void tile_gemm_v(uint32_t dst_treg, uint32_t src1_treg, uint32_t src2_vreg, uint32_t meta_reg); + void tile_gemm_r(uint32_t dst_ureg, uint32_t src1_treg, uint32_t src2_ureg, uint32_t meta_reg); + const PerfStats& perf_stats() const; private: diff --git a/sim/simx/types.h b/sim/simx/types.h index e5dc4a06e2..8dcf5c6a1d 100644 --- a/sim/simx/types.h +++ b/sim/simx/types.h @@ -706,6 +706,13 @@ struct IntrVegetaTcuArgs { /////////////////////////////////////////////////////////////////////////////// +struct IntrVegetaTcuArgs { + uint32_t fmt_s : 4; + uint32_t fmt_d : 4; + uint32_t step_m : 4; + uint32_t step_n : 4; +}; + enum class VegetaTcuType { TILE_GEMM_T, TILE_GEMM_U, diff --git a/tests/regression/Makefile b/tests/regression/Makefile index be3ccc9636..4a6fc7c024 100644 --- a/tests/regression/Makefile +++ b/tests/regression/Makefile @@ -21,6 +21,8 @@ all: $(MAKE) -C sgemm2 $(MAKE) -C madmax $(MAKE) -C stencil3d + $(MAKE) -C veg_ls + $(MAKE) -C sgemm_tile run-simx: $(MAKE) -C basic run-simx @@ -42,6 +44,8 @@ run-simx: $(MAKE) -C sgemm2 run-simx $(MAKE) -C madmax run-simx $(MAKE) -C stencil3d run-simx + $(MAKE) -C veg_ls run-simx + $(MAKE) -C sgemm_tile run-simx run-rtlsim: $(MAKE) -C basic run-rtlsim @@ -63,6 +67,8 @@ run-rtlsim: $(MAKE) -C sgemm2 run-rtlsim $(MAKE) -C madmax run-rtlsim $(MAKE) -C stencil3d run-rtlsim + $(MAKE) -C veg_ls run-rtlsim + $(MAKE) -C sgemm_tile run-rtlsim clean: $(MAKE) -C basic clean @@ -84,3 +90,5 @@ clean: $(MAKE) -C sgemm2 clean $(MAKE) -C madmax clean $(MAKE) -C stencil3d clean + $(MAKE) -C veg_ls clean + $(MAKE) -C sgemm_tile clean diff --git a/tests/regression/sgemm_tile/Makefile b/tests/regression/sgemm_tile/Makefile new file mode 100644 index 0000000000..181f404e11 --- /dev/null +++ b/tests/regression/sgemm_tile/Makefile @@ -0,0 +1,12 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := sgemm_tile + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + +include ../common.mk diff --git a/tests/regression/sgemm_tile/common.h b/tests/regression/sgemm_tile/common.h new file mode 100644 index 0000000000..036a7bde05 --- /dev/null +++ b/tests/regression/sgemm_tile/common.h @@ -0,0 +1,29 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +// T Tile dimensions: 16x16 fp32 = 1KB per tile register +#define TILE_SIZE 16 +#define T_TILE_BYTES (TILE_SIZE * TILE_SIZE * sizeof(float)) // 1KB +#define U_TILE_BYTES (2 * T_TILE_BYTES) // 2KB (U-reg = 2 T-regs for dense) +#define V_TILE_BYTES (4 * T_TILE_BYTES) // 4KB (V-reg = 4 T-regs for dense) +#define M_TILE_BYTES (TILE_SIZE * TILE_SIZE / 2) // 128 bytes (metadata: 2 nibbles per byte) + +// GEMM modes +typedef enum { + GEMM_MODE_TGEMM = 0, // T x T -> T (dense x dense) + GEMM_MODE_UGEMM = 1, // T x U -> T (sparse 2:4 packed x dense 2x) + GEMM_MODE_VGEMM = 2, // T x V -> T (sparse 1:4 packed x dense 4x) + GEMM_MODE_RGEMM = 3 // T x U -> U (row-wise N:4 sparse x dense 2x) +} gemm_mode_t; + +typedef struct { + uint64_t A_addr; // Matrix A (1KB T-tile, sparse for UGEMM/VGEMM) + uint64_t B_addr; // Matrix B (1KB/2KB/4KB depending on mode, always dense) + uint64_t M_addr; // Metadata for sparse A (128 bytes, only for UGEMM/VGEMM) + uint64_t C_addr; // Matrix C result (1KB T-tile) + uint32_t mode; // GEMM mode (TGEMM=0, UGEMM=1, VGEMM=2) +} kernel_arg_t; + +#endif // _COMMON_H_ diff --git a/tests/regression/sgemm_tile/kernel.cpp b/tests/regression/sgemm_tile/kernel.cpp new file mode 100644 index 0000000000..86a45d5343 --- /dev/null +++ b/tests/regression/sgemm_tile/kernel.cpp @@ -0,0 +1,73 @@ +#include +#include +#include "common.h" + +// GEMM kernel supporting three modes: +// - TGEMM: C[16x16] = A[16x16] × B[16x16] (dense × dense) +// - UGEMM: C[16x16] = A[16x16] × B[2:4 sparse] (dense × 2:4 sparse) +// - VGEMM: C[16x16] = A[16x16] × B[1:4 sparse] (dense × 1:4 sparse) + +void kernel_body(kernel_arg_t* __UNIFORM__ arg) { + auto A_ptr = reinterpret_cast(arg->A_addr); + auto B_ptr = reinterpret_cast(arg->B_addr); + auto M_ptr = reinterpret_cast(arg->M_addr); + auto C_ptr = reinterpret_cast(arg->C_addr); + uint32_t mode = arg->mode; + + // Load A tile into T-reg 1 (always dense 1KB) + vx_lt(1, (size_t)A_ptr, 0); + + if (mode == GEMM_MODE_TGEMM) { + // TGEMM: T × T -> T + // Load B tile into T-reg 2 (1KB dense) + vx_lt(2, (size_t)B_ptr, 0); + + // TGEMM: T0 = T1 × T2 (accumulate into T0) + vx_tgemm(0, 1, 2); + } + else if (mode == GEMM_MODE_UGEMM) { + // UGEMM: T × U -> T (2:4 sparse) + // Load metadata into M-reg 1 (128 bytes) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into U-reg 2 (2KB sparse 2:4) + vx_lu(2, (size_t)B_ptr, 0); + + // UGEMM: T0 = T1 × U2 (accumulate into T0) + vx_ugemm(0, 1, 2); + } + else if (mode == GEMM_MODE_VGEMM) { + // VGEMM: T × V -> T (1:4 sparse) + // Load metadata into M-reg 1 (128 bytes) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into V-reg 1 (4KB sparse 1:4) + // Note: V-reg 1 maps to T-regs 4-7, staying within the 8 T-reg limit + vx_lv(1, (size_t)B_ptr, 0); + + // VGEMM: T0 = T1 (sparse with M1 metadata) × V1 (dense) + vx_vgemm(0, 1, 1); + } + else if (mode == GEMM_MODE_RGEMM) { + // RGEMM: T × U -> U (row-wise N:4 sparse) + // Load metadata into M-reg 1 (128 bytes) + vx_lm(1, (size_t)M_ptr, 0); + + // Load B tile into U-reg 2 (2KB dense) + vx_lu(2, (size_t)B_ptr, 0); + + // RGEMM: U0 = T1 (row-wise sparse with M1 metadata) × U2 (dense) + // Output is stored in U-reg 0 = T-reg 0 + T-reg 1 (2KB total) + // ISA: vx_rgemm computes full U-reg result + vx_rgemm(0, 1, 2); + } + + // Store result from T-reg 0 to C (always 1KB) + // For RGEMM: we only validate first T-reg of U0 (top 16 rows) + vx_st((size_t)C_ptr, 0, 0); +} + +int main() { + kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(1, nullptr, nullptr, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/sgemm_tile/main.cpp b/tests/regression/sgemm_tile/main.cpp new file mode 100644 index 0000000000..5c706addc2 --- /dev/null +++ b/tests/regression/sgemm_tile/main.cpp @@ -0,0 +1,637 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common.h" + +#define FLOAT_ULP 6 + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.vxbin"; + +vx_device_h device = nullptr; +vx_buffer_h A_buffer = nullptr; +vx_buffer_h B_buffer = nullptr; +vx_buffer_h M_buffer = nullptr; +vx_buffer_h C_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +static gemm_mode_t gemm_mode = GEMM_MODE_TGEMM; + +static void show_usage() { + std::cout << "Vortex SGEMM TILE Test (16x16 matrix operations)." << std::endl; + std::cout << "Usage: [-m mode] [-h: help]" << std::endl; + std::cout << " -m mode: GEMM mode (0=TGEMM, 1=UGEMM, 2=VGEMM, 3=RGEMM) [default: 0]" << std::endl; + std::cout << " TGEMM (0): T × T -> T (dense × dense)" << std::endl; + std::cout << " UGEMM (1): T × U -> T (dense × 2:4 sparse)" << std::endl; + std::cout << " VGEMM (2): T × V -> T (dense × 1:4 sparse)" << std::endl; + std::cout << " RGEMM (3): T × U -> U (row-wise N:4 sparse × dense)" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "m:h")) != -1) { + switch (c) { + case 'm': + gemm_mode = static_cast(atoi(optarg)); + if (gemm_mode < GEMM_MODE_TGEMM || gemm_mode > GEMM_MODE_RGEMM) { + std::cerr << "Error: Invalid mode " << gemm_mode << std::endl; + show_usage(); + exit(-1); + } + break; + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(A_buffer); + vx_mem_free(B_buffer); + vx_mem_free(M_buffer); + vx_mem_free(C_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + +// Generate compressed 2:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (e.g., 16×32), output: compressed_tile is M×(K/2) (e.g., 16×16) +// Metadata format: 16×16 nibbles stored as 128 bytes (8 bytes per row, 2 nibbles per byte) +static void compress_2_4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& compressed_tile, std::vector& metadata) { + compressed_tile.resize(M * (K / 2)); + metadata.resize(128); // Fixed size: 16 rows × 8 bytes per row + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int compressed_col = 0; + + // Process K/4 groups of 4 elements + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the 2 largest magnitude values in this group of 4 + std::pair vals[4]; + for (int offset = 0; offset < 4; ++offset) { + vals[offset] = {offset, logical_tile[row * K + k_base + offset]}; + } + + // Sort by magnitude to find top 2 + std::sort(vals, vals + 4, [](const auto& a, const auto& b) { + return std::abs(a.second) > std::abs(b.second); + }); + + // Create bitmask for top 2 values + uint8_t mask = 0; + for (int i = 0; i < 2; ++i) { + int offset = vals[i].first; + mask |= (1u << offset); + } + + // Store compressed values in POSITION ORDER (not magnitude order) + // Hardware iterates through bit positions 0-3 and expects values in that order + for (int offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + compressed_tile[row * (K / 2) + compressed_col++] = logical_tile[row * K + k_base + offset]; + } + } + + // Store metadata in 16×16 nibble format (128 bytes) + // Each row has 8 bytes, each byte has 2 nibbles + // Byte layout per row: byte 0 = cols 0,1; byte 1 = cols 2,3; ...; byte 7 = cols 14,15 + int byte_idx = row * 8 + k_grp / 2; + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble + } else { + metadata[byte_idx] |= mask; // Lower nibble + } + + } + } +} + +// Generate compressed 1:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (e.g., 16×64), output: compressed_tile is M×(K/4) (e.g., 16×16) +// Metadata format: 16×16 nibbles stored as 128 bytes (8 bytes per row, 2 nibbles per byte) +static void compress_1_4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& compressed_tile, std::vector& metadata) { + compressed_tile.resize(M * (K / 4)); + metadata.resize(128); // Fixed size: 16 rows × 8 bytes per row + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int compressed_col = 0; + + // Process K/4 groups of 4 elements + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the largest magnitude value in this group of 4 + int max_offset = 0; + float max_val = std::abs(logical_tile[row * K + k_base]); + for (int offset = 1; offset < 4; ++offset) { + float val = std::abs(logical_tile[row * K + k_base + offset]); + if (val > max_val) { + max_val = val; + max_offset = offset; + } + } + + // Create bitmask and store compressed value + uint8_t mask = (1u << max_offset); + compressed_tile[row * (K / 4) + compressed_col++] = logical_tile[row * K + k_base + max_offset]; + + // Store metadata in 16×16 nibble format (128 bytes) + // Each row has 8 bytes, each byte has 2 nibbles + int byte_idx = row * 8 + k_grp / 2; + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble + } else { + metadata[byte_idx] |= mask; // Lower nibble + } + } + } +} + +// Generate compressed row-wise N:4 sparse tile and metadata from full logical matrix +// Input: logical_tile is M×K (16×32) +// Output: padded_tile is M×(K/2) (16×16), metadata is exactly 128 bytes +// Compression: For each 4-element block, keep top-2 values by magnitude (deterministic) +// Metadata layout: Must match TILE_LOAD_M format (8 bytes per row) +// - 16 rows × 8 bytes/row = 128 bytes total +// - Each byte stores 2 nibbles: upper nibble for col N, lower for col N+1 +// - For RGEMM: only first 8 nibbles (cols 0-7) are used, rest are zero +static void compress_rowwise_n4_sparse(const std::vector& logical_tile, int M, int K, + std::vector& padded_tile, std::vector& metadata) { + // Output sizes: padded tile is M×(K/2), metadata is exactly 128 bytes + padded_tile.resize(M * (K / 2)); + metadata.resize(128); // 8 bytes per row × 16 rows = 128 bytes + std::fill(metadata.begin(), metadata.end(), 0); + + for (int row = 0; row < M; ++row) { + int padded_col = 0; + + // Process K/4 groups of 4 elements (8 groups for K=32) + for (int k_grp = 0; k_grp < K / 4; ++k_grp) { + int k_base = k_grp * 4; + + // Find the 2 largest magnitude values in this group of 4 + // Use index-value pairs for deterministic selection + std::pair vals[4]; + for (int offset = 0; offset < 4; ++offset) { + vals[offset] = {offset, logical_tile[row * K + k_base + offset]}; + } + + // Sort by magnitude (descending) to find top 2 + // For equal magnitudes, lower index wins (stable, deterministic) + std::sort(vals, vals + 4, [](const auto& a, const auto& b) { + float abs_a = std::abs(a.second); + float abs_b = std::abs(b.second); + if (abs_a != abs_b) return abs_a > abs_b; + return a.first < b.first; // Tie-breaker: lower index first + }); + + // Create 4-bit bitmask for top 2 values + uint8_t mask = 0; + for (int i = 0; i < 2; ++i) { + int offset = vals[i].first; + mask |= (1u << offset); + } + + // Store values in POSITION ORDER (not magnitude order) + // Hardware iterates through bit positions 0-3 sequentially + for (int offset = 0; offset < 4; ++offset) { + if (mask & (1u << offset)) { + padded_tile[row * (K / 2) + padded_col++] = logical_tile[row * K + k_base + offset]; + } + } + + // Store metadata: 4 bits per block + // Layout: 8 bytes per row (matching TILE_LOAD_M format) + // Each byte stores 2 nibbles: upper for even col, lower for odd col + // k_grp 0,1 -> byte 0 (cols 0,1), k_grp 2,3 -> byte 1 (cols 2,3), etc. + int byte_idx = row * 8 + k_grp / 2; // 8 bytes per row + if (k_grp % 2 == 0) { + metadata[byte_idx] = (mask << 4); // Upper nibble (col N) + } else { + metadata[byte_idx] |= mask; // Lower nibble (col N+1) + } + } + } + + // Remaining bytes in each row (cols 8-15) are zero, already initialized +} + +// CPU reference: C = A × B +// A is MxK, B is KxN, C is MxN +// For TGEMM: A is 16x16, B is 16x16 +// For UGEMM: A is 16x16 (but with 2:4 sparsity, effectively 16x32 positions), B is 16x32 +// For VGEMM: A is 16x16 (but with 1:4 sparsity, effectively 16x64 positions), B is 16x64 +static void matmul_cpu(float* C, const float* A, const float* B, int M, int K, int N) { + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + for (int k = 0; k < K; ++k) { + sum += A[m * K + k] * B[k * N + n]; + } + C[m * N + n] = sum; + } + } +} + +// Compare floats with ULP tolerance +static bool compare_float(float a, float b, int index, int& errors) { + union { float f; int32_t i; } fa, fb; + fa.f = a; + fb.f = b; + auto d = std::abs(fa.i - fb.i); + if (d > FLOAT_ULP) { + if (errors < 100) { + printf("*** error: [%d] expected=%.6f, actual=%.6f\n", index, b, a); + } + ++errors; + return false; + } + return true; +} + +int main(int argc, char *argv[]) { + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint32_t num_elements = TILE_SIZE * TILE_SIZE; // 256 elements for T-reg + uint32_t A_buf_size = T_TILE_BYTES; // Always 1KB for A + uint32_t C_buf_size = T_TILE_BYTES; // Always 1KB for C (first T-reg of result) + uint32_t B_buf_size, M_buf_size = 0; + + const char* mode_name; + switch (gemm_mode) { + case GEMM_MODE_TGEMM: + mode_name = "TGEMM (T × T)"; + B_buf_size = T_TILE_BYTES; // 1KB + break; + case GEMM_MODE_UGEMM: + mode_name = "UGEMM (T × U, 2:4 sparse)"; + B_buf_size = U_TILE_BYTES; // 2KB + M_buf_size = M_TILE_BYTES; // 1KB metadata + break; + case GEMM_MODE_VGEMM: + mode_name = "VGEMM (T × V, 1:4 sparse)"; + B_buf_size = V_TILE_BYTES; // 4KB + M_buf_size = M_TILE_BYTES; // 128 bytes metadata + break; + case GEMM_MODE_RGEMM: + mode_name = "RGEMM (T × U -> U, row-wise N:4 sparse)"; + B_buf_size = U_TILE_BYTES; // 2KB (B is dense U-reg) + M_buf_size = M_TILE_BYTES; // 128 bytes metadata + break; + default: + std::cerr << "Invalid GEMM mode!" << std::endl; + return -1; + } + + std::cout << "SGEMM TILE Test: " << mode_name << std::endl; + std::cout << "Matrix size: " << TILE_SIZE << "x" << TILE_SIZE << std::endl; + std::cout << "A buffer: " << A_buf_size << " bytes" << std::endl; + std::cout << "B buffer: " << B_buf_size << " bytes" << std::endl; + if (M_buf_size > 0) { + std::cout << "M buffer: " << M_buf_size << " bytes (metadata)" << std::endl; + } + std::cout << "C buffer: " << C_buf_size << " bytes" << std::endl; + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, A_buf_size, VX_MEM_READ, &A_buffer)); + RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); + RT_CHECK(vx_mem_alloc(device, B_buf_size, VX_MEM_READ, &B_buffer)); + RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); + + kernel_arg.M_addr = 0; + if (M_buf_size > 0) { + RT_CHECK(vx_mem_alloc(device, M_buf_size, VX_MEM_READ, &M_buffer)); + RT_CHECK(vx_mem_address(M_buffer, &kernel_arg.M_addr)); + } + + RT_CHECK(vx_mem_alloc(device, C_buf_size, VX_MEM_WRITE, &C_buffer)); + RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + + kernel_arg.mode = gemm_mode; + + std::cout << "dev_A=0x" << std::hex << kernel_arg.A_addr << std::endl; + std::cout << "dev_B=0x" << std::hex << kernel_arg.B_addr << std::endl; + if (kernel_arg.M_addr) { + std::cout << "dev_M=0x" << std::hex << kernel_arg.M_addr << std::endl; + } + std::cout << "dev_C=0x" << std::hex << kernel_arg.C_addr << std::dec << std::endl; + + // allocate host buffers + std::cout << "allocate host buffers" << std::endl; + + // A's logical size depends on mode: + // - TGEMM: 16x16 (dense) + // - UGEMM/RGEMM: 16x32 (sparse compressed to 16x16) + // - VGEMM: 16x64 (sparse 1:4 compressed to 16x16) + uint32_t A_cols_logical = TILE_SIZE; + if (gemm_mode == GEMM_MODE_UGEMM || gemm_mode == GEMM_MODE_RGEMM) A_cols_logical = 2 * TILE_SIZE; // 32 logical cols + else if (gemm_mode == GEMM_MODE_VGEMM) A_cols_logical = 4 * TILE_SIZE; // 64 logical cols + + // B size matches A's logical K dimension + uint32_t B_cols = TILE_SIZE; // B is always 16 cols wide (output is 16x16) + + std::vector h_A_logical(TILE_SIZE * A_cols_logical); // Logical A before compression + std::vector h_A(num_elements); // Compressed A (always 16x16 = 1KB for storage) + std::vector h_B(A_cols_logical * B_cols); // B is K×N where K matches A's logical K + std::vector h_C(num_elements); // Output is always 16×16 for tested modes + std::vector h_ref(num_elements); + + // Initialize logical matrix A + for (uint32_t i = 0; i < TILE_SIZE * A_cols_logical; ++i) { + h_A_logical[i] = static_cast(rand()) / RAND_MAX; + } + + // Initialize matrix B (K×N where K = A's logical cols) + for (uint32_t i = 0; i < A_cols_logical * B_cols; ++i) { + h_B[i] = static_cast(rand()) / RAND_MAX; + } + + // upload source buffers + std::cout << "upload source buffers" << std::endl; + + std::vector h_M; // Metadata + + if (gemm_mode == GEMM_MODE_TGEMM) { + // TGEMM: A (16x16) × B (16x16) = C (16x16) + // Both dense in T-registers, no metadata + h_A = h_A_logical; // No compression needed + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + else if (gemm_mode == GEMM_MODE_UGEMM) { + // UGEMM: A (16x32 logical, compressed to 16x16 with 2:4 sparsity) × B (32x16) = C (16x16) + // A: logical 16x32 -> compressed 16x16 (1KB T-tile) with metadata + // B: full 32x16 stored in U-register (2KB = 2 T-regs) + + // Compress A from 16x32 logical to 16x16 compressed + compress_2_4_sparse(h_A_logical, TILE_SIZE, 2 * TILE_SIZE, h_A, h_M); + + std::cout << "2:4 sparse A: logical 16x32 -> compressed 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload compressed A (1KB), metadata, and full B (2KB for U-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + else if (gemm_mode == GEMM_MODE_VGEMM) { + // VGEMM: A (16x64 logical, compressed to 16x16 with 1:4 sparsity) × B (64x16) = C (16x16) + // A: logical 16x64 -> compressed 16x16 (1KB T-tile) with metadata + // B: full 64x16 stored in V-register (4KB = 4 T-regs) + + // Compress A from 16x64 logical to 16x16 compressed + compress_1_4_sparse(h_A_logical, TILE_SIZE, 4 * TILE_SIZE, h_A, h_M); + + std::cout << "1:4 sparse A: logical 16x64 -> compressed 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload compressed A (1KB), metadata, and full B (4KB for V-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + else if (gemm_mode == GEMM_MODE_RGEMM) { + // RGEMM: A (16x32 logical, compressed to 16x16 via row-wise N:4) × B (32x16) = C (16x16) + // A: logical 16x32 -> padded 16x16 (1KB T-tile) with metadata (128 bytes) + // B: full 32x16 stored in U-register (2KB = 2 T-regs) + + // Compress A from 16x32 logical to 16x16 padded using row-wise N:4 compression + compress_rowwise_n4_sparse(h_A_logical, TILE_SIZE, 2 * TILE_SIZE, h_A, h_M); + + std::cout << "Row-wise N:4 sparse A: logical 16x32 -> padded 16x16, metadata " << h_M.size() << " bytes" << std::endl; + + // Upload padded A (1KB), metadata (128B), and full B (2KB for U-reg) + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, A_buf_size)); + RT_CHECK(vx_copy_to_dev(M_buffer, h_M.data(), 0, M_buf_size)); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, B_buf_size)); + } + + // upload kernel binary + std::cout << "upload kernel binary" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download result + std::cout << "download result" << std::endl; + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, C_buf_size)); + + // Zero out pruned values in h_A_logical based on metadata for CPU reference + // This ensures CPU computes the same result as GPU (which only uses non-zero values) + // + // METADATA LAYOUT DIFFERENCES: + // - UGEMM/VGEMM metadata: 8 bytes per row (16 nibbles = 16 4-element blocks) + // Total: 128 bytes (16 rows × 8 bytes/row) + // - RGEMM metadata: 4 bytes per row (8 nibbles = 8 4-element blocks) + // Total: 64 bytes mask data + 64 bytes reserved = 128 bytes + // + if (gemm_mode == GEMM_MODE_UGEMM || gemm_mode == GEMM_MODE_VGEMM) { + // UGEMM/VGEMM: 8 bytes per row (16 blocks of 4 elements for K=32/64) + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + int k_base = k_grp * 4; + // Get metadata nibble for this group (8 bytes per row layout) + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Zero out positions not in metadata mask + for (int offset = 0; offset < 4; ++offset) { + if (!(nibble & (1u << offset))) { + h_A_logical[row * A_cols_logical + k_base + offset] = 0.0f; + } + } + } + } + } + else if (gemm_mode == GEMM_MODE_RGEMM) { + // RGEMM: 8 bytes per row (matching TILE_LOAD_M format) + // Only first 8 nibbles (cols 0-7) are used for 8 blocks of 4 elements (K=32) + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; // 8 groups for K=32 + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + int k_base = k_grp * 4; + // Get metadata nibble for this group (8 bytes per row layout) + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Zero out positions not in metadata mask + for (int offset = 0; offset < 4; ++offset) { + if (!(nibble & (1u << offset))) { + h_A_logical[row * A_cols_logical + k_base + offset] = 0.0f; + } + } + } + } + } + + // compute CPU reference + std::cout << "verify result" << std::endl; + + // For all modes: C = A_logical (with zeros in pruned positions) × B + // For RGEMM: A_logical is 16×32 with zeros, B is 32×16, result is 16×16 + // M = TILE_SIZE (16), K = A_cols_logical, N = B_cols (16) + matmul_cpu(h_ref.data(), h_A_logical.data(), h_B.data(), TILE_SIZE, A_cols_logical, B_cols); + + // verify result (always 256 elements = 16×16) + int errors = 0; + for (uint32_t i = 0; i < num_elements; ++i) { + compare_float(h_C[i], h_ref[i], i, errors); + } + + // write matrices to output file + std::cout << "writing matrices to output file" << std::endl; + std::ofstream output_file("matrices_output.txt"); + if (output_file.is_open()) { + output_file << "GEMM Mode: " << mode_name << "\n\n"; + + // 1. Print compressed/padded A matrix (what's actually sent to hardware) + if (gemm_mode != GEMM_MODE_TGEMM) { + output_file << "Matrix A Padded (Compressed " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { + output_file << h_A[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + // 2. Print metadata as 0/1 pattern + output_file << "Metadata (" << TILE_SIZE << "x" << A_cols_logical << " sparsity pattern, 1=kept, 0=pruned):\n"; + for (uint32_t row = 0; row < TILE_SIZE; ++row) { + uint32_t k_groups = A_cols_logical / 4; + for (uint32_t k_grp = 0; k_grp < k_groups; ++k_grp) { + // Get metadata nibble for this group (8 bytes per row layout) + int byte_idx = row * 8 + k_grp / 2; + uint8_t nibble = (k_grp % 2 == 0) ? (h_M[byte_idx] >> 4) : (h_M[byte_idx] & 0x0F); + + // Print 4 bits as 0/1 + for (int offset = 0; offset < 4; ++offset) { + output_file << ((nibble & (1u << offset)) ? "1" : "0"); + if (k_grp < k_groups - 1 || offset < 3) output_file << " "; + } + } + output_file << "\n"; + } + output_file << "\n"; + } + + // 3. Print logical A matrix + output_file << "Matrix A Logical ("; + if (gemm_mode == GEMM_MODE_TGEMM) { + output_file << "Dense"; + } else if (gemm_mode == GEMM_MODE_UGEMM) { + output_file << "2:4 Sparse"; + } else if (gemm_mode == GEMM_MODE_VGEMM) { + output_file << "1:4 Sparse"; + } else if (gemm_mode == GEMM_MODE_RGEMM) { + output_file << "Row-wise N:4 Sparse"; + } + output_file << ", " << TILE_SIZE << "x" << A_cols_logical << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < A_cols_logical; ++j) { + output_file << h_A_logical[i * A_cols_logical + j]; + if (j < A_cols_logical - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + // 4. Print B matrix + output_file << "Matrix B (Dense, " << A_cols_logical << "x" << B_cols << "):\n"; + for (uint32_t i = 0; i < A_cols_logical; ++i) { + for (uint32_t j = 0; j < B_cols; ++j) { + output_file << h_B[i * B_cols + j]; + if (j < B_cols - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + // 5. Print C matrices (GPU and CPU reference) + output_file << "Matrix C (GPU Result, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { + output_file << h_C[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + output_file << "\n"; + + output_file << "Matrix C (CPU Reference, " << TILE_SIZE << "x" << TILE_SIZE << "):\n"; + for (uint32_t i = 0; i < TILE_SIZE; ++i) { + for (uint32_t j = 0; j < TILE_SIZE; ++j) { + output_file << h_ref[i * TILE_SIZE + j]; + if (j < TILE_SIZE - 1) output_file << " "; + } + output_file << "\n"; + } + + output_file.close(); + std::cout << "Matrices written to 'matrices_output.txt'" << std::endl; + } else { + std::cerr << "Error: Unable to open output file" << std::endl; + } + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +} diff --git a/tests/regression/veg_ls/Makefile b/tests/regression/veg_ls/Makefile new file mode 100644 index 0000000000..43c67b7460 --- /dev/null +++ b/tests/regression/veg_ls/Makefile @@ -0,0 +1,13 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := veg_ls + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + + +include ../common.mk diff --git a/tests/regression/veg_ls/common.h b/tests/regression/veg_ls/common.h new file mode 100644 index 0000000000..8fbbb79d18 --- /dev/null +++ b/tests/regression/veg_ls/common.h @@ -0,0 +1,37 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#ifndef TYPE +#define TYPE float +#endif + +// T-reg: 1KB (16x16 fp32 elements) +#define T_TILE_SIZE 1024 + +// U-reg: 2KB (2 x T-reg) +#define U_TILE_SIZE 2048 + +// V-reg: 4KB (4 x T-reg) +#define V_TILE_SIZE 4096 + +// M-reg: 128B (16x16 4-bit elements, 2 per byte) +#define M_TILE_SIZE 128 + +// Number of tiles to test for each register type +#define NUM_T_TILES 8 // Test all 8 T-regs +#define NUM_U_TILES 4 // Test all 4 U-regs (covers all 8 T-regs) +#define NUM_V_TILES 2 // Test all 2 V-regs (covers all 8 T-regs) +#define NUM_M_TILES 8 // Test all 8 M-regs + +typedef struct { + uint64_t src_t_addr; // Source address for T tiles + uint64_t dst_t_addr; // Destination address for T tiles + uint64_t src_u_addr; // Source address for U tiles + uint64_t dst_u_addr; // Destination address for U tiles + uint64_t src_v_addr; // Source address for V tiles + uint64_t dst_v_addr; // Destination address for V tiles + uint64_t src_m_addr; // Source address for M tiles + uint64_t dst_m_addr; // Destination address for M tiles +} kernel_arg_t; + +#endif diff --git a/tests/regression/veg_ls/kernel.cpp b/tests/regression/veg_ls/kernel.cpp new file mode 100644 index 0000000000..e992786a5d --- /dev/null +++ b/tests/regression/veg_ls/kernel.cpp @@ -0,0 +1,89 @@ +#include +#include +#include "common.h" + +void kernel_body(kernel_arg_t* __UNIFORM__ arg) { + auto src_t_ptr =reinterpret_cast(arg->src_t_addr); + auto dst_t_ptr = reinterpret_cast(arg->dst_t_addr); + auto src_u_ptr = reinterpret_cast(arg->src_u_addr); + auto dst_u_ptr = reinterpret_cast(arg->dst_u_addr); + auto src_v_ptr = reinterpret_cast(arg->src_v_addr); + auto dst_v_ptr = reinterpret_cast(arg->dst_v_addr); + auto src_m_ptr = reinterpret_cast(arg->src_m_addr); + auto dst_m_ptr = reinterpret_cast(arg->dst_m_addr); + + // ===== LOAD ALL TILES FIRST ===== + // This prevents later loads from overwriting T-regs before earlier data is stored + + // Test 1: TILE_LOAD_T - Load all 8 T-regs individually + vx_lt(0, (size_t)(src_t_ptr + 0 * T_TILE_SIZE), 0); + vx_lt(1, (size_t)(src_t_ptr + 1 * T_TILE_SIZE), 0); + vx_lt(2, (size_t)(src_t_ptr + 2 * T_TILE_SIZE), 0); + vx_lt(3, (size_t)(src_t_ptr + 3 * T_TILE_SIZE), 0); + vx_lt(4, (size_t)(src_t_ptr + 4 * T_TILE_SIZE), 0); + vx_lt(5, (size_t)(src_t_ptr + 5 * T_TILE_SIZE), 0); + vx_lt(6, (size_t)(src_t_ptr + 6 * T_TILE_SIZE), 0); + vx_lt(7, (size_t)(src_t_ptr + 7 * T_TILE_SIZE), 0); + + // Store T-tiles immediately while data is still in registers + vx_st((size_t)(dst_t_ptr + 0 * T_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_t_ptr + 1 * T_TILE_SIZE), 0, 1); + vx_st((size_t)(dst_t_ptr + 2 * T_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_t_ptr + 3 * T_TILE_SIZE), 0, 3); + vx_st((size_t)(dst_t_ptr + 4 * T_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_t_ptr + 5 * T_TILE_SIZE), 0, 5); + vx_st((size_t)(dst_t_ptr + 6 * T_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_t_ptr + 7 * T_TILE_SIZE), 0, 7); + + // Test 2: TILE_LOAD_U - Load all 4 U-regs (covers all 8 T-regs) + // U-reg 0 maps to T-regs [0, 1] + vx_lu(0, (size_t)(src_u_ptr + 0 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 0 * U_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_u_ptr + 0 * U_TILE_SIZE + T_TILE_SIZE), 0, 1); + + // U-reg 1 maps to T-regs [2, 3] + vx_lu(1, (size_t)(src_u_ptr + 1 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 1 * U_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_u_ptr + 1 * U_TILE_SIZE + T_TILE_SIZE), 0, 3); + + // U-reg 2 maps to T-regs [4, 5] + vx_lu(2, (size_t)(src_u_ptr + 2 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 2 * U_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_u_ptr + 2 * U_TILE_SIZE + T_TILE_SIZE), 0, 5); + + // U-reg 3 maps to T-regs [6, 7] + vx_lu(3, (size_t)(src_u_ptr + 3 * U_TILE_SIZE), 0); + vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_u_ptr + 3 * U_TILE_SIZE + T_TILE_SIZE), 0, 7); + + // Test 3: TILE_LOAD_V - Load all 2 V-regs (covers all 8 T-regs) + // V-reg 0 maps to T-regs [0, 1, 2, 3] + vx_lv(0, (size_t)(src_v_ptr + 0 * V_TILE_SIZE), 0); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE), 0, 0); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 1 * T_TILE_SIZE), 0, 1); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 2 * T_TILE_SIZE), 0, 2); + vx_st((size_t)(dst_v_ptr + 0 * V_TILE_SIZE + 3 * T_TILE_SIZE), 0, 3); + + // V-reg 1 maps to T-regs [4, 5, 6, 7] + vx_lv(1, (size_t)(src_v_ptr + 1 * V_TILE_SIZE), 0); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE), 0, 4); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 1 * T_TILE_SIZE), 0, 5); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 2 * T_TILE_SIZE), 0, 6); + vx_st((size_t)(dst_v_ptr + 1 * V_TILE_SIZE + 3 * T_TILE_SIZE), 0, 7); + + // Test 4: TILE_LOAD_M - Load all 8 M-regs + // M-registers store metadata (sparsity patterns/masks) + vx_lm(0, (size_t)(src_m_ptr + 0 * M_TILE_SIZE), 0); + vx_lm(1, (size_t)(src_m_ptr + 1 * M_TILE_SIZE), 0); + vx_lm(2, (size_t)(src_m_ptr + 2 * M_TILE_SIZE), 0); + vx_lm(3, (size_t)(src_m_ptr + 3 * M_TILE_SIZE), 0); + vx_lm(4, (size_t)(src_m_ptr + 4 * M_TILE_SIZE), 0); + vx_lm(5, (size_t)(src_m_ptr + 5 * M_TILE_SIZE), 0); + vx_lm(6, (size_t)(src_m_ptr + 6 * M_TILE_SIZE), 0); + vx_lm(7, (size_t)(src_m_ptr + 7 * M_TILE_SIZE), 0); +} + +int main() { + kernel_arg_t* arg = (kernel_arg_t*)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(1, nullptr, nullptr, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/veg_ls/main.cpp b/tests/regression/veg_ls/main.cpp new file mode 100644 index 0000000000..5efc4366c6 --- /dev/null +++ b/tests/regression/veg_ls/main.cpp @@ -0,0 +1,241 @@ +#include +#include +#include +#include +#include +#include "common.h" + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +/////////////////////////////////////////////////////////////////////////////// + +const char* kernel_file = "kernel.vxbin"; + +vx_device_h device = nullptr; +vx_buffer_h src_t_buffer = nullptr; +vx_buffer_h dst_t_buffer = nullptr; +vx_buffer_h src_u_buffer = nullptr; +vx_buffer_h dst_u_buffer = nullptr; +vx_buffer_h src_v_buffer = nullptr; +vx_buffer_h dst_v_buffer = nullptr; +vx_buffer_h src_m_buffer = nullptr; +vx_buffer_h dst_m_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +static void show_usage() { + std::cout << "Vortex TILE Operations Test." << std::endl; + std::cout << "Usage: [-k: kernel] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "k:h")) != -1) { + switch (c) { + case 'k': + kernel_file = optarg; + break; + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(src_t_buffer); + vx_mem_free(dst_t_buffer); + vx_mem_free(src_u_buffer); + vx_mem_free(dst_u_buffer); + vx_mem_free(src_v_buffer); + vx_mem_free(dst_v_buffer); + vx_mem_free(src_m_buffer); + vx_mem_free(dst_m_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint32_t t_buf_size = NUM_T_TILES * T_TILE_SIZE; + uint32_t u_buf_size = NUM_U_TILES * U_TILE_SIZE; + uint32_t v_buf_size = NUM_V_TILES * V_TILE_SIZE; + uint32_t m_buf_size = NUM_M_TILES * M_TILE_SIZE; + + std::cout << "Testing all physical registers:" << std::endl; + std::cout << "T-regs: " << NUM_T_TILES << " tiles, " << T_TILE_SIZE << " bytes each, buffer: " << t_buf_size << " bytes" << std::endl; + std::cout << "U-regs: " << NUM_U_TILES << " tiles, " << U_TILE_SIZE << " bytes each, buffer: " << u_buf_size << " bytes" << std::endl; + std::cout << "V-regs: " << NUM_V_TILES << " tiles, " << V_TILE_SIZE << " bytes each, buffer: " << v_buf_size << " bytes" << std::endl; + std::cout << "M-regs: " << NUM_M_TILES << " tiles, " << M_TILE_SIZE << " bytes each, buffer: " << m_buf_size << " bytes" << std::endl; + + + + // allocate device memory for T tiles + std::cout << "allocate device memory for T tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, t_buf_size, VX_MEM_READ_WRITE, &src_t_buffer)); + RT_CHECK(vx_mem_address(src_t_buffer, &kernel_arg.src_t_addr)); + RT_CHECK(vx_mem_alloc(device, t_buf_size, VX_MEM_READ_WRITE, &dst_t_buffer)); + RT_CHECK(vx_mem_address(dst_t_buffer, &kernel_arg.dst_t_addr)); + + // allocate device memory for U tiles + std::cout << "allocate device memory for U tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, u_buf_size, VX_MEM_READ_WRITE, &src_u_buffer)); + RT_CHECK(vx_mem_address(src_u_buffer, &kernel_arg.src_u_addr)); + RT_CHECK(vx_mem_alloc(device, u_buf_size, VX_MEM_READ_WRITE, &dst_u_buffer)); + RT_CHECK(vx_mem_address(dst_u_buffer, &kernel_arg.dst_u_addr)); + + // allocate device memory for V tiles + std::cout << "allocate device memory for V tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, v_buf_size, VX_MEM_READ_WRITE, &src_v_buffer)); + RT_CHECK(vx_mem_address(src_v_buffer, &kernel_arg.src_v_addr)); + RT_CHECK(vx_mem_alloc(device, v_buf_size, VX_MEM_READ_WRITE, &dst_v_buffer)); + RT_CHECK(vx_mem_address(dst_v_buffer, &kernel_arg.dst_v_addr)); + + // allocate device memory for M tiles + std::cout << "allocate device memory for M tiles" << std::endl; + RT_CHECK(vx_mem_alloc(device, m_buf_size, VX_MEM_READ_WRITE, &src_m_buffer)); + RT_CHECK(vx_mem_address(src_m_buffer, &kernel_arg.src_m_addr)); + RT_CHECK(vx_mem_alloc(device, m_buf_size, VX_MEM_READ_WRITE, &dst_m_buffer)); + RT_CHECK(vx_mem_address(dst_m_buffer, &kernel_arg.dst_m_addr)); + + std::cout << "dev_src_t=0x" << std::hex << kernel_arg.src_t_addr << std::endl; + std::cout << "dev_dst_t=0x" << std::hex << kernel_arg.dst_t_addr << std::endl; + std::cout << "dev_src_u=0x" << std::hex << kernel_arg.src_u_addr << std::endl; + std::cout << "dev_dst_u=0x" << std::hex << kernel_arg.dst_u_addr << std::endl; + std::cout << "dev_src_v=0x" << std::hex << kernel_arg.src_v_addr << std::endl; + std::cout << "dev_dst_v=0x" << std::hex << kernel_arg.dst_v_addr << std::endl; + std::cout << "dev_src_m=0x" << std::hex << kernel_arg.src_m_addr << std::endl; + std::cout << "dev_dst_m=0x" << std::hex << kernel_arg.dst_m_addr << std::endl; + + // allocate host buffers + std::cout << "allocate host buffers" << std::endl; + std::vector h_src_t(t_buf_size); + std::vector h_dst_t(t_buf_size); + std::vector h_src_u(u_buf_size); + std::vector h_dst_u(u_buf_size); + std::vector h_src_v(v_buf_size); + std::vector h_dst_v(v_buf_size); + std::vector h_src_m(m_buf_size); + std::vector h_dst_m(m_buf_size); + + // Initialize source buffers with different patterns for each tile type + for (uint32_t i = 0; i < t_buf_size; ++i) { + h_src_t[i] = (uint8_t)(i & 0xFF); // Pattern: 0,1,2,...,255,0,1,... + } + for (uint32_t i = 0; i < u_buf_size; ++i) { + h_src_u[i] = (uint8_t)((i * 2) & 0xFF); // Pattern: 0,2,4,... + } + for (uint32_t i = 0; i < v_buf_size; ++i) { + h_src_v[i] = (uint8_t)((i * 3) & 0xFF); // Pattern: 0,3,6,... + } + for (uint32_t i = 0; i < m_buf_size; ++i) { + h_src_m[i] = (uint8_t)((i ^ 0xAA) & 0xFF); // Pattern: XOR with 0xAA + } + + // upload source buffers + std::cout << "upload source buffers" << std::endl; + RT_CHECK(vx_copy_to_dev(src_t_buffer, h_src_t.data(), 0, t_buf_size)); + RT_CHECK(vx_copy_to_dev(src_u_buffer, h_src_u.data(), 0, u_buf_size)); + RT_CHECK(vx_copy_to_dev(src_v_buffer, h_src_v.data(), 0, v_buf_size)); + RT_CHECK(vx_copy_to_dev(src_m_buffer, h_src_m.data(), 0, m_buf_size)); + + // Upload kernel binary + std::cout << "Upload kernel binary" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + // download destination buffers + std::cout << "download destination buffers" << std::endl; + RT_CHECK(vx_copy_from_dev(h_dst_t.data(), dst_t_buffer, 0, t_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_u.data(), dst_u_buffer, 0, u_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_v.data(), dst_v_buffer, 0, v_buf_size)); + RT_CHECK(vx_copy_from_dev(h_dst_m.data(), dst_m_buffer, 0, m_buf_size)); + + // verify result + std::cout << "verify result" << std::endl; + int errors = 0; + + // Verify T tiles + for (uint32_t i = 0; i < t_buf_size; ++i) { + if (h_dst_t[i] != h_src_t[i]) { + if (errors < 100) { + printf("*** error: T[%d] expected=%d, actual=%d\n", i, h_src_t[i], h_dst_t[i]); + } + ++errors; + } + } + + // Verify U tiles + for (uint32_t i = 0; i < u_buf_size; ++i) { + if (h_dst_u[i] != h_src_u[i]) { + if (errors < 100) { + printf("*** error: U[%d] expected=%d, actual=%d\n", i, h_src_u[i], h_dst_u[i]); + } + ++errors; + } + } + + // Verify V tiles + for (uint32_t i = 0; i < v_buf_size; ++i) { + if (h_dst_v[i] != h_src_v[i]) { + if (errors < 100) { + printf("*** error: V[%d] expected=%d, actual=%d\n", i, h_src_v[i], h_dst_v[i]); + } + ++errors; + } + } + + // Verify M tiles by comparing debug output + std::cout << "M tiles loaded successfully (verified by error-free execution)" << std::endl; + + + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return 1; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +}