MLIR  20.0.0git
SparseSpaceCollapse.cpp
Go to the documentation of this file.
1 //===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/IRMapping.h"
11 #include "mlir/Transforms/Passes.h"
12 
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19 } // namespace mlir
20 
21 #define DEBUG_TYPE "sparse-space-collapse"
22 
23 using namespace mlir;
24 using namespace sparse_tensor;
25 
26 namespace {
27 
28 struct CollapseSpaceInfo {
29  ExtractIterSpaceOp space;
30  IterateOp loop;
31 };
32 
33 bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34  auto pIterArgs = parent.getRegionIterArgs();
35  auto nInitArgs = node.getInits();
36  if (pIterArgs.size() != nInitArgs.size())
37  return false;
38 
39  // Two loops are collapsable if they are perfectly nested.
40  auto pYields = parent.getYieldedValues();
41  auto nResult = node.getLoopResults().value();
42 
43  bool yieldEq =
44  llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45  return std::get<0>(zipped) == std::get<1>(zipped);
46  });
47 
48  // Parent iter_args should be passed directly to the node's init_args.
49  bool iterArgEq =
50  llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51  return std::get<0>(zipped) == std::get<1>(zipped);
52  });
53 
54  return yieldEq && iterArgEq;
55 }
56 
57 bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58  ExtractIterSpaceOp curSpace) {
59 
60  auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61  Value spaceVal = space.getExtractedSpace();
62  if (spaceVal.hasOneUse())
63  return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64  return nullptr;
65  };
66 
67  if (toCollapse.empty()) {
68  // Collapse root.
69  if (auto itOp = getIterateOpOverSpace(curSpace)) {
70  CollapseSpaceInfo &info = toCollapse.emplace_back();
71  info.space = curSpace;
72  info.loop = itOp;
73  return true;
74  }
75  return false;
76  }
77 
78  auto parent = toCollapse.back().space;
79  auto pItOp = toCollapse.back().loop;
80  auto nItOp = getIterateOpOverSpace(curSpace);
81 
82  // Can only collapse spaces extracted from the same tensor.
83  if (parent.getTensor() != curSpace.getTensor()) {
84  LLVM_DEBUG({
85  llvm::dbgs()
86  << "failed to collpase spaces extracted from different tensors.";
87  });
88  return false;
89  }
90 
91  // Can only collapse consecutive simple iteration on one tensor (i.e., no
92  // coiteration).
93  if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94  pItOp.getIterator() != curSpace.getParentIter() ||
95  curSpace->getParentOp() != pItOp.getOperation()) {
96  LLVM_DEBUG(
97  { llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98  return false;
99  }
100 
101  if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102  LLVM_DEBUG({
103  llvm::dbgs()
104  << "failed to collapse IterateOps that are not perfectly nested.";
105  });
106  return false;
107  }
108 
109  CollapseSpaceInfo &info = toCollapse.emplace_back();
110  info.space = curSpace;
111  info.loop = nItOp;
112  return true;
113 }
114 
115 void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116  if (toCollapse.size() < 2)
117  return;
118 
119  ExtractIterSpaceOp root = toCollapse.front().space;
120  ExtractIterSpaceOp leaf = toCollapse.back().space;
121  Location loc = root.getLoc();
122 
123  assert(root->hasOneUse() && leaf->hasOneUse());
124 
125  // Insert collapsed operation at the same scope as root operation.
126  OpBuilder builder(root);
127 
128  // Construct the collapsed iteration space.
129  auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
130  loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131  leaf.getHiLvl());
132 
133  auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134  auto innermost = toCollapse.back().loop;
135 
136  IRMapping mapper;
137  mapper.map(leaf, collapsedSpace.getExtractedSpace());
138  for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139  mapper.map(std::get<0>(z), std::get<1>(z));
140 
141  auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142  builder.setInsertionPointToStart(cloned.getBody());
143 
144  I64BitSet crdUsedLvls;
145  unsigned shift = 0, argIdx = 1;
146  for (auto info : toCollapse.drop_back()) {
147  I64BitSet set = info.loop.getCrdUsedLvls();
148  crdUsedLvls |= set.lshift(shift);
149  shift += info.loop.getSpaceDim();
150  for (BlockArgument crd : info.loop.getCrds()) {
151  BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152  argIdx++, builder.getIndexType(), crd.getLoc());
153  crd.replaceAllUsesWith(collapsedCrd);
154  }
155  }
156  crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157  cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158  cloned.setCrdUsedLvls(crdUsedLvls);
159 
160  rItOp.replaceAllUsesWith(cloned.getResults());
161  // Erase collapsed loops.
162  rItOp.erase();
163  root.erase();
164 }
165 
166 struct SparseSpaceCollapsePass
167  : public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168  SparseSpaceCollapsePass() = default;
169 
170  void runOnOperation() override {
171  func::FuncOp func = getOperation();
172 
173  // A naive (experimental) implementation to collapse consecutive sparse
174  // spaces. It does NOT handle complex cases where multiple spaces are
175  // extracted in the same basic block. E.g.,
176  //
177  // %space1 = extract_space %t1 ...
178  // %space2 = extract_space %t2 ...
179  // sparse_tensor.iterate(%sp1) ...
180  //
182  func->walk([&](ExtractIterSpaceOp op) {
183  if (!legalToCollapse(toCollapse, op)) {
184  // if not legal to collapse one more space, collapse the existing ones
185  // and clear.
186  collapseSparseSpace(toCollapse);
187  toCollapse.clear();
188  }
189  });
190 
191  collapseSparseSpace(toCollapse);
192  }
193 };
194 
195 } // namespace
196 
197 std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198  return std::make_unique<SparseSpaceCollapsePass>();
199 }
This class represents an argument of a Block.
Definition: Value.h:319
IndexType getIndexType()
Definition: Builders.cpp:95
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
user_range getUsers() const
Definition: Value.h:228
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
I64BitSet & lshift(unsigned offset)
Definition: SparseTensor.h:94
Include the generated interface declarations.
std::unique_ptr< Pass > createSparseSpaceCollapsePass()