17 #define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
21 #define DEBUG_TYPE "sparse-space-collapse"
24 using namespace sparse_tensor;
28 struct CollapseSpaceInfo {
29 ExtractIterSpaceOp space;
33 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34 auto pIterArgs = parent.getRegionIterArgs();
35 auto nInitArgs = node.getInits();
36 if (pIterArgs.size() != nInitArgs.size())
40 auto pYields = parent.getYieldedValues();
41 auto nResult = node.getLoopResults().value();
44 llvm::all_of(llvm::zip_equal(pYields, nResult), [](
auto zipped) {
45 return std::get<0>(zipped) == std::get<1>(zipped);
50 llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](
auto zipped) {
51 return std::get<0>(zipped) == std::get<1>(zipped);
54 return yieldEq && iterArgEq;
58 ExtractIterSpaceOp curSpace) {
60 auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61 Value spaceVal = space.getExtractedSpace();
63 return llvm::dyn_cast<IterateOp>(*spaceVal.
getUsers().begin());
67 if (toCollapse.empty()) {
69 if (
auto itOp = getIterateOpOverSpace(curSpace)) {
70 CollapseSpaceInfo &info = toCollapse.emplace_back();
71 info.space = curSpace;
78 auto parent = toCollapse.back().space;
79 auto pItOp = toCollapse.back().loop;
80 auto nItOp = getIterateOpOverSpace(curSpace);
83 if (parent.getTensor() != curSpace.getTensor()) {
86 <<
"failed to collpase spaces extracted from different tensors.";
93 if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94 pItOp.getIterator() != curSpace.getParentIter() ||
95 curSpace->getParentOp() != pItOp.getOperation()) {
97 { llvm::dbgs() <<
"failed to collapse non-consecutive IterateOps."; });
101 if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
104 <<
"failed to collapse IterateOps that are not perfectly nested.";
109 CollapseSpaceInfo &info = toCollapse.emplace_back();
110 info.space = curSpace;
116 if (toCollapse.size() < 2)
119 ExtractIterSpaceOp root = toCollapse.front().space;
120 ExtractIterSpaceOp leaf = toCollapse.back().space;
123 assert(root->hasOneUse() && leaf->hasOneUse());
129 auto collapsedSpace = builder.
create<ExtractIterSpaceOp>(
130 loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
133 auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134 auto innermost = toCollapse.back().loop;
137 mapper.
map(leaf, collapsedSpace.getExtractedSpace());
138 for (
auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139 mapper.
map(std::get<0>(z), std::get<1>(z));
141 auto cloned = llvm::cast<IterateOp>(builder.
clone(*innermost, mapper));
145 unsigned shift = 0, argIdx = 1;
146 for (
auto info : toCollapse.drop_back()) {
147 I64BitSet set = info.loop.getCrdUsedLvls();
148 crdUsedLvls |= set.
lshift(shift);
149 shift += info.loop.getSpaceDim();
151 BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
156 crdUsedLvls |= innermost.getCrdUsedLvls().
lshift(shift);
157 cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158 cloned.setCrdUsedLvls(crdUsedLvls);
160 rItOp.replaceAllUsesWith(cloned.getResults());
166 struct SparseSpaceCollapsePass
167 :
public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168 SparseSpaceCollapsePass() =
default;
170 void runOnOperation()
override {
171 func::FuncOp func = getOperation();
182 func->walk([&](ExtractIterSpaceOp op) {
183 if (!legalToCollapse(toCollapse, op)) {
186 collapseSparseSpace(toCollapse);
191 collapseSparseSpace(toCollapse);
198 return std::make_unique<SparseSpaceCollapsePass>();
This class represents an argument of a Block.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
I64BitSet & lshift(unsigned offset)
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseSpaceCollapsePass()