diff --git a/src/IRMutator.h b/src/IRMutator.h index c170b37eb42b..4c06e12eda97 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -343,6 +343,15 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) { } } +template +auto mutate_with(const IRNode *ir, Lambdas &&...lambdas) -> IRHandle { + if (ir->node_type <= StrongestExprNodeType) { + return mutate_with(Expr((const BaseExprNode *)ir), std::forward(lambdas)...); + } else { + return mutate_with(Stmt((const BaseStmtNode *)ir), std::forward(lambdas)...); + } +} + /** A helper function for mutator-like things to mutate regions */ template std::pair mutate_region(Mutator *mutator, const Region &bounds, Args &&...args) { diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index 85691921bc8d..5073d6194522 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -1,5 +1,6 @@ #include "StageStridedLoads.h" #include "CSE.h" +#include "ExprUsesVar.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" @@ -95,12 +96,15 @@ class FindStridedLoads : public IRVisitor { base = base_add->a; offset = *off; } + } else if (auto off = as_const_int(base)) { + base = 0; + offset = *off; } // TODO: We do not yet handle nested vectorization here for // ramps which have not already collapsed. We could potentially // handle more interesting types of shuffle than simple flat slices. - if (stride >= 2 && stride < r->lanes && r->stride.type().is_scalar()) { + if (stride >= 2 && stride <= r->lanes && r->stride.type().is_scalar()) { const IRNode *s = scope; const Allocate *a = nullptr; if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { @@ -154,17 +158,35 @@ class FindStridedLoads : public IRVisitor { // Replace a bunch of load expressions in a stmt class ReplaceStridedLoads : public IRMutator { public: - std::map, Expr> replacements; + std::map replacements; std::map padding; - Scope allocation_scope; + std::map>> let_injections; + + Stmt mutate(const Stmt &s) override { + Stmt stmt = IRMutator::mutate(s); + auto it = let_injections.find(s.get()); + if (it != let_injections.end()) { + for (const auto &[name, value] : it->second) { + stmt = LetStmt::make(name, value, stmt); + } + } + return stmt; + } + + Expr mutate(const Expr &e) override { + Expr expr = IRMutator::mutate(e); + auto it = let_injections.find(e.get()); + if (it != let_injections.end()) { + for (const auto &[name, value] : it->second) { + expr = Let::make(name, value, expr); + } + } + return expr; + } protected: Expr visit(const Load *op) override { - const Allocate *alloc = nullptr; - if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { - alloc = *a_ptr; - } - auto it = replacements.find({alloc, op}); + auto it = replacements.find(op); if (it != replacements.end()) { return mutate(it->second); } else { @@ -173,7 +195,6 @@ class ReplaceStridedLoads : public IRMutator { } Stmt visit(const Allocate *op) override { - ScopedBinding bind(allocation_scope, op->name, op); auto it = padding.find(op); Stmt s = IRMutator::visit(op); if (it == padding.end()) { @@ -191,12 +212,96 @@ class ReplaceStridedLoads : public IRMutator { using IRMutator::visit; }; +const IRNode *innermost_containing_node(const IRNode *root, const std::set &exprs) { + const IRNode *result = nullptr; + // The innermost containing stmt is whichever stmt node contains the + // largest number of our exprs, with ties breaking inwards. + int seen = 0, best = 0; + mutate_with(root, // + [&](auto *self, const Stmt &s) { + int old = seen; + self->mutate_base(s); + if (old == 0 && seen > best) { + result = s.get(); + best = seen; + } + return s; // + }, + [&](auto *self, const Expr &e) { + int old = seen; + const Load *l = e.as(); + if (l && exprs.count(l)) { + seen++; + }; + self->mutate_base(e); + if (old == 0 && seen > best) { + result = e.get(); + best = seen; + } + return e; // + }); + internal_assert(seen) << "None of the exprs were found\n"; + return result; +} + +bool can_hoist_shared_load(const IRNode *n, const std::string &buf, const Expr &idx) { + // Check none of the variables the idx depends on are defined somewhere + // within this stmt, there are no stores to the given buffer in the stmt, + // and other side-effecty things that might either write to the buffer or + // guard against out of bounds reads don't occur either. + bool result = true; + visit_with(n, // + [&](auto *self, const Let *let) { // + result &= !expr_uses_var(idx, let->name); + self->visit_base(let); + }, + [&](auto *self, const LetStmt *let) { // + result &= !expr_uses_var(idx, let->name); + self->visit_base(let); + }, + [&](auto *self, const For *loop) { // + result &= !expr_uses_var(idx, loop->name); + self->visit_base(loop); + }, + [&](auto *self, const Allocate *alloc) { // + result &= alloc->name != buf; + self->visit_base(alloc); + }, + [&](auto *self, const Store *store) { // + result &= store->name != buf; + self->visit_base(store); + }, + [&](auto *self, const AssertStmt *a) { // + // Extern stages always come with asserts, even when + // no_asserts is on (they get stripped later), so this also + // guards against writes to the buffer happening in an extern + // stage. copy_to_host/device calls also come with asserts. + result = false; + }); + return result; +} + } // namespace -Stmt stage_strided_loads(const Stmt &s) { +Stmt stage_strided_loads(const Stmt &stmt) { FindStridedLoads finder; ReplaceStridedLoads replacer; + // Make all strided loads distinct IR nodes so that we can uniquely identify + // them by address. We may want to mutate the same load node in different + // ways depending on the surrounding context. + Stmt s = mutate_with(stmt, [&](auto *self, const Load *l) { + const Ramp *r = l->index.as(); + if (l->type.is_scalar() || (r && is_const_one(r->stride))) { + // Definitely not a strided load + return self->visit_base(l); + } else { + // Might be a strided load after simplification + return Load::make(l->type, l->name, self->mutate(l->index), l->image, l->param, + self->mutate(l->predicate), l->alignment); + } + }); + // Find related clusters of strided loads anywhere in the stmt. While this // appears to look globally, it requires expressions to match exactly, so // really it's only going to find things inside the same loops and let @@ -205,7 +310,6 @@ Stmt stage_strided_loads(const Stmt &s) { for (const auto &l : finder.found_loads) { const FindStridedLoads::Key &k = l.first; - const Allocate *alloc = k.allocation; const std::map> &v = l.second; // Find clusters of strided loads that can share the same dense load. @@ -225,16 +329,47 @@ Stmt stage_strided_loads(const Stmt &s) { // We have a complete cluster of loads. Make a single dense load int lanes = k.lanes * k.stride; int64_t first_offset = load->first; - Expr idx = Ramp::make(k.base + (int)first_offset, make_one(k.base.type()), lanes); + Expr base = common_subexpression_elimination(k.base); + Expr idx = Ramp::make(base + (int)first_offset, make_one(k.base.type()), lanes); Type t = k.type.with_lanes(lanes); const Load *op = load->second[0]; + + std::set all_loads; + for (auto l = load; l != v.end() && l->first < first_offset + k.stride; l++) { + all_loads.insert(l->second.begin(), l->second.end()); + } + Expr shared_load = Load::make(t, k.buf, idx, op->image, op->param, const_true(lanes), op->alignment); - shared_load = common_subexpression_elimination(shared_load); - for (; load != v.end() && load->first < first_offset + k.stride; load++) { - Expr shuf = Shuffle::make_slice(shared_load, load->first - first_offset, k.stride, k.lanes); - for (const Load *l : load->second) { - replacer.replacements.emplace(std::make_pair(alloc, l), shuf); + + // We now need to pick a site to place our shared dense load. We + // can't lift the shared load further out than k.scope, because that + // marks the outermost point at which the loads are known to both + // definitely occur (we don't want to escape an if statement) and + // produce a single fixed value (we don't want to cross over a store + // to the same buffer). If k.scope is null, the loads are valid + // everywhere, so we can hoist the single shared dense load to + // wherever we like. + const IRNode *outermost = k.scope ? k.scope : s.get(); + const IRNode *let_site = innermost_containing_node(outermost, all_loads); + if (can_hoist_shared_load(let_site, k.buf, idx)) { + std::string name = unique_name('t'); + Expr var = Variable::make(shared_load.type(), name); + for (; load != v.end() && load->first < first_offset + k.stride; load++) { + int row = load->first - first_offset; + Expr shuf = Shuffle::make_slice(var, row, k.stride, k.lanes); + for (const Load *l : load->second) { + replacer.replacements.emplace(l, shuf); + } + } + replacer.let_injections[let_site].emplace_back(name, shared_load); + } else { + for (; load != v.end() && load->first < first_offset + k.stride; load++) { + int row = load->first - first_offset; + Expr shuf = Shuffle::make_slice(shared_load, row, k.stride, k.lanes); + for (const Load *l : load->second) { + replacer.replacements.emplace(l, shuf); + } } } } @@ -243,7 +378,7 @@ Stmt stage_strided_loads(const Stmt &s) { // picked up in a cluster, but for whom we know it's safe to do a // dense load before their start. for (const auto &[offset, loads] : reverse_view(v)) { - if (replacer.replacements.count({alloc, loads[0]})) { + if (replacer.replacements.count(loads[0])) { continue; } int64_t delta = k.stride - 1; @@ -261,14 +396,14 @@ Stmt stage_strided_loads(const Stmt &s) { dense_load = common_subexpression_elimination(dense_load); Expr shuf = Shuffle::make_slice(dense_load, delta, k.stride, k.lanes); for (const Load *l : loads) { - replacer.replacements.emplace(std::make_pair(alloc, l), shuf); + replacer.replacements.emplace(l, shuf); } } // Look for any loads we can densify because an overlapping load occurs // in any parent scope. for (const auto &[offset, loads] : reverse_view(v)) { - if (replacer.replacements.count({alloc, loads[0]})) { + if (replacer.replacements.count(loads[0])) { continue; } int64_t min_offset = offset; @@ -299,7 +434,7 @@ Stmt stage_strided_loads(const Stmt &s) { dense_load = common_subexpression_elimination(dense_load); Expr shuf = Shuffle::make_slice(dense_load, offset - final_offset, k.stride, k.lanes); for (const Load *l : loads) { - replacer.replacements.emplace(std::make_pair(alloc, l), shuf); + replacer.replacements.emplace(l, shuf); } } @@ -308,7 +443,7 @@ Stmt stage_strided_loads(const Stmt &s) { // external allocations by doing a dense load at a trimmed size. We rely // on codegen to do a good job at loading vectors of a funny size. for (const auto &[offset, loads] : v) { - if (replacer.replacements.count({alloc, loads[0]})) { + if (replacer.replacements.count(loads[0])) { continue; } @@ -332,7 +467,7 @@ Stmt stage_strided_loads(const Stmt &s) { dense_load = common_subexpression_elimination(dense_load); Expr shuf = Shuffle::make_slice(dense_load, offset - first_offset, k.stride, k.lanes); for (const Load *l : loads) { - replacer.replacements.emplace(std::make_pair(alloc, l), shuf); + replacer.replacements.emplace(l, shuf); } } else if (k.lanes % 2 == 0) { @@ -355,7 +490,7 @@ Stmt stage_strided_loads(const Stmt &s) { Expr shuf2 = Shuffle::make_slice(dense_load2, delta, k.stride, k.lanes / 2); Expr shuf = Shuffle::make_concat({shuf1, shuf2}); for (const Load *l : loads) { - replacer.replacements.emplace(std::make_pair(alloc, l), shuf); + replacer.replacements.emplace(l, shuf); } } } diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 690081f5ce4b..77066a8392bd 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -457,6 +457,7 @@ set_target_properties(correctness_async correctness_sliding_over_guard_with_if correctness_sliding_reduction correctness_sliding_window + correctness_stage_strided_loads correctness_storage_folding PROPERTIES ENABLE_EXPORTS TRUE) diff --git a/test/correctness/stage_strided_loads.cpp b/test/correctness/stage_strided_loads.cpp index f791385f7c25..757f71acd487 100644 --- a/test/correctness/stage_strided_loads.cpp +++ b/test/correctness/stage_strided_loads.cpp @@ -3,6 +3,18 @@ using namespace Halide; using namespace Halide::Internal; +// An extern stage needed below +extern "C" HALIDE_EXPORT_SYMBOL int make_data(halide_buffer_t *out) { + if (out->is_bounds_query()) { + return 0; + } + int *dst = (int *)out->host; + for (int x = 0; x < out->dim[0].extent; x++) { + dst[x] = x + out->dim[0].min; + } + return 0; +} + class CheckForStridedLoads : public IRMutator { using IRMutator::visit; @@ -86,10 +98,7 @@ int main(int argc, char **argv) { f(x) += {buf(2 * x), buf(2 * x + 1)}; f.update().vectorize(x, 8, TailStrategy::RoundUp); - // In this case, the dense load appears twice across the two store - // statements for the two tuple components, but it will get deduped by - // llvm. - checker.check(f, 2); + checker.check(f, 1); } { @@ -113,7 +122,7 @@ int main(int argc, char **argv) { g.vectorize(x, 8, TailStrategy::RoundUp); f.compute_at(g, x).vectorize(x); - checker.check(g, 2); + checker.check(g, 1); } { @@ -125,7 +134,7 @@ int main(int argc, char **argv) { g(x) = f(x); g.vectorize(x, 8, TailStrategy::RoundUp); - checker.check(g, 2); + checker.check(g, 1); } { @@ -135,7 +144,7 @@ int main(int argc, char **argv) { f(x, c) = buf(4 * x + c) + 4 * x; f.vectorize(x, 8, TailStrategy::RoundUp).bound(c, 0, 4).unroll(c).reorder(c, x); - checker.check(f, 4); + checker.check(f, 1); } { @@ -152,7 +161,7 @@ int main(int argc, char **argv) { f.tile(x, y, xi, yi, 8, 8, TailStrategy::RoundUp).vectorize(xi).reorder(c, x, y); g.compute_at(f, x).vectorize(x); h.compute_at(f, x).vectorize(x); - checker.check(f, 2); + checker.check(f, 1); } // We can always densify strided loads to internal allocations, because we @@ -181,7 +190,7 @@ int main(int argc, char **argv) { { Func f; Var x; - f(x) = buf(16 * x) + buf(16 * x + 15); + f(x) = buf(17 * x) + buf(17 * x + 15); f.vectorize(x, 16, TailStrategy::RoundUp); checker.check_not(f, 0); @@ -258,6 +267,63 @@ int main(int argc, char **argv) { } } + // Check we don't hoist a shared load past a store to the same buffer. + { + Func f, g, h; + Var x, y, xo, xi, xio, xii; + f(x) = x; + g(x, y) = f(2 * x + y); + h(x) = g(x, 0) + g(x, 1); + + // Construct a situation like: + // alloc f + // for x: + // f([0...14]) = [0...14] + // strided load from f [0, 2, 4, ... 14] + // f[15] = 15 + // strided load from f [1, 3, 5, ... 15] + + // if a single shared load is hoisted before the stores to f + // stage, it will see garbage. If it's hoisted to just after the first + // store, it'll see garbage for the last lane + + h.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + g.compute_at(h, x).unroll(y).vectorize(x); + f.store_at(g, Var::outermost()).compute_at(g, y); + + Buffer out = h.realize({1024}); + + for (int x = 0; x < out.width(); x++) { + int correct = 4 * x + 1; + if (out(x) != correct) { + printf("out(%d) = %d instead of %d\n", x, out(x), correct); + } + } + } + + // Check we don't hoist a shared load before an extern stage that writes to + // the same buffer. Same as above, but uses an extern stage to fill f. + { + Func f, g, h; + Var x, y, xo, xi, xio, xii; + f.define_extern("make_data", {}, Int(32), 1); + g(x, y) = f(2 * x + y); + h(x) = g(x, 0) + g(x, 1); + + h.compute_root().vectorize(x, 8, TailStrategy::RoundUp); + g.compute_at(h, x).unroll(y).vectorize(x); + f.store_at(g, Var::outermost()).bound_storage(_0, 16).compute_at(g, y); + + Buffer out = h.realize({1024}); + + for (int x = 0; x < out.width(); x++) { + int correct = 4 * x + 1; + if (out(x) != correct) { + printf("out(%d) = %d instead of %d\n", x, out(x), correct); + } + } + } + printf("Success!\n"); return 0;