MLIR  21.0.0git
SparseIterationToScf.cpp
Go to the documentation of this file.
1 
2 #include "Utils/CodegenUtils.h"
3 #include "Utils/LoopEmitter.h"
5 
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 
15 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
16  SmallVectorImpl<Type> &fields) {
17  // Position and coordinate buffer in the sparse structure.
18  if (enc.getLvlType(lvl).isWithPosLT())
19  fields.push_back(enc.getPosMemRefType());
20  if (enc.getLvlType(lvl).isWithCrdLT())
21  fields.push_back(enc.getCrdMemRefType());
22  // One index for shape bound (result from lvlOp).
23  fields.push_back(IndexType::get(enc.getContext()));
24 }
25 
26 static std::optional<LogicalResult>
27 convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
28 
29  auto idxTp = IndexType::get(itSp.getContext());
30  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
31  convertLevelType(itSp.getEncoding(), l, fields);
32 
33  // Two indices for lower and upper bound (we only need one pair for the last
34  // iteration space).
35  fields.append({idxTp, idxTp});
36  return success();
37 }
38 
39 static std::optional<LogicalResult>
40 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
41  // The actually Iterator Values (that are updated every iteration).
42  auto idxTp = IndexType::get(itTp.getContext());
43  // TODO: handle batch dimension.
44  assert(itTp.getEncoding().getBatchLvlRank() == 0);
45  if (!itTp.isUnique()) {
46  // Segment high for non-unique iterator.
47  fields.push_back(idxTp);
48  }
49  fields.push_back(idxTp);
50  return success();
51 }
52 
53 static ValueRange
54 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
55  Value loopCrd,
56  ArrayRef<std::unique_ptr<SparseIterator>> iters,
57  ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
58  ArrayRef<Value> userReduc) {
59  if (newBlocks.empty())
60  return userReduc;
61 
62  // The current branch that we are handling.
63  Block *newBlock = newBlocks.front();
64  Block *oldBlock = oldBlocks.front();
65  Value casePred = constantI1(rewriter, loc, true);
66  I64BitSet caseBits =
67  op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
68  for (unsigned i : caseBits.bits()) {
69  SparseIterator *it = iters[i].get();
70  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
71  it->getCrd(), loopCrd);
72  casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
73  }
74  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
75  loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
76  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
77 
78  // Erase the empty block.
79  rewriter.eraseBlock(&ifOp.getThenRegion().front());
80  // Set up block arguments: user-provided values -> loop coord -> iterators.
81  SmallVector<Value> blockArgs(userReduc);
82  blockArgs.push_back(loopCrd);
83  for (unsigned idx : caseBits.bits())
84  llvm::append_range(blockArgs, iters[idx]->getCursor());
85 
86  // Map the old block arguments, because the dialect conversion driver does
87  // not immediately perform SSA value replacements. This function is still
88  // seeing the old uses.
89  IRMapping mapping;
90  for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
91  mapping.map(from, to);
92  }
93 
94  // Clone the region, we can not erase the region now because the same region
95  // might be a subcase for multiple lattice point.
96  rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
97  ifOp.getThenRegion().begin(), mapping);
98  // Remove the block arguments, they were already replaced via `mapping`.
99  ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
100 
101  // replace sparse_tensor::YieldOp -> scf::YieldOp
102  auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
103  ValueRange yields = spY.getResults();
104  rewriter.eraseOp(spY);
105  rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
106  rewriter.create<scf::YieldOp>(loc, yields);
107 
108  // Generates remaining case recursively.
109  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
110  ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
111  newBlocks.drop_front(),
112  oldBlocks.drop_front(), userReduc);
113  if (!res.empty())
114  rewriter.create<scf::YieldOp>(loc, res);
115 
116  rewriter.setInsertionPointAfter(ifOp);
117  return ifOp.getResults();
118 }
119 
121  PatternRewriter &rewriter, Location loc, SparseIterator *it,
122  ValueRange reduc,
124  Region &loopBody, SparseIterator *it,
125  ValueRange reduc)>
126  bodyBuilder) {
127  if (it->iteratableByFor()) {
128  auto [lo, hi] = it->genForCond(rewriter, loc);
129  Value step = constantIndex(rewriter, loc, 1);
130  scf::ForOp forOp = rewriter.create<scf::ForOp>(
131  loc, lo, hi, step, reduc,
132  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
133  // Empty builder function to ensure that no terminator is created.
134  });
135  {
136  OpBuilder::InsertionGuard guard(rewriter);
137  it->linkNewScope(forOp.getInductionVar());
138  rewriter.setInsertionPointToStart(forOp.getBody());
139  SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
140  it, forOp.getRegionIterArgs());
141 
142  rewriter.setInsertionPointToEnd(forOp.getBody());
143  rewriter.create<scf::YieldOp>(loc, ret);
144  }
145  return forOp.getResults();
146  }
147 
148  SmallVector<Value> ivs(reduc);
149  llvm::append_range(ivs, it->getCursor());
150 
151  TypeRange types = ValueRange(ivs).getTypes();
152  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
153  {
154  OpBuilder::InsertionGuard guard(rewriter);
155  // Generates loop conditions.
156  SmallVector<Location> l(types.size(), loc);
157  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
158  rewriter.setInsertionPointToStart(before);
159  ValueRange bArgs = before->getArguments();
160  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
161  rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
162 
163  // Delegates loop body generation.
164  Region &dstRegion = whileOp.getAfter();
165  Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
166  ValueRange aArgs = whileOp.getAfterArguments();
167  it->linkNewScope(aArgs.drop_front(reduc.size()));
168  aArgs = aArgs.take_front(reduc.size());
169 
170  rewriter.setInsertionPointToStart(after);
171  SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
172  rewriter.setInsertionPointToEnd(after);
173 
174  // Forward loops
175  SmallVector<Value> yields;
176  llvm::append_range(yields, ret);
177  llvm::append_range(yields, it->forward(rewriter, loc));
178  rewriter.create<scf::YieldOp>(loc, yields);
179  }
180  return whileOp.getResults().drop_front(it->getCursor().size());
181 }
182 
183 namespace {
184 
185 /// Sparse codegen rule for number of entries operator.
186 class ExtractIterSpaceConverter
187  : public OpConversionPattern<ExtractIterSpaceOp> {
188 public:
190  LogicalResult
191  matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override {
193  Location loc = op.getLoc();
194 
195  // Construct the iteration space.
196  SparseIterationSpace space(loc, rewriter,
197  llvm::getSingleElement(adaptor.getTensor()), 0,
198  op.getLvlRange(), adaptor.getParentIter());
199 
200  SmallVector<Value> result = space.toValues();
201  rewriter.replaceOpWithMultiple(op, {result});
202  return success();
203  }
204 };
205 
206 /// Sparse codegen rule for number of entries operator.
207 class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
208 public:
210  LogicalResult
211  matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
212  ConversionPatternRewriter &rewriter) const override {
213  Location loc = op.getLoc();
214  Value pos = adaptor.getIterator().back();
215  Value valBuf = rewriter.create<ToValuesOp>(
216  loc, llvm::getSingleElement(adaptor.getTensor()));
217  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
218  return success();
219  }
220 };
221 
222 class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
223 public:
225  LogicalResult
226  matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
227  ConversionPatternRewriter &rewriter) const override {
228  if (!op.getCrdUsedLvls().empty())
229  return rewriter.notifyMatchFailure(
230  op, "non-empty coordinates list not implemented.");
231 
232  Location loc = op.getLoc();
233 
234  auto iterSpace = SparseIterationSpace::fromValues(
235  op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
236 
237  std::unique_ptr<SparseIterator> it =
238  iterSpace.extractIterator(rewriter, loc);
239 
240  SmallVector<Value> ivs;
241  for (ValueRange inits : adaptor.getInitArgs())
242  llvm::append_range(ivs, inits);
243 
244  // Type conversion on iterate op block.
245  unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
246  TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
247  if (failed(typeConverter->convertSignatureArgs(
248  op.getBody()->getArgumentTypes(), signatureConversion)))
249  return rewriter.notifyMatchFailure(
250  op, "failed to convert iterate region argurment types");
251 
252  Block *block = rewriter.applySignatureConversion(
253  op.getBody(), signatureConversion, getTypeConverter());
255  rewriter, loc, it.get(), ivs,
256  [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
258  SmallVector<Value> blockArgs(reduc);
259  // TODO: Also appends coordinates if used.
260  // blockArgs.push_back(it->deref(rewriter, loc));
261  llvm::append_range(blockArgs, it->getCursor());
262 
263  Block *dstBlock = &loopBody.getBlocks().front();
264  rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
265  blockArgs);
266  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
267  // We can not use ValueRange as the operation holding the values will
268  // be destroyed.
269  SmallVector<Value> result(yield.getResults());
270  rewriter.eraseOp(yield);
271  return result;
272  });
273 
274  rewriter.replaceOp(op, ret);
275  return success();
276  }
277 };
278 
279 class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
281 
282  LogicalResult
283  matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
284  ConversionPatternRewriter &rewriter) const override {
285  assert(op.getSpaceDim() == 1 && "Not implemented");
286  Location loc = op.getLoc();
287 
288  I64BitSet denseBits(0);
289  for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
290  if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
291  denseBits.set(idx);
292 
293  // If there exists a case that only contains dense spaces. I.e., case
294  // bits is a subset of dense bits, or when there is a full empty case (due
295  // to complements), we need a universal pointer to forward the coiteration
296  // loop.
297  bool needUniv =
298  any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
299  // A case for complement.
300  if (caseBits.count() == 0)
301  return true;
302  // An all-dense case.
303  return caseBits.isSubSetOf(denseBits);
304  });
305  assert(!needUniv && "Not implemented");
306  (void)needUniv;
307 
308  SmallVector<Block *> newBlocks;
309  DenseMap<Block *, Block *> newToOldBlockMap;
310  for (Region &region : op.getCaseRegions()) {
311  // Do a one-shot type conversion on all region blocks, since the same
312  // region might be used multiple time.
313  Block *block = &region.getBlocks().front();
314  TypeConverter::SignatureConversion blockTypeMapping(
315  block->getArgumentTypes().size());
316  if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
317  blockTypeMapping))) {
318  return rewriter.notifyMatchFailure(
319  op, "failed to convert coiterate region argurment types");
320  }
321 
322  newBlocks.push_back(rewriter.applySignatureConversion(
323  block, blockTypeMapping, getTypeConverter()));
324  newToOldBlockMap[newBlocks.back()] = block;
325  }
326 
329  for (auto [spaceTp, spaceVals] : llvm::zip_equal(
330  op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
331  // TODO: do we really need tid?
332  spaces.push_back(SparseIterationSpace::fromValues(
333  cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
334  // Extract the iterator.
335  iters.push_back(spaces.back().extractIterator(rewriter, loc));
336  }
337 
338  auto getFilteredIters = [&iters](I64BitSet caseBits) {
339  // Retrives a vector of pointers to the iterators used in the case.
341  for (auto idx : caseBits.bits())
342  validIters.push_back(iters[idx].get());
343  return validIters;
344  };
345 
346  // Get a flattened user-provided loop reduction values.
347  SmallVector<Value> userReduc;
348  for (ValueRange r : adaptor.getInitArgs())
349  llvm::append_range(userReduc, r);
350 
351  // TODO: we need to sort the cases such that they appears in lexical order.
352  // Although sparsification always generates cases in that order, it might
353  // not be the case for human-written code.
354 
355  // Generates a loop sequence, one loop per case.
356  for (auto [r, caseBits] :
357  llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
358  assert(caseBits.count() > 0 && "Complement space not implemented");
359 
360  // Retrives a vector of pointers to the iterators used in the case.
361  SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
362 
363  if (validIters.size() > 1) {
364  auto [loop, loopCrd] =
365  genCoIteration(rewriter, loc, validIters, userReduc,
366  /*uniIdx=*/nullptr, /*userReducFirst=*/true);
367 
368  // 1st. find all the cases that is a strict subset of the current case
369  // condition, for which we generate one branch per case inside the loop.
370  // The subcases are never empty, it must contains at least the current
371  // region itself.
372  // TODO: these cases should be sorted.
373  SmallVector<Region *> subCases =
374  op.getSubCasesOf(r->getParent()->getRegionNumber());
375  SmallVector<Block *> newBlocks, oldBlocks;
376  for (Region *r : subCases) {
377  newBlocks.push_back(&r->front());
378  oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
379  }
380  assert(!subCases.empty());
381 
383  rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
384 
385  SmallVector<Value> nextIterYields(res);
386  // 2nd. foward the loop.
387  for (SparseIterator *it : validIters) {
388  Value cmp = rewriter.create<arith::CmpIOp>(
389  loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
390  it->forwardIf(rewriter, loc, cmp);
391  llvm::append_range(nextIterYields, it->getCursor());
392  }
393  rewriter.create<scf::YieldOp>(loc, nextIterYields);
394 
395  // Exit the loop, relink the iterator SSA value.
396  rewriter.setInsertionPointAfter(loop);
397  ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
398  for (SparseIterator *it : validIters)
399  iterVals = it->linkNewScope(iterVals);
400  assert(iterVals.empty());
401 
402  ValueRange curResult = loop->getResults().take_front(userReduc.size());
403  userReduc.assign(curResult.begin(), curResult.end());
404  } else {
405  // This is a simple iteration loop.
406  assert(caseBits.count() == 1);
407 
408  Block *block = r;
409  ValueRange curResult = genLoopWithIterator(
410  rewriter, loc, validIters.front(), userReduc,
411  /*bodyBuilder=*/
412  [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
413  SparseIterator *it,
414  ValueRange reduc) -> SmallVector<Value> {
415  SmallVector<Value> blockArgs(reduc);
416  blockArgs.push_back(it->deref(rewriter, loc));
417  llvm::append_range(blockArgs, it->getCursor());
418 
419  Block *dstBlock = &dstRegion.getBlocks().front();
420  rewriter.inlineBlockBefore(
421  block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
422  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
423  SmallVector<Value> result(yield.getResults());
424  rewriter.eraseOp(yield);
425  return result;
426  });
427 
428  userReduc.assign(curResult.begin(), curResult.end());
429  }
430  }
431 
432  rewriter.replaceOp(op, userReduc);
433  return success();
434  }
435 };
436 
437 } // namespace
438 
440  addConversion([](Type type) { return type; });
441  addConversion(convertIteratorType);
442  addConversion(convertIterSpaceType);
443 
444  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
445  ValueRange inputs, Location loc) -> Value {
446  return builder
447  .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
448  .getResult(0);
449  });
450 }
451 
453  const TypeConverter &converter, RewritePatternSet &patterns) {
454 
455  IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
456  patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
457  SparseIterateOpConverter, SparseCoIterateOpConverter>(
458  converter, patterns.getContext());
459 }
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Block * > newBlocks, ArrayRef< Block * > oldBlocks, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:576
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides all of the information necessary to convert a type signature.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...