MLIR 22.0.0git
SparseSpaceCollapse.cpp
Go to the documentation of this file.
1//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/IRMapping.h"
12
15
16namespace mlir {
17#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19} // namespace mlir
20
21#define DEBUG_TYPE "sparse-space-collapse"
22
23using namespace mlir;
24using namespace sparse_tensor;
25
26namespace {
27
28struct CollapseSpaceInfo {
29 ExtractIterSpaceOp space;
30 IterateOp loop;
31};
32
33bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34 auto pIterArgs = parent.getRegionIterArgs();
35 auto nInitArgs = node.getInits();
36 if (pIterArgs.size() != nInitArgs.size())
37 return false;
38
39 // Two loops are collapsable if they are perfectly nested.
40 auto pYields = parent.getYieldedValues();
41 auto nResult = node.getLoopResults().value();
42
43 bool yieldEq =
44 llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45 return std::get<0>(zipped) == std::get<1>(zipped);
46 });
47
48 // Parent iter_args should be passed directly to the node's init_args.
49 bool iterArgEq =
50 llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51 return std::get<0>(zipped) == std::get<1>(zipped);
52 });
53
54 return yieldEq && iterArgEq;
55}
56
57bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58 ExtractIterSpaceOp curSpace) {
59
60 auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61 Value spaceVal = space.getExtractedSpace();
62 if (spaceVal.hasOneUse())
63 return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64 return nullptr;
65 };
66
67 if (toCollapse.empty()) {
68 // Collapse root.
69 if (auto itOp = getIterateOpOverSpace(curSpace)) {
70 CollapseSpaceInfo &info = toCollapse.emplace_back();
71 info.space = curSpace;
72 info.loop = itOp;
73 return true;
74 }
75 return false;
76 }
77
78 auto parent = toCollapse.back().space;
79 auto pItOp = toCollapse.back().loop;
80 auto nItOp = getIterateOpOverSpace(curSpace);
81
82 // Can only collapse spaces extracted from the same tensor.
83 if (parent.getTensor() != curSpace.getTensor()) {
84 LLVM_DEBUG({
85 llvm::dbgs()
86 << "failed to collpase spaces extracted from different tensors.";
87 });
88 return false;
89 }
90
91 // Can only collapse consecutive simple iteration on one tensor (i.e., no
92 // coiteration).
93 if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94 pItOp.getIterator() != curSpace.getParentIter() ||
95 curSpace->getParentOp() != pItOp.getOperation()) {
96 LLVM_DEBUG(
97 { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98 return false;
99 }
100
101 if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102 LLVM_DEBUG({
103 llvm::dbgs()
104 << "failed to collapse IterateOps that are not perfectly nested.";
105 });
106 return false;
107 }
108
109 CollapseSpaceInfo &info = toCollapse.emplace_back();
110 info.space = curSpace;
111 info.loop = nItOp;
112 return true;
113}
114
115void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116 if (toCollapse.size() < 2)
117 return;
118
119 ExtractIterSpaceOp root = toCollapse.front().space;
120 ExtractIterSpaceOp leaf = toCollapse.back().space;
121 Location loc = root.getLoc();
122
123 assert(root->hasOneUse() && leaf->hasOneUse());
124
125 // Insert collapsed operation at the same scope as root operation.
126 OpBuilder builder(root);
127
128 // Construct the collapsed iteration space.
129 auto collapsedSpace = ExtractIterSpaceOp::create(
130 builder, loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131 leaf.getHiLvl());
132
133 auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134 auto innermost = toCollapse.back().loop;
135
136 IRMapping mapper;
137 mapper.map(leaf, collapsedSpace.getExtractedSpace());
138 for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139 mapper.map(std::get<0>(z), std::get<1>(z));
140
141 auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142 builder.setInsertionPointToStart(cloned.getBody());
143
144 I64BitSet crdUsedLvls;
145 unsigned shift = 0, argIdx = 1;
146 for (auto info : toCollapse.drop_back()) {
147 I64BitSet set = info.loop.getCrdUsedLvls();
148 crdUsedLvls |= set.lshift(shift);
149 shift += info.loop.getSpaceDim();
150 for (BlockArgument crd : info.loop.getCrds()) {
151 BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152 argIdx++, builder.getIndexType(), crd.getLoc());
153 crd.replaceAllUsesWith(collapsedCrd);
154 }
155 }
156 crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157 cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158 cloned.setCrdUsedLvls(crdUsedLvls);
159
160 rItOp.replaceAllUsesWith(cloned.getResults());
161 // Erase collapsed loops.
162 rItOp.erase();
163 root.erase();
164}
165
166struct SparseSpaceCollapsePass
167 : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168 SparseSpaceCollapsePass() = default;
169
170 void runOnOperation() override {
171 func::FuncOp func = getOperation();
172
173 // A naive (experimental) implementation to collapse consecutive sparse
174 // spaces. It does NOT handle complex cases where multiple spaces are
175 // extracted in the same basic block. E.g.,
176 //
177 // %space1 = extract_space %t1 ...
178 // %space2 = extract_space %t2 ...
179 // sparse_tensor.iterate(%sp1) ...
180 //
182 func->walk([&](ExtractIterSpaceOp op) {
183 if (!legalToCollapse(toCollapse, op)) {
184 // if not legal to collapse one more space, collapse the existing ones
185 // and clear.
186 collapseSparseSpace(toCollapse);
187 toCollapse.clear();
188 }
189 });
190
191 collapseSparseSpace(toCollapse);
192 }
193};
194
195} // namespace
196
197std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198 return std::make_unique<SparseSpaceCollapsePass>();
199}
This class represents an argument of a Block.
Definition Value.h:309
IndexType getIndexType()
Definition Builders.cpp:51
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
user_range getUsers() const
Definition Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
I64BitSet & lshift(unsigned offset)
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseSpaceCollapsePass()