MLIR  22.0.0git
SparseReinterpretMap.cpp
Go to the documentation of this file.
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
11 
20 #include "mlir/IR/AffineMap.h"
21 
22 using namespace mlir;
23 using namespace mlir::sparse_tensor;
24 
25 namespace {
26 
27 //===----------------------------------------------------------------------===//
28 // File Local Helper classes.
29 //===----------------------------------------------------------------------===//
30 
31 // CRTP to help implementing a rewriter that demaps all its inputs.
32 template <typename SubClass, typename SourceOp>
33 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35  using OpAdaptor = typename SourceOp::Adaptor;
36 
37  LogicalResult matchAndRewrite(SourceOp op,
38  PatternRewriter &rewriter) const override {
39  Location loc = op.getLoc();
40 
41  // Demaps non-trivial inputs.
42  bool changed = false;
43  SmallVector<Value> deMappedIns(op->getOperands());
44  for (Value &in : deMappedIns) {
45  if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
46  in =
47  ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
48  changed = true;
49  }
50  }
51 
52  // CRTP call.
53  OpAdaptor adaptor(deMappedIns, op);
54  LogicalResult status =
55  static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56  return changed ? success() : status;
57  }
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64  BitVector dims;
65 };
66 
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69  : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70  explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
71 
72  // We only allow AffineDimExpr on output.
73  void visitAddExpr(AffineBinaryOpExpr expr) {
74  if (isOutput)
75  admissible = false;
76  }
77  void visitMulExpr(AffineBinaryOpExpr expr) {
78  if (isOutput)
79  admissible = false;
80  }
81 
82  // We disallow mod, floor div and ceil div on inputs.
83  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
84  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
85  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86  operator bool() { return admissible; }
87 
88 private:
89  bool admissible = true;
90  bool isOutput;
91 };
92 
93 // The first BitVector stores levels where inadmissible exprs are used.
94 // The second BitVector stores the AffineDimExp that are used by the
95 // inadmissible expressions.
96 using InadmissInfo = std::pair<BitVector, BitVector>;
97 
98 } // namespace
99 
100 //===----------------------------------------------------------------------===//
101 // File Local Helper methods.
102 //===----------------------------------------------------------------------===//
103 
104 // Collects the inadmissible affine expression imposed on levels.
105 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
106  auto ret = std::make_pair(BitVector(map.getNumResults()),
107  BitVector(map.getNumDims()));
108  AffineDimCollector collector(map.getNumDims());
109  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
110  AffineExprAdmissibleVisitor admissible(isOutput);
111  admissible.walkPostOrder(map.getResult(lvl));
112  if (!admissible) {
113  // Record the inadmissible level.
114  ret.first.set(lvl);
115  // Record the AffineDimExpr that is used in the inadmissible expr.
116  collector.walkPostOrder(map.getResult(lvl));
117  }
118  }
119  ret.second = collector.dims;
120  return ret;
121 }
122 
123 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
124 // inadmissible affine expressions can be eliminated.
125 // For example, we can rewrite
126 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
127 // to
128 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
129 // by composing inverse(idxMap), that is
130 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
131 // -> ((l0 * 2 + l2) floordiv 2,
132 // (l1 * 3 + l3) floordiv 3,
133 // (l0 * 2 + l2) mod 2,
134 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
135 //
136 // This function builds the inverse(idxMap) that replace every dimensions used
137 // in `info` to levels, and updates the iterator type array `itTps` for the new
138 // index variable introduced.
139 //
140 // Note that the returned affine map does not retain the order of the input
141 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
142 // replaced levels, and remaining ones for unused dimensions.
143 // For example, to handle
144 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
145 // which is a typical map for block_2to4. The function returns:
146 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
147 // in which, (l0, l1) together replaces `d1`, yet they appear
148 // before `d0` in the resulting affine map.
149 // The index (loop) order can later be canonicalized by a topo sort.
150 static AffineMap
151 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153  MLIRContext *ctx = idxMap.getContext();
154  auto [inAdLvls, usedDims] = info;
155  // Note that idxMap does not equal to dim2Lvl map, it is computed by
156  // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
157  // ID map.
158  // TODO: we might fail here, in those case we should really return
159  // failure instead of assertion error.
160  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
161 
162  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
163  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
164  // This could happen when some dimensions are projected.
165  // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
166  // ==> lvl2Idx = (j, k) -> (j, k)
167  // In this case, we append the unused dimesion at the end.
168  // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
169  SmallVector<AffineExpr> results;
170  AffineDimCollector usedInLvl(idxMap.getNumDims());
171  for (auto e : idxMap.getResults())
172  usedInLvl.walkPostOrder(e);
173 
174  unsigned curUsedDimID = 0;
175  unsigned curUnusedDimID = lvl2Idx.getNumDims();
176 
177  BitVector unused = usedInLvl.dims.flip();
178  for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
179  if (unused.test(i))
180  results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
181  else
182  results.push_back(lvl2Idx.getResult(curUsedDimID++));
183  }
184  lvl2Idx =
185  AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
186  }
187  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
188 
189  // We do not need to replace the DimExpr that is not used in inadmissible
190  // level expressions. We use the first inAdLvl.count() dim to represent the
191  // replaced level, the remainings are reserved for unchanged ones.
192  // Note that results from the inverse map computed previously does not follow
193  // the convention we used, and we need to fix the mismatch below.
194  unsigned curRepID = 0;
195  unsigned curOriID = inAdLvls.count();
196  SmallVector<AffineExpr> results;
199 
200  for (unsigned l : inAdLvls.set_bits()) {
201  // By our convention, the inadmissible level `l` always appears in the
202  // leading part (accumulated by curRepID) of the affine map's parameter
203  // list. Record the mapping so that we can replace all the uses of `l` to
204  // the correct position after the translation.
205  dimRep[l] = getAffineDimExpr(curRepID++, ctx);
206  // A new index variable is introduced for the inadmissible level, inherit
207  // the iterator type. E.g., if l0 = d0 floordiv 2, the
208  // iterator type of l0 equals to the iterator type of d0.
209  AffineExpr lvlExp = idxMap.getResult(l);
210  AffineDimCollector collector(idxMap.getNumDims());
211  collector.walkPostOrder(lvlExp);
212  // We assumes a level can only be derived from one dimension.
213  assert(collector.dims.count() == 1);
214  transItTps.push_back(itTps[collector.dims.find_first()]);
215  }
216 
217  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
218  if (usedDims.test(d)) {
219  // The dimension is used in some of the inadmissible levels, and it need
220  // to be inversed. Get the inversion from the inverse map, and fix the
221  // mismatch captured by the above loop.
222  results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
223  } else {
224  // The dimension is not used in any of the inadmissible levels, and it
225  // does not need to be inversed. Fix the mismatch by mapping it to the
226  // trailing part of the affine map (accumulated by curOriID).
227  results.push_back(getAffineDimExpr(curOriID++, ctx));
228  transItTps.push_back(itTps[d]);
229  }
230  }
231  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
232  // Update iterator type.
233  itTps.assign(transItTps.begin(), transItTps.end());
234  return AffineMap::get(numDim, 0, results, ctx);
235 }
236 
237 // Translates the index map in the linalg::GenericOp from idx->dim map to
238 // idx->lvl map. Returns failure if the index map can not be translated to an
239 // admissible form.
240 // Returns the translated index map array and the iterator type array.
241 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
242 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
243  // idxMap is a idx2dim map before reinterpretation.
244  MLIRContext *ctx = op.getContext();
245  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
246  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
247  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248  Value tensor = op->getOpOperand(i).get();
249  auto stt = tryGetSparseTensorType(tensor);
250  if (stt && !stt->isIdentity()) {
251  AffineMap dim2Lvl = stt->getDimToLvl();
252  // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
253  idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
254  }
255  }
256 
257  // A naive way to handle common constant expressions that arise during dim2lvl
258  // translation.
259  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
260  unsigned pos, int64_t lvlSz) {
261  if (ShapedType::isStatic(lvlSz)) {
262  auto c0 = getAffineConstantExpr(0, ctx);
263  auto lvlExp = getAffineDimExpr(pos, ctx);
264  auto szExp = getAffineConstantExpr(lvlSz, ctx);
265 
266  // lvl floordiv lvlSz = 0
267  auto divExp =
269  cstMapping.try_emplace(divExp, c0);
270 
271  // lvl mod lvlSz = lvl
272  auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
273  cstMapping.try_emplace(modExp, lvlExp);
274  }
275  };
276 
277  unsigned boundedNum = 0;
278  // A fixed-point algorithm.
279  bool changed = true;
280  while (changed) {
281  changed = false;
282  for (OpOperand &operand : op->getOpOperands()) {
283  auto stt = tryGetSparseTensorType(operand.get());
284  // Skip on dense operands.
285  if (!stt || !stt->getEncoding())
286  continue;
287 
288  unsigned tid = operand.getOperandNumber();
289  bool isOutput = &operand == op.getDpsInitOperand(0);
290  AffineMap idxMap = idxMapArray[tid];
291  InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
292  auto [inAdLvls, dimExprs] = inAdInfo;
293  for (unsigned d : dimExprs.set_bits()) {
294  // The first `boundedNum` used in the AffineMap is introduced to
295  // resolve previous inadmissible expressions. We can not replace them
296  // as it might bring back the inadmissible expressions.
297  if (d < boundedNum)
298  return std::nullopt;
299  }
300 
301  if (inAdLvls.count() != 0) {
302  // Naive constant progagation, should be sufficient to handle block
303  // sparsity in our cases.
304  SmallVector<int64_t> lvlShape = stt->getLvlShape();
306  unsigned position = 0;
307  for (unsigned lvl : inAdLvls.set_bits()) {
308  int64_t lvlSz = lvlShape[lvl];
309  populateCstMapping(cstMapping, position, lvlSz);
310  position++;
311  }
312 
313  AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
314  // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
315  // inadmissible expressions.
316  for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317  AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318  idxMapArray[tid] = transMap.replace(
319  cstMapping, /*numResultDims=*/transMap.getNumDims(),
320  /*numResultSyms=*/0);
321  }
322  changed = true;
323  boundedNum += inAdLvls.count();
324  }
325  }
326  };
327 
328  SmallVector<Attribute> iterAttr =
329  llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
330  return linalg::IteratorTypeAttr::get(ctx, itTp);
331  });
332 
333  return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
334  rewriter.getArrayAttr(iterAttr));
335 }
336 
337 // Generates a "de"mapping reinterpretation of the map.
338 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
339  Value val) {
340  return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(),
341  val);
342 }
343 
344 // Generates a "re"mapping reinterpretation of the map.
345 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
346  Value val) {
347  return ReinterpretMapOp::create(builder, val.getLoc(), enc, val);
348 }
349 
351  ValueRange outs) {
352  SmallVector<Value> ret(outs);
353  assert(outs.size() == types.size());
354  for (auto [r, t] : llvm::zip(ret, types))
355  if (r.getType() != t)
356  r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
357  return ret;
358 }
359 
360 namespace {
361 
362 //===----------------------------------------------------------------------===//
363 // Rewriting rules for linalg generic ops.
364 //===----------------------------------------------------------------------===//
365 
366 /// Sparse rewriting rule for the generic `linalg` operation.
367 struct GenericOpReinterpretMap
368  : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
369 public:
370  using DemapInsRewriter::DemapInsRewriter;
371  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
372  PatternRewriter &rewriter) const {
373  // Only rewrite single output operations with pure (sparse) tensor
374  // semantics.
375  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
376  !hasAnySparseOperandOrResult(linalgOp) ||
378  return failure();
379 
380  // Try translating the index map.
381  auto transMap = translateMap(linalgOp, rewriter);
382  if (!transMap)
383  return rewriter.notifyMatchFailure(
384  linalgOp, "the sparse kernel can not be sparsified.");
385 
386  // On success, replace update the linalg operands and maps in place.
387  Value res = linalgOp.getResult(0);
388  auto stt = tryGetSparseTensorType(res);
389  auto [idxMap, itTp] = *transMap;
390 
391  rewriter.startOpModification(linalgOp);
392  linalgOp.setIndexingMapsAttr(idxMap);
393  linalgOp.setIteratorTypesAttr(itTp);
394  // Use demapped arguments.
395  linalgOp.getInputsMutable().assign(adaptor.getInputs());
396  linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397  res.setType(adaptor.getOutputs()[0].getType());
398  rewriter.finalizeOpModification(linalgOp);
399 
400  rewriter.setInsertionPointAfter(linalgOp);
401  if (stt && stt->hasEncoding()) {
402  Value t = genRemap(rewriter, stt->getEncoding(), res);
403  rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
404  }
405  return success();
406  }
407 };
408 
409 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
411  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
412  PatternRewriter &rewriter) const override {
413  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
414  hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
415  !hasAnySparseOperandOrResult(linalgOp)) {
416  return failure();
417  }
418 
419  const StringRef sorted = "sorted";
420  if (linalgOp->hasAttr(sorted))
421  return failure();
422 
423  auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
424  bool isAdmissible = false;
425  AffineMap order;
426  // A const list of all masks that we used for iteration graph
427  // computation. Must be ordered from more strict to less strict.
428  // Ideally (though might not be guaranteed), the earlier a constraint mask
429  // can be satisfied, the faster the generated kernel will be.
430  const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
434  for (const SortMask mask : allMasks) {
435  order = scheduler.sort(mask);
436  if (order) {
437  if (isAdmissibleOrder(linalgOp, order)) {
438  isAdmissible = true;
439  break;
440  }
441  // else try a set of less strict constraints.
442  }
443  }
444 
445  if (!order) {
446  // Cycles detected.
447  if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
448  return rewriter.notifyMatchFailure(
449  linalgOp, "the sparse kernel can not be scheduled: loop detected.");
450  }
451  return success();
452  }
453 
454  if (!isAdmissible) {
455  return rewriter.notifyMatchFailure(
456  linalgOp, "the sparse kernel can not be scheduled.");
457  }
458 
459  // Marks the GenericOp to avoid recursive matching.
460  rewriter.modifyOpInPlace(linalgOp, [&]() {
461  linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
462  });
463 
464  // Already sorted.
465  if (order.isIdentity())
466  return success();
467 
468  assert(order.isPermutation());
469  // `order` is orignial loop -> sorted loop map
470  ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
471  SmallVector<Attribute> curItTypes;
472  curItTypes.reserve(preItTypes.size());
473  for (AffineExpr expr : order.getResults()) {
474  unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
475  curItTypes.push_back(preItTypes[loopID]);
476  }
477 
478  // Inverse `order` to get sorted loop -> original loop map
479  order = inversePermutation(order);
480  SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
481  for (AffineMap &idxMap : idxMaps)
482  idxMap = idxMap.compose(order); // sorted loop -> lvl map
483 
484  rewriter.startOpModification(linalgOp);
485  linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
486  linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
487  rewriter.finalizeOpModification(linalgOp);
488 
489  return success();
490  }
491 
492 private:
493  /// Whether the loop order is admissible by sparsification.
494  static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
495  if (!hasAnySparseResult(linalgOp))
496  return true;
497 
498  OpOperand *lhs = linalgOp.getDpsInitOperand(0);
499  unsigned nest = 0;
500  const auto iteratorTypes = linalgOp.getIteratorTypesArray();
501  for (const AffineExpr l : order.getResults()) {
502  unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
503  auto itTp =
504  cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
505  if (linalg::isReductionIterator(itTp.getValue()))
506  break; // terminate at first reduction
507  nest++;
508  }
509  // Determine admissible dynamic insertion situations:
510  // (1) fully injective, since there are no reductions,
511  // (2) admissible 1-d expansion in innermost dimension.
512  return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
513  };
514 
515  // Last resort cycle resolution.
516  static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
517  linalg::LinalgOp linalgOp,
518  PatternRewriter &rewriter) {
519  // Compute topological sort while leaving out every sparse input tensor in
520  // succession until an acylic iteration graph results.
521  for (OpOperand *t : linalgOp.getDpsInputOperands()) {
522  Value tval = t->get();
523  auto srcEnc = getSparseTensorEncoding(tval.getType());
524  // The constraints introduced by compound index expression are
525  // complicated. Skip them.
526  AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
527  bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
528  return !llvm::isa<AffineDimExpr>(exp);
529  });
530  if (!srcEnc || hasCompExpr)
531  continue;
532 
533  // Try scheduling loop without constraints from `tval`.
534  AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
535  if (!order) // still cyclic
536  continue;
537 
538  // Found an input tensor that resolves the cycle by inserting a
539  // conversion into a sparse tensor that adheres to the iteration
540  // graph order.
541  auto stt = getSparseTensorType(tval);
542  assert(stt.isIdentity());
543  order = inversePermutation(order);
544  // sorted loop -> lvl map.
545  idxMap = idxMap.compose(order);
546 
547  // Found a permutation such that the results in `idxMap` is sorted.
548  // For example,
549  // (d0, d1, d2, d3) -> (d2, d1, d0)
550  // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
551  // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
552  // transposed tensor's levels are visited in the same order as the loop
553  // scheduling order.
555  for (AffineExpr expr : idxMap.getResults()) {
556  unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
557  lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
558  }
559  llvm::sort(lvlSeq, llvm::less_first());
560  SmallVector<unsigned> perm =
561  llvm::to_vector(llvm::make_second_range(lvlSeq));
562  auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
563  // The result of the idxMap must be unsorted.
564  assert(!dimToLvl.isIdentity());
565 
566  // Inserting the transpose
567  rewriter.setInsertionPoint(linalgOp);
568  RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
569  Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
570  rewriter.modifyOpInPlace(linalgOp, [&]() {
571  linalgOp->setOperand(t->getOperandNumber(), dst);
572  });
573 
574  // Release the transposed form afterwards.
575  // TODO: CSE when used in more than one following op?
576  rewriter.setInsertionPointAfter(linalgOp);
577  bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst);
578 
579  return success();
580  }
581  // Cannot be resolved with a single conversion.
582  // TODO: convert more than one?
583  return failure();
584  }
585 };
586 
587 //===----------------------------------------------------------------------===//
588 // Reinterpret Map Rewriters for operations other than linalg.generics
589 //===----------------------------------------------------------------------===//
590 
591 template <typename AllocOp>
592 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
594  LogicalResult matchAndRewrite(AllocOp op,
595  PatternRewriter &rewriter) const override {
597  return failure();
598 
599  Location loc = op.getLoc();
600  auto stt = getSparseTensorType(op.getResult());
601 
602  SmallVector<Value> maxDimCrds;
603  maxDimCrds.reserve(stt.getDimRank());
604  ValueRange dynSz = op.getDynamicSizes();
605  for (int64_t dimSz : stt.getDimShape()) {
606  if (ShapedType::isDynamic(dimSz)) {
607  Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
608  constantIndex(rewriter, loc, 1));
609  maxDimCrds.push_back(maxCrd);
610  dynSz = dynSz.drop_front();
611  } else {
612  maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
613  }
614  }
615 
616  ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
617  CrdTransDirectionKind::dim2lvl);
618  auto lvlShape = stt.getLvlShape();
619  SmallVector<Value> dynLvlSzs;
620  for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
621  if (ShapedType::isDynamic(lvlShape[i])) {
622  Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
623  constantIndex(rewriter, loc, 1));
624  dynLvlSzs.push_back(sz);
625  }
626  }
627 
628  assert(dynSz.empty()); // should have consumed all.
629  rewriter.startOpModification(op);
630  op->setOperands(dynLvlSzs);
631  op.getResult().setType(stt.getDemappedType());
632  rewriter.finalizeOpModification(op);
633  rewriter.setInsertionPointAfter(op);
634 
635  Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
636  rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
637  return success();
638  }
639 };
640 
641 struct TensorInsertDemapper
642  : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
643  using DemapInsRewriter::DemapInsRewriter;
644  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
645  PatternRewriter &rewriter) const {
647  return failure();
648 
649  Location loc = op.getLoc();
650  auto stt = getSparseTensorType(op.getResult());
651  ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
652  CrdTransDirectionKind::dim2lvl);
653  auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
654  adaptor.getDest(), lvlCrd);
655 
656  Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
657  rewriter.replaceOp(op, out);
658  return success();
659  }
660 };
661 
662 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
664  LogicalResult matchAndRewrite(AssembleOp op,
665  PatternRewriter &rewriter) const override {
667  return failure();
668 
669  assert(hasAnySparseResult(op));
670  auto stt = getSparseTensorType(op.getResult());
671  rewriter.modifyOpInPlace(
672  op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
673  rewriter.setInsertionPointAfter(op);
674  Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
675  rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
676  return success();
677  }
678 };
679 
680 struct SparseDisassembleDemapper
681  : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
682  using DemapInsRewriter::DemapInsRewriter;
683  LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
684  PatternRewriter &rewriter) const {
686  return failure();
687 
688  assert(hasAnySparseOperandOrResult(op));
689  rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
690  op.getTensorMutable().assign(adaptor.getTensor());
691  });
692  return success();
693  }
694 };
695 
696 struct ForeachOpDemapper
697  : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
698  using DemapInsRewriter::DemapInsRewriter;
699  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
700  PatternRewriter &rewriter) const {
701  // Only handle operations with sparse input/output with non-identity dim2lvl
702  // maps.
704  return failure();
705 
706  // TODO: demap constant as well.
707  if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
708  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
709  return failure();
710 
711  Location loc = op.getLoc();
712  // Cache the type information since we update the foreach op in-place.
713  auto srcStt = getSparseTensorType(op.getTensor());
714  SmallVector<Type> prevRetTps(op.getResultTypes());
715 
716  rewriter.startOpModification(op);
717  op.getTensorMutable().assign(adaptor.getTensor());
718  op.getInitArgsMutable().assign(adaptor.getInitArgs());
719  // Update results' types.
720  for (auto r : op.getResults())
721  if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
722  r.setType(stt->getDemappedType());
723 
724  Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
725  // Update the foreach body.
726  SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
727  blockArgTps.push_back(srcStt.getElementType());
728  blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
729  adaptor.getInitArgs().getTypes().end());
730  Block *body = op.getBody();
731  // Block Args: [dimCrd, val, initArgs]
732  unsigned preArgNum = body->getNumArguments();
733  for (Type t : blockArgTps)
734  body->addArgument(t, loc);
735 
736  // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
737  rewriter.setInsertionPointToStart(body);
738  ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
739 
740  ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
741  CrdTransDirectionKind::lvl2dim);
742  rewriter.replaceAllUsesWith(
743  body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
744  body->eraseArguments(0, srcStt.getDimRank());
745  // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
746  unsigned numInitArgs = op.getInitArgs().size();
747  rewriter.replaceAllUsesWith(body->getArgument(0),
748  body->getArgument(lvlRank + numInitArgs + 1));
749  body->eraseArgument(0);
750  // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
751  ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
752  ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
753  // Remap back before replacement.
754  SmallVector<Value> reMappedArgs =
755  remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
756  rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
757  body->eraseArguments(0, numInitArgs);
758  // Block Args: [lvlCrds, DemappedArgs] and we are done.
759 
760  // Update yield operations.
761  if (numInitArgs != 0) {
762  rewriter.setInsertionPointToEnd(body);
763  auto yield = llvm::cast<YieldOp>(body->getTerminator());
764  if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
765  stt && !stt->isIdentity()) {
766  Value y =
767  genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
768  YieldOp::create(rewriter, loc, y);
769  rewriter.eraseOp(yield);
770  }
771  }
772  rewriter.finalizeOpModification(op);
773 
774  rewriter.setInsertionPointAfter(op);
775  SmallVector<Value> outs =
776  remapValueRange(rewriter, prevRetTps, op.getResults());
777 
778  // Replace all the uses of the foreach results, expect the use in
779  // reinterpret_map used to remap the output.
780  for (auto [from, to] : llvm::zip(op.getResults(), outs))
781  rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
782 
783  return success();
784  }
785 };
786 
787 } // namespace
788 
790  ReinterpretMapScope scope) {
791  if (scope == ReinterpretMapScope::kAll ||
793  patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
794  patterns.getContext());
795  }
796  if (scope == ReinterpretMapScope::kAll ||
798  patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
799  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
800  SparseDisassembleDemapper, TensorInsertDemapper,
801  ForeachOpDemapper>(patterns.getContext());
802  }
803 }
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
Definition: AffineExpr.h:214
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:511
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
BlockArgListType getArguments()
Definition: Block.h:87
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:193
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:116
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:239
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
Definition: SparseTensor.h:196
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Definition: SparseTensor.h:191
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:68
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:45
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322