Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/IRMutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) {
}
}

template<typename... Lambdas>
auto mutate_with(const IRNode *ir, Lambdas &&...lambdas) -> IRHandle {
if (ir->node_type <= StrongestExprNodeType) {
return mutate_with(Expr((const BaseExprNode *)ir), std::forward<Lambdas>(lambdas)...);
} else {
return mutate_with(Stmt((const BaseStmtNode *)ir), std::forward<Lambdas>(lambdas)...);
}
}

/** A helper function for mutator-like things to mutate regions */
template<typename Mutator, typename... Args>
std::pair<Region, bool> mutate_region(Mutator *mutator, const Region &bounds, Args &&...args) {
Expand Down
183 changes: 159 additions & 24 deletions src/StageStridedLoads.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "StageStridedLoads.h"
#include "CSE.h"
#include "ExprUsesVar.h"
#include "IREquality.h"
#include "IRMutator.h"
#include "IROperator.h"
Expand Down Expand Up @@ -95,12 +96,15 @@ class FindStridedLoads : public IRVisitor {
base = base_add->a;
offset = *off;
}
} else if (auto off = as_const_int(base)) {
base = 0;
offset = *off;
}

// TODO: We do not yet handle nested vectorization here for
// ramps which have not already collapsed. We could potentially
// handle more interesting types of shuffle than simple flat slices.
if (stride >= 2 && stride < r->lanes && r->stride.type().is_scalar()) {
if (stride >= 2 && stride <= r->lanes && r->stride.type().is_scalar()) {
const IRNode *s = scope;
const Allocate *a = nullptr;
if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) {
Expand Down Expand Up @@ -154,17 +158,35 @@ class FindStridedLoads : public IRVisitor {
// Replace a bunch of load expressions in a stmt
class ReplaceStridedLoads : public IRMutator {
public:
std::map<std::pair<const Allocate *, const Load *>, Expr> replacements;
std::map<const Load *, Expr> replacements;
std::map<const Allocate *, int> padding;
Scope<const Allocate *> allocation_scope;
std::map<const IRNode *, std::vector<std::pair<std::string, Expr>>> let_injections;

Stmt mutate(const Stmt &s) override {
Stmt stmt = IRMutator::mutate(s);
auto it = let_injections.find(s.get());
if (it != let_injections.end()) {
for (const auto &[name, value] : it->second) {
stmt = LetStmt::make(name, value, stmt);
}
}
return stmt;
}

Expr mutate(const Expr &e) override {
Expr expr = IRMutator::mutate(e);
auto it = let_injections.find(e.get());
if (it != let_injections.end()) {
for (const auto &[name, value] : it->second) {
expr = Let::make(name, value, expr);
}
}
return expr;
}

protected:
Expr visit(const Load *op) override {
const Allocate *alloc = nullptr;
if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) {
alloc = *a_ptr;
}
auto it = replacements.find({alloc, op});
auto it = replacements.find(op);
if (it != replacements.end()) {
return mutate(it->second);
} else {
Expand All @@ -173,7 +195,6 @@ class ReplaceStridedLoads : public IRMutator {
}

Stmt visit(const Allocate *op) override {
ScopedBinding bind(allocation_scope, op->name, op);
auto it = padding.find(op);
Stmt s = IRMutator::visit(op);
if (it == padding.end()) {
Expand All @@ -191,12 +212,96 @@ class ReplaceStridedLoads : public IRMutator {
using IRMutator::visit;
};

const IRNode *innermost_containing_node(const IRNode *root, const std::set<const Load *> &exprs) {
const IRNode *result = nullptr;
// The innermost containing stmt is whichever stmt node contains the
// largest number of our exprs, with ties breaking inwards.
int seen = 0, best = 0;
mutate_with(root, //
[&](auto *self, const Stmt &s) {
int old = seen;
self->mutate_base(s);
if (old == 0 && seen > best) {
result = s.get();
best = seen;
}
return s; //
},
[&](auto *self, const Expr &e) {
int old = seen;
const Load *l = e.as<Load>();
if (l && exprs.count(l)) {
seen++;
};
self->mutate_base(e);
if (old == 0 && seen > best) {
result = e.get();
best = seen;
}
return e; //
});
internal_assert(seen) << "None of the exprs were found\n";
return result;
}

bool can_hoist_shared_load(const IRNode *n, const std::string &buf, const Expr &idx) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any risk of hoisting past an extern stage with a side-effect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be. I'll try to construct a failure

// Check none of the variables the idx depends on are defined somewhere
// within this stmt, there are no stores to the given buffer in the stmt,
// and other side-effecty things that might either write to the buffer or
// guard against out of bounds reads don't occur either.
bool result = true;
visit_with(n, //
[&](auto *self, const Let *let) { //
result &= !expr_uses_var(idx, let->name);
self->visit_base(let);
},
[&](auto *self, const LetStmt *let) { //
result &= !expr_uses_var(idx, let->name);
self->visit_base(let);
},
[&](auto *self, const For *loop) { //
result &= !expr_uses_var(idx, loop->name);
self->visit_base(loop);
},
[&](auto *self, const Allocate *alloc) { //
result &= alloc->name != buf;
self->visit_base(alloc);
},
[&](auto *self, const Store *store) { //
result &= store->name != buf;
self->visit_base(store);
},
[&](auto *self, const AssertStmt *a) { //
// Extern stages always come with asserts, even when
// no_asserts is on (they get stripped later), so this also
// guards against writes to the buffer happening in an extern
// stage. copy_to_host/device calls also come with asserts.
result = false;
});
return result;
}

} // namespace

Stmt stage_strided_loads(const Stmt &s) {
Stmt stage_strided_loads(const Stmt &stmt) {
FindStridedLoads finder;
ReplaceStridedLoads replacer;

// Make all strided loads distinct IR nodes so that we can uniquely identify
// them by address. We may want to mutate the same load node in different
// ways depending on the surrounding context.
Stmt s = mutate_with(stmt, [&](auto *self, const Load *l) {
const Ramp *r = l->index.as<Ramp>();
if (l->type.is_scalar() || (r && is_const_one(r->stride))) {
// Definitely not a strided load
return self->visit_base(l);
} else {
// Might be a strided load after simplification
return Load::make(l->type, l->name, self->mutate(l->index), l->image, l->param,
self->mutate(l->predicate), l->alignment);
}
});

// Find related clusters of strided loads anywhere in the stmt. While this
// appears to look globally, it requires expressions to match exactly, so
// really it's only going to find things inside the same loops and let
Expand All @@ -205,7 +310,6 @@ Stmt stage_strided_loads(const Stmt &s) {

for (const auto &l : finder.found_loads) {
const FindStridedLoads::Key &k = l.first;
const Allocate *alloc = k.allocation;
const std::map<int64_t, std::vector<const Load *>> &v = l.second;

// Find clusters of strided loads that can share the same dense load.
Expand All @@ -225,16 +329,47 @@ Stmt stage_strided_loads(const Stmt &s) {
// We have a complete cluster of loads. Make a single dense load
int lanes = k.lanes * k.stride;
int64_t first_offset = load->first;
Expr idx = Ramp::make(k.base + (int)first_offset, make_one(k.base.type()), lanes);
Expr base = common_subexpression_elimination(k.base);
Expr idx = Ramp::make(base + (int)first_offset, make_one(k.base.type()), lanes);
Type t = k.type.with_lanes(lanes);
const Load *op = load->second[0];

std::set<const Load *> all_loads;
for (auto l = load; l != v.end() && l->first < first_offset + k.stride; l++) {
all_loads.insert(l->second.begin(), l->second.end());
}

Expr shared_load = Load::make(t, k.buf, idx, op->image, op->param,
const_true(lanes), op->alignment);
shared_load = common_subexpression_elimination(shared_load);
for (; load != v.end() && load->first < first_offset + k.stride; load++) {
Expr shuf = Shuffle::make_slice(shared_load, load->first - first_offset, k.stride, k.lanes);
for (const Load *l : load->second) {
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);

// We now need to pick a site to place our shared dense load. We
// can't lift the shared load further out than k.scope, because that
// marks the outermost point at which the loads are known to both
// definitely occur (we don't want to escape an if statement) and
// produce a single fixed value (we don't want to cross over a store
// to the same buffer). If k.scope is null, the loads are valid
// everywhere, so we can hoist the single shared dense load to
// wherever we like.
const IRNode *outermost = k.scope ? k.scope : s.get();
const IRNode *let_site = innermost_containing_node(outermost, all_loads);
if (can_hoist_shared_load(let_site, k.buf, idx)) {
std::string name = unique_name('t');
Expr var = Variable::make(shared_load.type(), name);
for (; load != v.end() && load->first < first_offset + k.stride; load++) {
int row = load->first - first_offset;
Expr shuf = Shuffle::make_slice(var, row, k.stride, k.lanes);
for (const Load *l : load->second) {
replacer.replacements.emplace(l, shuf);
}
}
replacer.let_injections[let_site].emplace_back(name, shared_load);
} else {
for (; load != v.end() && load->first < first_offset + k.stride; load++) {
int row = load->first - first_offset;
Expr shuf = Shuffle::make_slice(shared_load, row, k.stride, k.lanes);
for (const Load *l : load->second) {
replacer.replacements.emplace(l, shuf);
}
}
}
}
Expand All @@ -243,7 +378,7 @@ Stmt stage_strided_loads(const Stmt &s) {
// picked up in a cluster, but for whom we know it's safe to do a
// dense load before their start.
for (const auto &[offset, loads] : reverse_view(v)) {
if (replacer.replacements.count({alloc, loads[0]})) {
if (replacer.replacements.count(loads[0])) {
continue;
}
int64_t delta = k.stride - 1;
Expand All @@ -261,14 +396,14 @@ Stmt stage_strided_loads(const Stmt &s) {
dense_load = common_subexpression_elimination(dense_load);
Expr shuf = Shuffle::make_slice(dense_load, delta, k.stride, k.lanes);
for (const Load *l : loads) {
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
replacer.replacements.emplace(l, shuf);
}
}

// Look for any loads we can densify because an overlapping load occurs
// in any parent scope.
for (const auto &[offset, loads] : reverse_view(v)) {
if (replacer.replacements.count({alloc, loads[0]})) {
if (replacer.replacements.count(loads[0])) {
continue;
}
int64_t min_offset = offset;
Expand Down Expand Up @@ -299,7 +434,7 @@ Stmt stage_strided_loads(const Stmt &s) {
dense_load = common_subexpression_elimination(dense_load);
Expr shuf = Shuffle::make_slice(dense_load, offset - final_offset, k.stride, k.lanes);
for (const Load *l : loads) {
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
replacer.replacements.emplace(l, shuf);
}
}

Expand All @@ -308,7 +443,7 @@ Stmt stage_strided_loads(const Stmt &s) {
// external allocations by doing a dense load at a trimmed size. We rely
// on codegen to do a good job at loading vectors of a funny size.
for (const auto &[offset, loads] : v) {
if (replacer.replacements.count({alloc, loads[0]})) {
if (replacer.replacements.count(loads[0])) {
continue;
}

Expand All @@ -332,7 +467,7 @@ Stmt stage_strided_loads(const Stmt &s) {
dense_load = common_subexpression_elimination(dense_load);
Expr shuf = Shuffle::make_slice(dense_load, offset - first_offset, k.stride, k.lanes);
for (const Load *l : loads) {
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
replacer.replacements.emplace(l, shuf);
}

} else if (k.lanes % 2 == 0) {
Expand All @@ -355,7 +490,7 @@ Stmt stage_strided_loads(const Stmt &s) {
Expr shuf2 = Shuffle::make_slice(dense_load2, delta, k.stride, k.lanes / 2);
Expr shuf = Shuffle::make_concat({shuf1, shuf2});
for (const Load *l : loads) {
replacer.replacements.emplace(std::make_pair(alloc, l), shuf);
replacer.replacements.emplace(l, shuf);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ set_target_properties(correctness_async
correctness_sliding_over_guard_with_if
correctness_sliding_reduction
correctness_sliding_window
correctness_stage_strided_loads
correctness_storage_folding
PROPERTIES ENABLE_EXPORTS TRUE)

Expand Down
Loading
Loading