MLIR  20.0.0git
SparseIterationToScf.cpp
Go to the documentation of this file.
1 
2 #include "Utils/CodegenUtils.h"
3 #include "Utils/LoopEmitter.h"
5 
11 
12 using namespace mlir;
13 using namespace mlir::sparse_tensor;
14 
15 static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
16  SmallVectorImpl<Type> &fields) {
17  // Position and coordinate buffer in the sparse structure.
18  if (enc.getLvlType(lvl).isWithPosLT())
19  fields.push_back(enc.getPosMemRefType());
20  if (enc.getLvlType(lvl).isWithCrdLT())
21  fields.push_back(enc.getCrdMemRefType());
22  // One index for shape bound (result from lvlOp).
23  fields.push_back(IndexType::get(enc.getContext()));
24 }
25 
26 static std::optional<LogicalResult>
27 convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
28 
29  auto idxTp = IndexType::get(itSp.getContext());
30  for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
31  convertLevelType(itSp.getEncoding(), l, fields);
32 
33  // Two indices for lower and upper bound (we only need one pair for the last
34  // iteration space).
35  fields.append({idxTp, idxTp});
36  return success();
37 }
38 
39 static std::optional<LogicalResult>
40 convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
41  // The actually Iterator Values (that are updated every iteration).
42  auto idxTp = IndexType::get(itTp.getContext());
43  // TODO: handle batch dimension.
44  assert(itTp.getEncoding().getBatchLvlRank() == 0);
45  if (!itTp.isUnique()) {
46  // Segment high for non-unique iterator.
47  fields.push_back(idxTp);
48  }
49  fields.push_back(idxTp);
50  return success();
51 }
52 
53 static ValueRange
54 genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
55  Value loopCrd,
56  ArrayRef<std::unique_ptr<SparseIterator>> iters,
57  ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58  if (subCases.empty())
59  return userReduc;
60 
61  // The current branch that we are handling.
62  Region *b = subCases.front();
63  Value casePred = constantI1(rewriter, loc, true);
64  I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
65  for (unsigned i : caseBits.bits()) {
66  SparseIterator *it = iters[i].get();
67  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
68  it->getCrd(), loopCrd);
69  casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
70  }
71  scf::IfOp ifOp = rewriter.create<scf::IfOp>(
72  loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
73  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
74 
75  // Erase the empty block.
76  rewriter.eraseBlock(&ifOp.getThenRegion().front());
77  // Set up block arguments: user-provided values -> loop coord -> iterators.
78  SmallVector<Value> blockArgs(userReduc);
79  blockArgs.push_back(loopCrd);
80  for (unsigned idx : caseBits.bits())
81  llvm::append_range(blockArgs, iters[idx]->getCursor());
82 
83  IRMapping mapping;
84  for (auto [from, to] :
85  llvm::zip_equal(b->front().getArguments(), blockArgs)) {
86  mapping.map(from, to);
87  }
88 
89  // Clone the region, we can not erase the region now because the same region
90  // might be a subcase for multiple lattice point.
91  rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
92  ifOp.getThenRegion().begin(), mapping);
93 
94  // replace sparse_tensor::YieldOp -> scf::YieldOp
95  auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
96  ValueRange yields = spY.getResults();
97  rewriter.eraseOp(spY);
98  rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
99  rewriter.create<scf::YieldOp>(loc, yields);
100 
101  // Generates remaining case recursively.
102  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
103  ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
104  subCases.drop_front(), userReduc);
105  if (!res.empty())
106  rewriter.create<scf::YieldOp>(loc, res);
107 
108  rewriter.setInsertionPointAfter(ifOp);
109  return ifOp.getResults();
110 }
111 
113  PatternRewriter &rewriter, Location loc, SparseIterator *it,
114  ValueRange reduc,
116  Region &loopBody, SparseIterator *it,
117  ValueRange reduc)>
118  bodyBuilder) {
119  if (it->iteratableByFor()) {
120  auto [lo, hi] = it->genForCond(rewriter, loc);
121  Value step = constantIndex(rewriter, loc, 1);
122  scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
123  {
124  OpBuilder::InsertionGuard guard(rewriter);
125  // Erase the implicit yield operation created by ForOp when there is no
126  // yielding values.
127  if (!forOp.getBody()->empty())
128  rewriter.eraseOp(&forOp.getBody()->front());
129  assert(forOp.getBody()->empty());
130 
131  it->linkNewScope(forOp.getInductionVar());
132  rewriter.setInsertionPointToStart(forOp.getBody());
133  SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
134  it, forOp.getRegionIterArgs());
135 
136  rewriter.setInsertionPointToEnd(forOp.getBody());
137  rewriter.create<scf::YieldOp>(loc, ret);
138  }
139  return forOp.getResults();
140  }
141 
142  SmallVector<Value> ivs(reduc);
143  llvm::append_range(ivs, it->getCursor());
144 
145  TypeRange types = ValueRange(ivs).getTypes();
146  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
147  {
148  OpBuilder::InsertionGuard guard(rewriter);
149  // Generates loop conditions.
150  SmallVector<Location> l(types.size(), loc);
151  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
152  rewriter.setInsertionPointToStart(before);
153  ValueRange bArgs = before->getArguments();
154  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
155  rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
156 
157  // Delegates loop body generation.
158  Region &dstRegion = whileOp.getAfter();
159  Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
160  ValueRange aArgs = whileOp.getAfterArguments();
161  it->linkNewScope(aArgs.drop_front(reduc.size()));
162  aArgs = aArgs.take_front(reduc.size());
163 
164  rewriter.setInsertionPointToStart(after);
165  SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
166  rewriter.setInsertionPointToEnd(after);
167 
168  // Forward loops
169  SmallVector<Value> yields;
170  llvm::append_range(yields, ret);
171  llvm::append_range(yields, it->forward(rewriter, loc));
172  rewriter.create<scf::YieldOp>(loc, yields);
173  }
174  return whileOp.getResults().drop_front(it->getCursor().size());
175 }
176 
177 namespace {
178 
179 /// Sparse codegen rule for number of entries operator.
180 class ExtractIterSpaceConverter
181  : public OneToNOpConversionPattern<ExtractIterSpaceOp> {
182 public:
184  LogicalResult
185  matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
186  OneToNPatternRewriter &rewriter) const override {
187  Location loc = op.getLoc();
188  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
189 
190  // Construct the iteration space.
191  SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
192  op.getLvlRange(), adaptor.getParentIter());
193 
194  SmallVector<Value> result = space.toValues();
195  rewriter.replaceOp(op, result, resultMapping);
196  return success();
197  }
198 };
199 
200 /// Sparse codegen rule for number of entries operator.
201 class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
202 public:
204  LogicalResult
205  matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
206  OneToNPatternRewriter &rewriter) const override {
207  Location loc = op.getLoc();
208  Value pos = adaptor.getIterator().back();
209  Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
210  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
211  return success();
212  }
213 };
214 
215 class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
216 public:
218  LogicalResult
219  matchAndRewrite(IterateOp op, OpAdaptor adaptor,
220  OneToNPatternRewriter &rewriter) const override {
221  if (!op.getCrdUsedLvls().empty())
222  return rewriter.notifyMatchFailure(
223  op, "non-empty coordinates list not implemented.");
224 
225  Location loc = op.getLoc();
226 
227  auto iterSpace = SparseIterationSpace::fromValues(
228  op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
229 
230  std::unique_ptr<SparseIterator> it =
231  iterSpace.extractIterator(rewriter, loc);
232 
233  SmallVector<Value> ivs;
234  for (ValueRange inits : adaptor.getInitArgs())
235  llvm::append_range(ivs, inits);
236 
237  // Type conversion on iterate op block.
238  OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
239  if (failed(typeConverter->convertSignatureArgs(
240  op.getBody()->getArgumentTypes(), blockTypeMapping)))
241  return rewriter.notifyMatchFailure(
242  op, "failed to convert iterate region argurment types");
243  rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
244 
245  Block *block = op.getBody();
247  rewriter, loc, it.get(), ivs,
248  [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
250  SmallVector<Value> blockArgs(reduc);
251  // TODO: Also appends coordinates if used.
252  // blockArgs.push_back(it->deref(rewriter, loc));
253  llvm::append_range(blockArgs, it->getCursor());
254 
255  Block *dstBlock = &loopBody.getBlocks().front();
256  rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
257  blockArgs);
258  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
259  // We can not use ValueRange as the operation holding the values will
260  // be destoryed.
261  SmallVector<Value> result(yield.getResults());
262  rewriter.eraseOp(yield);
263  return result;
264  });
265 
266  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
267  rewriter.replaceOp(op, ret, resultMapping);
268  return success();
269  }
270 };
271 
272 class SparseCoIterateOpConverter
273  : public OneToNOpConversionPattern<CoIterateOp> {
275 
276  LogicalResult
277  matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
278  OneToNPatternRewriter &rewriter) const override {
279  assert(op.getSpaceDim() == 1 && "Not implemented");
280  Location loc = op.getLoc();
281 
282  I64BitSet denseBits(0);
283  for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
284  if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
285  denseBits.set(idx);
286 
287  // If there exists a case that only contains dense spaces. I.e., case
288  // bits is a subset of dense bits, or when there is a full empty case (due
289  // to complements), we need a universal pointer to forward the coiteration
290  // loop.
291  bool needUniv =
292  any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
293  // A case for complement.
294  if (caseBits.count() == 0)
295  return true;
296  // An all-dense case.
297  return caseBits.isSubSetOf(denseBits);
298  });
299  assert(!needUniv && "Not implemented");
300  (void)needUniv;
301 
302  for (Region &region : op.getCaseRegions()) {
303  // Do a one-shot type conversion on all region blocks, since the same
304  // region might be used multiple time.
305  Block *block = &region.getBlocks().front();
306  OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
307  if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
308  blockTypeMapping))) {
309  return rewriter.notifyMatchFailure(
310  op, "failed to convert coiterate region argurment types");
311  }
312 
313  rewriter.applySignatureConversion(block, blockTypeMapping);
314  }
315 
318  for (auto [spaceTp, spaceVals] : llvm::zip_equal(
319  op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
320  // TODO: do we really need tid?
321  spaces.push_back(SparseIterationSpace::fromValues(
322  cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
323  // Extract the iterator.
324  iters.push_back(spaces.back().extractIterator(rewriter, loc));
325  }
326 
327  auto getFilteredIters = [&iters](I64BitSet caseBits) {
328  // Retrives a vector of pointers to the iterators used in the case.
330  for (auto idx : caseBits.bits())
331  validIters.push_back(iters[idx].get());
332  return validIters;
333  };
334 
335  // Get a flattened user-provided loop reduction values.
336  SmallVector<Value> userReduc;
337  for (ValueRange r : adaptor.getInitArgs())
338  llvm::append_range(userReduc, r);
339 
340  // TODO: we need to sort the cases such that they appears in lexical order.
341  // Although sparsification always generates cases in that order, it might
342  // not be the case for human-written code.
343 
344  // Generates a loop sequence, one loop per case.
345  for (auto [r, caseBits] :
346  llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
347  assert(caseBits.count() > 0 && "Complement space not implemented");
348 
349  // Retrives a vector of pointers to the iterators used in the case.
350  SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
351 
352  if (validIters.size() > 1) {
353  auto [loop, loopCrd] =
354  genCoIteration(rewriter, loc, validIters, userReduc,
355  /*uniIdx=*/nullptr, /*userReducFirst=*/true);
356 
357  // 1st. find all the cases that is a strict subset of the current case
358  // condition, for which we generate one branch per case inside the loop.
359  // The subcases are never empty, it must contains at least the current
360  // region itself.
361  // TODO: these cases should be sorted.
362  SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
363  assert(!subCases.empty());
364 
365  ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
366  iters, subCases, userReduc);
367 
368  SmallVector<Value> nextIterYields(res);
369  // 2nd. foward the loop.
370  for (SparseIterator *it : validIters) {
371  Value cmp = rewriter.create<arith::CmpIOp>(
372  loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
373  it->forwardIf(rewriter, loc, cmp);
374  llvm::append_range(nextIterYields, it->getCursor());
375  }
376  rewriter.create<scf::YieldOp>(loc, nextIterYields);
377 
378  // Exit the loop, relink the iterator SSA value.
379  rewriter.setInsertionPointAfter(loop);
380  ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
381  for (SparseIterator *it : validIters)
382  iterVals = it->linkNewScope(iterVals);
383  assert(iterVals.empty());
384 
385  ValueRange curResult = loop->getResults().take_front(userReduc.size());
386  userReduc.assign(curResult.begin(), curResult.end());
387  } else {
388  // This is a simple iteration loop.
389  assert(caseBits.count() == 1);
390 
391  Block *block = &r.getBlocks().front();
392  ValueRange curResult = genLoopWithIterator(
393  rewriter, loc, validIters.front(), userReduc,
394  /*bodyBuilder=*/
395  [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
396  SparseIterator *it,
397  ValueRange reduc) -> SmallVector<Value> {
398  SmallVector<Value> blockArgs(reduc);
399  blockArgs.push_back(it->deref(rewriter, loc));
400  llvm::append_range(blockArgs, it->getCursor());
401 
402  Block *dstBlock = &dstRegion.getBlocks().front();
403  rewriter.inlineBlockBefore(
404  block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
405  auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
406  SmallVector<Value> result(yield.getResults());
407  rewriter.eraseOp(yield);
408  return result;
409  });
410 
411  userReduc.assign(curResult.begin(), curResult.end());
412  }
413  }
414 
415  rewriter.replaceOp(op, userReduc);
416  return success();
417  }
418 };
419 
420 } // namespace
421 
423  addConversion([](Type type) { return type; });
424  addConversion(convertIteratorType);
425  addConversion(convertIterSpaceType);
426 
427  addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
428  ValueRange inputs, Location loc) -> Value {
429  return builder
430  .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
431  .getResult(0);
432  });
433 }
434 
436  const TypeConverter &converter, RewritePatternSet &patterns) {
437 
438  IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
439  patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
440  SparseIterateOpConverter, SparseCoIterateOpConverter>(
441  converter, patterns.getContext());
442 }
static ValueRange genLoopWithIterator(PatternRewriter &rewriter, Location loc, SparseIterator *it, ValueRange reduc, function_ref< SmallVector< Value >(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> bodyBuilder)
static ValueRange genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, Value loopCrd, ArrayRef< std::unique_ptr< SparseIterator >> iters, ArrayRef< Region * > subCases, ArrayRef< Value > userReduc)
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIteratorType(IteratorType itTp, SmallVectorImpl< Type > &fields)
static std::optional< LogicalResult > convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl< Type > &fields)
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:151
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class is a wrapper around OneToNConversionPattern for matching against instances of a particular...
OneToNOpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Specialization of PatternRewriter that OneToNConversionPatterns use.
Block * applySignatureConversion(Block *block, OneToNTypeMapping &argumentConversion)
Applies the given argument conversion to the given block.
void replaceOp(Operation *op, ValueRange newValues, const OneToNTypeMapping &resultMapping)
Replaces the results of the operation with the specified list of values mapped back to the original t...
Stores a 1:N mapping of types and provides several useful accessors.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:615
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A simple wrapper to encode a bitset of (at most 64) levels, currently used by sparse_tensor....
Definition: SparseTensor.h:64
iterator_range< const_set_bits_iterator > bits() const
Definition: SparseTensor.h:75
A SparseIterationSpace represents a sparse set of coordinates defined by (possibly multiple) levels o...
static SparseIterationSpace fromValues(IterSpaceType dstTp, ValueRange values, unsigned tid)
Helper class that generates loop conditions, etc, to traverse a sparse tensor level.
virtual std::pair< Value, Value > genForCond(OpBuilder &b, Location l)
ValueRange forward(OpBuilder &b, Location l)
ValueRange linkNewScope(ValueRange pos)
std::pair< Value, ValueRange > genWhileCond(OpBuilder &b, Location l, ValueRange vs)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
std::pair< Operation *, Value > genCoIteration(OpBuilder &builder, Location loc, ArrayRef< SparseIterator * > iters, MutableArrayRef< Value > reduc, Value uniIdx, bool userReducFirst=false)
bool isDenseLT(LevelType lt)
Definition: Enums.h:413
Include the generated interface declarations.
void populateLowerSparseIterationToSCFPatterns(const TypeConverter &converter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...