MLIR  20.0.0git
SparseIterationToScf.cpp
Go to the documentation of this file.
1 
2 #include "Utils/CodegenUtils.h"
3 #include "Utils/LoopEmitter.h"
5 
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 
15 /// Assert that the given value range contains a single value and return it.
17  assert(values.size() == 1 && "expected single value");
18  return values.front();
19 }
20 
21 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
22  SmallVectorImpl<Type> &fields) {
23  // Position and coordinate buffer in the sparse structure.
24  if (enc.getLvlType(lvl).isWithPosLT())
25  fields.push_back(enc.getPosMemRefType());
26  if (enc.getLvlType(lvl).isWithCrdLT())
27  fields.push_back(enc.getCrdMemRefType());
28  // One index for shape bound (result from lvlOp).
29  fields.push_back(IndexType::get(enc.getContext()));
30 }
31 
32 static std::optional<LogicalResult>
33 convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34 
35  auto idxTp = IndexType::get(itSp.getContext());
36  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
37  convertLevelType(itSp.getEncoding(), l, fields);
38 
39  // Two indices for lower and upper bound (we only need one pair for the last
40  // iteration space).
41  fields.append({idxTp, idxTp});
42  return success();
43 }
44 
45 static std::optional<LogicalResult>
46 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
47  // The actually Iterator Values (that are updated every iteration).
48  auto idxTp = IndexType::get(itTp.getContext());
49  // TODO: handle batch dimension.
50  assert(itTp.getEncoding().getBatchLvlRank() == 0);
51  if (!itTp.isUnique()) {
52  // Segment high for non-unique iterator.
53  fields.push_back(idxTp);
54  }
55  fields.push_back(idxTp);
56  return success();
57 }
58 
59 static ValueRange
60 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
61  Value loopCrd,
62  ArrayRef<std::unique_ptr<SparseIterator>> iters,
63  ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64  ArrayRef<Value> userReduc) {
65  if (newBlocks.empty())
66  return userReduc;
67 
68  // The current branch that we are handling.
69  Block *newBlock = newBlocks.front();
70  Block *oldBlock = oldBlocks.front();
71  Value casePred = constantI1(rewriter, loc, true);
72  I64BitSet caseBits =
73  op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
74  for (unsigned i : caseBits.bits()) {
75  SparseIterator *it = iters[i].get();
76  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
77  it->getCrd(), loopCrd);
78  casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
79  }
80  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
81  loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
82  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
83 
84  // Erase the empty block.
85  rewriter.eraseBlock(&ifOp.getThenRegion().front());
86  // Set up block arguments: user-provided values -> loop coord -> iterators.
87  SmallVector<Value> blockArgs(userReduc);
88  blockArgs.push_back(loopCrd);
89  for (unsigned idx : caseBits.bits())
90  llvm::append_range(blockArgs, iters[idx]->getCursor());
91 
92  // Map the old block arguments, because the dialect conversion driver does
93  // not immediately perform SSA value replacements. This function is still
94  // seeing the old uses.
95  IRMapping mapping;
96  for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
97  mapping.map(from, to);
98  }
99 
100  // Clone the region, we can not erase the region now because the same region
101  // might be a subcase for multiple lattice point.
102  rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
103  ifOp.getThenRegion().begin(), mapping);
104  // Remove the block arguments, they were already replaced via `mapping`.
105  ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
106 
107  // replace sparse_tensor::YieldOp -> scf::YieldOp
108  auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
109  ValueRange yields = spY.getResults();
110  rewriter.eraseOp(spY);
111  rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
112  rewriter.create<scf::YieldOp>(loc, yields);
113 
114  // Generates remaining case recursively.
115  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
116  ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
117  newBlocks.drop_front(),
118  oldBlocks.drop_front(), userReduc);
119  if (!res.empty())
120  rewriter.create<scf::YieldOp>(loc, res);
121 
122  rewriter.setInsertionPointAfter(ifOp);
123  return ifOp.getResults();
124 }
125 
127  PatternRewriter &rewriter, Location loc, SparseIterator *it,
128  ValueRange reduc,
130  Region &loopBody, SparseIterator *it,
131  ValueRange reduc)>
132  bodyBuilder) {
133  if (it->iteratableByFor()) {
134  auto [lo, hi] = it->genForCond(rewriter, loc);
135  Value step = constantIndex(rewriter, loc, 1);
136  scf::ForOp forOp = rewriter.create<scf::ForOp>(
137  loc, lo, hi, step, reduc,
138  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139  // Empty builder function to ensure that no terminator is created.
140  });
141  {
142  OpBuilder::InsertionGuard guard(rewriter);
143  it->linkNewScope(forOp.getInductionVar());
144  rewriter.setInsertionPointToStart(forOp.getBody());
145  SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
146  it, forOp.getRegionIterArgs());
147 
148  rewriter.setInsertionPointToEnd(forOp.getBody());
149  rewriter.create<scf::YieldOp>(loc, ret);
150  }
151  return forOp.getResults();
152  }
153 
154  SmallVector<Value> ivs(reduc);
155  llvm::append_range(ivs, it->getCursor());
156 
157  TypeRange types = ValueRange(ivs).getTypes();
158  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
159  {
160  OpBuilder::InsertionGuard guard(rewriter);
161  // Generates loop conditions.
162  SmallVector<Location> l(types.size(), loc);
163  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
164  rewriter.setInsertionPointToStart(before);
165  ValueRange bArgs = before->getArguments();
166  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
167  rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
168 
169  // Delegates loop body generation.
170  Region &dstRegion = whileOp.getAfter();
171  Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
172  ValueRange aArgs = whileOp.getAfterArguments();
173  it->linkNewScope(aArgs.drop_front(reduc.size()));
174  aArgs = aArgs.take_front(reduc.size());
175 
176  rewriter.setInsertionPointToStart(after);
177  SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
178  rewriter.setInsertionPointToEnd(after);
179 
180  // Forward loops
181  SmallVector<Value> yields;
182  llvm::append_range(yields, ret);
183  llvm::append_range(yields, it->forward(rewriter, loc));
184  rewriter.create<scf::YieldOp>(loc, yields);
185  }
186  return whileOp.getResults().drop_front(it->getCursor().size());
187 }
188 
189 namespace {
190 
191 /// Sparse codegen rule for number of entries operator.
192 class ExtractIterSpaceConverter
193  : public OpConversionPattern<ExtractIterSpaceOp> {
194 public:
196  LogicalResult
197  matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override {
199  Location loc = op.getLoc();
200 
201  // Construct the iteration space.
202  SparseIterationSpace space(loc, rewriter,
203  getSingleValue(adaptor.getTensor()), 0,
204  op.getLvlRange(), adaptor.getParentIter());
205 
206  SmallVector<Value> result = space.toValues();
207  rewriter.replaceOpWithMultiple(op, {result});
208  return success();
209  }
210 };
211 
212 /// Sparse codegen rule for number of entries operator.
213 class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
214 public:
216  LogicalResult
217  matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
218  ConversionPatternRewriter &rewriter) const override {
219  Location loc = op.getLoc();
220  Value pos = adaptor.getIterator().back();
221  Value valBuf =
222  rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
223  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
224  return success();
225  }
226 };
227 
228 class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
229 public:
231  LogicalResult
232  matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override {
234  if (!op.getCrdUsedLvls().empty())
235  return rewriter.notifyMatchFailure(
236  op, "non-empty coordinates list not implemented.");
237 
238  Location loc = op.getLoc();
239 
240  auto iterSpace = SparseIterationSpace::fromValues(
241  op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
242 
243  std::unique_ptr<SparseIterator> it =
244  iterSpace.extractIterator(rewriter, loc);
245 
246  SmallVector<Value> ivs;
247  for (ValueRange inits : adaptor.getInitArgs())
248  llvm::append_range(ivs, inits);
249 
250  // Type conversion on iterate op block.
251  unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
252  TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
253  if (failed(typeConverter->convertSignatureArgs(
254  op.getBody()->getArgumentTypes(), signatureConversion)))
255  return rewriter.notifyMatchFailure(
256  op, "failed to convert iterate region argurment types");
257 
258  Block *block = rewriter.applySignatureConversion(
259  op.getBody(), signatureConversion, getTypeConverter());
261  rewriter, loc, it.get(), ivs,
262  [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
264  SmallVector<Value> blockArgs(reduc);
265  // TODO: Also appends coordinates if used.
266  // blockArgs.push_back(it->deref(rewriter, loc));
267  llvm::append_range(blockArgs, it->getCursor());
268 
269  Block *dstBlock = &loopBody.getBlocks().front();
270  rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
271  blockArgs);
272  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
273  // We can not use ValueRange as the operation holding the values will
274  // be destoryed.
275  SmallVector<Value> result(yield.getResults());
276  rewriter.eraseOp(yield);
277  return result;
278  });
279 
280  rewriter.replaceOp(op, ret);
281  return success();
282  }
283 };
284 
285 class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
287 
288  LogicalResult
289  matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
290  ConversionPatternRewriter &rewriter) const override {
291  assert(op.getSpaceDim() == 1 && "Not implemented");
292  Location loc = op.getLoc();
293 
294  I64BitSet denseBits(0);
295  for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
296  if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
297  denseBits.set(idx);
298 
299  // If there exists a case that only contains dense spaces. I.e., case
300  // bits is a subset of dense bits, or when there is a full empty case (due
301  // to complements), we need a universal pointer to forward the coiteration
302  // loop.
303  bool needUniv =
304  any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
305  // A case for complement.
306  if (caseBits.count() == 0)
307  return true;
308  // An all-dense case.
309  return caseBits.isSubSetOf(denseBits);
310  });
311  assert(!needUniv && "Not implemented");
312  (void)needUniv;
313 
314  SmallVector<Block *> newBlocks;
315  DenseMap<Block *, Block *> newToOldBlockMap;
316  for (Region &region : op.getCaseRegions()) {
317  // Do a one-shot type conversion on all region blocks, since the same
318  // region might be used multiple time.
319  Block *block = &region.getBlocks().front();
320  TypeConverter::SignatureConversion blockTypeMapping(
321  block->getArgumentTypes().size());
322  if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
323  blockTypeMapping))) {
324  return rewriter.notifyMatchFailure(
325  op, "failed to convert coiterate region argurment types");
326  }
327 
328  newBlocks.push_back(rewriter.applySignatureConversion(
329  block, blockTypeMapping, getTypeConverter()));
330  newToOldBlockMap[newBlocks.back()] = block;
331  }
332 
335  for (auto [spaceTp, spaceVals] : llvm::zip_equal(
336  op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
337  // TODO: do we really need tid?
338  spaces.push_back(SparseIterationSpace::fromValues(
339  cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
340  // Extract the iterator.
341  iters.push_back(spaces.back().extractIterator(rewriter, loc));
342  }
343 
344  auto getFilteredIters = [&iters](I64BitSet caseBits) {
345  // Retrives a vector of pointers to the iterators used in the case.
347  for (auto idx : caseBits.bits())
348  validIters.push_back(iters[idx].get());
349  return validIters;
350  };
351 
352  // Get a flattened user-provided loop reduction values.
353  SmallVector<Value> userReduc;
354  for (ValueRange r : adaptor.getInitArgs())
355  llvm::append_range(userReduc, r);
356 
357  // TODO: we need to sort the cases such that they appears in lexical order.
358  // Although sparsification always generates cases in that order, it might
359  // not be the case for human-written code.
360 
361  // Generates a loop sequence, one loop per case.
362  for (auto [r, caseBits] :
363  llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
364  assert(caseBits.count() > 0 && "Complement space not implemented");
365 
366  // Retrives a vector of pointers to the iterators used in the case.
367  SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
368 
369  if (validIters.size() > 1) {
370  auto [loop, loopCrd] =
371  genCoIteration(rewriter, loc, validIters, userReduc,
372  /*uniIdx=*/nullptr, /*userReducFirst=*/true);
373 
374  // 1st. find all the cases that is a strict subset of the current case
375  // condition, for which we generate one branch per case inside the loop.
376  // The subcases are never empty, it must contains at least the current
377  // region itself.
378  // TODO: these cases should be sorted.
379  SmallVector<Region *> subCases =
380  op.getSubCasesOf(r->getParent()->getRegionNumber());
381  SmallVector<Block *> newBlocks, oldBlocks;
382  for (Region *r : subCases) {
383  newBlocks.push_back(&r->front());
384  oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
385  }
386  assert(!subCases.empty());
387 
389  rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
390 
391  SmallVector<Value> nextIterYields(res);
392  // 2nd. foward the loop.
393  for (SparseIterator *it : validIters) {
394  Value cmp = rewriter.create<arith::CmpIOp>(
395  loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
396  it->forwardIf(rewriter, loc, cmp);
397  llvm::append_range(nextIterYields, it->getCursor());
398  }
399  rewriter.create<scf::YieldOp>(loc, nextIterYields);
400 
401  // Exit the loop, relink the iterator SSA value.
402  rewriter.setInsertionPointAfter(loop);
403  ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
404  for (SparseIterator *it : validIters)
405  iterVals = it->linkNewScope(iterVals);
406  assert(iterVals.empty());
407 
408  ValueRange curResult = loop->getResults().take_front(userReduc.size());
409  userReduc.assign(curResult.begin(), curResult.end());
410  } else {
411  // This is a simple iteration loop.
412  assert(caseBits.count() == 1);
413 
414  Block *block = r;
415  ValueRange curResult = genLoopWithIterator(
416  rewriter, loc, validIters.front(), userReduc,
417  /*bodyBuilder=*/
418  [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
419  SparseIterator *it,
420  ValueRange reduc) -> SmallVector<Value> {
421  SmallVector<Value> blockArgs(reduc);
422  blockArgs.push_back(it->deref(rewriter, loc));
423  llvm::append_range(blockArgs, it->getCursor());
424 
425  Block *dstBlock = &dstRegion.getBlocks().front();
426  rewriter.inlineBlockBefore(
427  block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
428  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
429  SmallVector<Value> result(yield.getResults());
430  rewriter.eraseOp(yield);
431  return result;
432  });
433 
434  userReduc.assign(curResult.begin(), curResult.end());
435  }
436  }
437 
438  rewriter.replaceOp(op, userReduc);
439  return success();
440  }
441 };
442 
443 } // namespace
444 
446  addConversion([](Type type) { return type; });
447  addConversion(convertIteratorType);
448  addConversion(convertIterSpaceType);
449 
450  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
451  ValueRange inputs, Location loc) -> Value {
452  return builder
453  .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
454  .getResult(0);
455  });
456 }
457 
459  const TypeConverter &converter, RewritePatternSet &patterns) {
460 
461  IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
462  patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
463  SparseIterateOpConverter, SparseCoIterateOpConverter>(
464  converter, patterns.getContext());
465 }
static Value getSingleValue(ValueRange values)
Assert that the given value range contains a single value and return it.
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Block * > newBlocks, ArrayRef< Block * > oldBlocks, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void replaceOpWithMultiple(Operation *op, ArrayRef< ValueRange > newValues)
Replace the given operation with the new value ranges.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides all of the information necessary to convert a type signature.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...