MLIR  19.0.0git
SparseReinterpretMap.cpp
Go to the documentation of this file.
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
11 
21 #include "mlir/IR/AffineMap.h"
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
31 
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
36  using OpAdaptor = typename SourceOp::Adaptor;
37 
38  LogicalResult matchAndRewrite(SourceOp op,
39  PatternRewriter &rewriter) const override {
40  Location loc = op.getLoc();
41 
42  // Demaps non-trivial inputs.
43  bool changed = false;
44  SmallVector<Value> deMappedIns(op->getOperands());
45  for (Value &in : deMappedIns) {
46  if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47  in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48  changed = true;
49  }
50  }
51 
52  // CRTP call.
53  OpAdaptor adaptor(deMappedIns, op);
54  LogicalResult status =
55  static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56  return changed ? success() : status;
57  }
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64  BitVector dims;
65 };
66 
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69  : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70  explicit AffineExprAdmissibleVisitor(bool isOutput)
71  : admissible(true), isOutput(isOutput){};
72 
73  // We only allow AffineDimExpr on output.
74  void visitAddExpr(AffineBinaryOpExpr expr) {
75  if (isOutput)
76  admissible = false;
77  }
78  void visitMulExpr(AffineBinaryOpExpr expr) {
79  if (isOutput)
80  admissible = false;
81  }
82 
83  // We disallow mod, floor div and ceil div on inputs.
84  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
85  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
87  operator bool() { return admissible; }
88 
89 private:
90  bool admissible;
91  bool isOutput;
92 };
93 
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo = std::pair<BitVector, BitVector>;
98 
99 } // namespace
100 
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
104 
105 // Collects the inadmissible affine expression imposed on levels.
106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107  auto ret = std::make_pair(BitVector(map.getNumResults()),
108  BitVector(map.getNumDims()));
109  AffineDimCollector collector(map.getNumDims());
110  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111  AffineExprAdmissibleVisitor admissible(isOutput);
112  admissible.walkPostOrder(map.getResult(lvl));
113  if (!admissible) {
114  // Record the inadmissible level.
115  ret.first.set(lvl);
116  // Record the AffineDimExpr that is used in the inadmissible expr.
117  collector.walkPostOrder(map.getResult(lvl));
118  }
119  }
120  ret.second = collector.dims;
121  return ret;
122 }
123 
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128 // to
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 // -> ((l0 * 2 + l2) floordiv 2,
133 // (l1 * 3 + l3) floordiv 3,
134 // (l0 * 2 + l2) mod 2,
135 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136 //
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
140 //
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
151 static AffineMap
152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
154  MLIRContext *ctx = idxMap.getContext();
155  auto [inAdLvls, usedDims] = info;
156  // Note that idxMap does not equal to dim2Lvl map, it is computed by
157  // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158  // ID map.
159  // TODO: we might fail here, in those case we should really return
160  // failure instead of assertion error.
161  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162 
163  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165  // This could happen when some dimensions are projected.
166  // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167  // ==> lvl2Idx = (j, k) -> (j, k)
168  // In this case, we append the unused dimesion at the end.
169  // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170  SmallVector<AffineExpr> results;
171  AffineDimCollector usedInLvl(idxMap.getNumDims());
172  for (auto e : idxMap.getResults())
173  usedInLvl.walkPostOrder(e);
174 
175  unsigned curUsedDimID = 0;
176  unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 
178  BitVector unused = usedInLvl.dims.flip();
179  for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180  if (unused.test(i))
181  results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182  else
183  results.push_back(lvl2Idx.getResult(curUsedDimID++));
184  }
185  lvl2Idx =
186  AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187  }
188  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189 
190  // We do not need to replace the DimExpr that is not used in inadmissible
191  // level expressions. We use the first inAdLvl.count() dim to represent the
192  // replaced level, the remainings are reserved for unchanged ones.
193  // Note that results from the inverse map computed previously does not follow
194  // the convention we used, and we need to fix the mismatch below.
195  unsigned curRepID = 0;
196  unsigned curOriID = inAdLvls.count();
197  SmallVector<AffineExpr> results;
200 
201  for (unsigned l : inAdLvls.set_bits()) {
202  // By our convention, the inadmissible level `l` always appears in the
203  // leading part (accumulated by curRepID) of the affine map's parameter
204  // list. Record the mapping so that we can replace all the uses of `l` to
205  // the correct position after the translation.
206  dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207  // A new index variable is introduced for the inadmissible level, inherit
208  // the iterator type. E.g., if l0 = d0 floordiv 2, the
209  // iterator type of l0 equals to the iterator type of d0.
210  AffineExpr lvlExp = idxMap.getResult(l);
211  AffineDimCollector collector(idxMap.getNumDims());
212  collector.walkPostOrder(lvlExp);
213  // We assumes a level can only be derived from one dimension.
214  assert(collector.dims.count() == 1);
215  transItTps.push_back(itTps[collector.dims.find_first()]);
216  }
217 
218  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219  if (usedDims.test(d)) {
220  // The dimension is used in some of the inadmissible levels, and it need
221  // to be inversed. Get the inversion from the inverse map, and fix the
222  // mismatch captured by the above loop.
223  results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224  } else {
225  // The dimension is not used in any of the inadmissible levels, and it
226  // does not need to be inversed. Fix the mismatch by mapping it to the
227  // trailing part of the affine map (accumulated by curOriID).
228  results.push_back(getAffineDimExpr(curOriID++, ctx));
229  transItTps.push_back(itTps[d]);
230  }
231  }
232  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233  // Update iterator type.
234  itTps.assign(transItTps.begin(), transItTps.end());
235  return AffineMap::get(numDim, 0, results, ctx);
236 }
237 
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
240 // admissible form.
241 // Returns the translated index map array and the iterator type array.
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244  // idxMap is a idx2dim map before reinterpretation.
245  MLIRContext *ctx = op.getContext();
246  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249  Value tensor = op->getOpOperand(i).get();
250  auto stt = tryGetSparseTensorType(tensor);
251  if (stt && !stt->isIdentity()) {
252  AffineMap dim2Lvl = stt->getDimToLvl();
253  // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254  idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255  }
256  }
257 
258  // A naive way to handle common constant expressions that arise during dim2lvl
259  // translation.
260  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261  unsigned pos, int64_t lvlSz) {
262  if (!ShapedType::isDynamic(lvlSz)) {
263  auto c0 = getAffineConstantExpr(0, ctx);
264  auto lvlExp = getAffineDimExpr(pos, ctx);
265  auto szExp = getAffineConstantExpr(lvlSz, ctx);
266 
267  // lvl floordiv lvlSz = 0
268  auto divExp =
270  cstMapping.try_emplace(divExp, c0);
271 
272  // lvl mod lvlSz = lvl
273  auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274  cstMapping.try_emplace(modExp, lvlExp);
275  }
276  };
277 
278  unsigned boundedNum = 0;
279  // A fixed-point algorithm.
280  bool changed = true;
281  while (changed) {
282  changed = false;
283  for (OpOperand &operand : op->getOpOperands()) {
284  auto stt = tryGetSparseTensorType(operand.get());
285  // Skip on dense operands.
286  if (!stt || !stt->getEncoding())
287  continue;
288 
289  unsigned tid = operand.getOperandNumber();
290  bool isOutput = &operand == op.getDpsInitOperand(0);
291  AffineMap idxMap = idxMapArray[tid];
292  InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293  auto [inAdLvls, dimExprs] = inAdInfo;
294  for (unsigned d : dimExprs.set_bits()) {
295  // The first `boundedNum` used in the AffineMap is introduced to
296  // resolve previous inadmissible expressions. We can not replace them
297  // as it might bring back the inadmissible expressions.
298  if (d < boundedNum)
299  return std::nullopt;
300  }
301 
302  if (inAdLvls.count() != 0) {
303  // Naive constant progagation, should be sufficient to handle block
304  // sparsity in our cases.
305  SmallVector<int64_t> lvlShape = stt->getLvlShape();
307  unsigned position = 0;
308  for (unsigned lvl : inAdLvls.set_bits()) {
309  int64_t lvlSz = lvlShape[lvl];
310  populateCstMapping(cstMapping, position, lvlSz);
311  position++;
312  }
313 
314  AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315  // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316  // inadmissible expressions.
317  for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318  AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319  idxMapArray[tid] = transMap.replace(
320  cstMapping, /*numResultDims=*/transMap.getNumDims(),
321  /*numResultSyms=*/0);
322  }
323  changed = true;
324  boundedNum += inAdLvls.count();
325  }
326  }
327  };
328 
329  SmallVector<Attribute> iterAttr =
330  llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331  return linalg::IteratorTypeAttr::get(ctx, itTp);
332  });
333 
334  return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335  rewriter.getArrayAttr(iterAttr));
336 }
337 
338 // Generates a "de"mapping reinterpretation of the map.
339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340  Value val) {
341  return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342  val);
343 }
344 
345 // Generates a "re"mapping reinterpretation of the map.
346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347  Value val) {
348  return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349 }
350 
352  ValueRange outs) {
353  SmallVector<Value> ret(outs);
354  assert(outs.size() == types.size());
355  for (auto [r, t] : llvm::zip(ret, types))
356  if (r.getType() != t)
357  r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358  return ret;
359 }
360 
361 namespace {
362 
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
366 
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369  : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 public:
371  using DemapInsRewriter::DemapInsRewriter;
372  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373  PatternRewriter &rewriter) const {
374  // Only rewrite single output operations with pure (sparse) tensor
375  // semantics.
376  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377  !hasAnySparseOperandOrResult(linalgOp) ||
379  return failure();
380 
381  // Try translating the index map.
382  auto transMap = translateMap(linalgOp, rewriter);
383  if (!transMap)
384  return rewriter.notifyMatchFailure(
385  linalgOp, "the sparse kernel can not be sparsified.");
386 
387  // On success, replace update the linalg operands and maps in place.
388  Value res = linalgOp.getResult(0);
389  auto stt = tryGetSparseTensorType(res);
390  auto [idxMap, itTp] = *transMap;
391 
392  rewriter.startOpModification(linalgOp);
393  linalgOp.setIndexingMapsAttr(idxMap);
394  linalgOp.setIteratorTypesAttr(itTp);
395  // Use demapped arguments.
396  linalgOp.getInputsMutable().assign(adaptor.getInputs());
397  linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398  res.setType(adaptor.getOutputs()[0].getType());
399  rewriter.finalizeOpModification(linalgOp);
400 
401  rewriter.setInsertionPointAfter(linalgOp);
402  if (stt && stt->hasEncoding()) {
403  Value t = genRemap(rewriter, stt->getEncoding(), res);
404  rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405  }
406  return success();
407  }
408 };
409 
410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
412  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413  PatternRewriter &rewriter) const override {
414  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415  hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416  !hasAnySparseOperandOrResult(linalgOp)) {
417  return failure();
418  }
419 
420  const StringRef sorted = "sorted";
421  if (linalgOp->hasAttr(sorted))
422  return failure();
423 
424  auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
425  bool isAdmissible = false;
426  AffineMap order;
427  // A const list of all masks that we used for iteration graph
428  // computation. Must be ordered from more strict to less strict.
429  // Ideally (though might not be guaranteed), the earlier a constraint mask
430  // can be satisfied, the faster the generated kernel will be.
431  const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435  for (const SortMask mask : allMasks) {
436  order = scheduler.sort(mask);
437  if (order) {
438  if (isAdmissibleOrder(linalgOp, order)) {
439  isAdmissible = true;
440  break;
441  }
442  // else try a set of less strict constraints.
443  }
444  }
445 
446  if (!order) {
447  // Cycles detected.
448  if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449  return rewriter.notifyMatchFailure(
450  linalgOp, "the sparse kernel can not be scheduled: loop detected.");
451  }
452  return success();
453  }
454 
455  if (!isAdmissible) {
456  return rewriter.notifyMatchFailure(
457  linalgOp, "the sparse kernel can not be scheduled.");
458  }
459 
460  // Marks the GenericOp to avoid recursive matching.
461  rewriter.modifyOpInPlace(linalgOp, [&]() {
462  linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463  });
464 
465  // Already sorted.
466  if (order.isIdentity())
467  return success();
468 
469  assert(order.isPermutation());
470  // `order` is orignial loop -> sorted loop map
471  ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472  SmallVector<Attribute> curItTypes;
473  curItTypes.reserve(preItTypes.size());
474  for (AffineExpr expr : order.getResults()) {
475  unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476  curItTypes.push_back(preItTypes[loopID]);
477  }
478 
479  // Inverse `order` to get sorted loop -> original loop map
480  order = inversePermutation(order);
481  SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482  for (AffineMap &idxMap : idxMaps)
483  idxMap = idxMap.compose(order); // sorted loop -> lvl map
484 
485  rewriter.startOpModification(linalgOp);
486  linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487  linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488  rewriter.finalizeOpModification(linalgOp);
489 
490  return success();
491  }
492 
493 private:
494  /// Whether the loop order is admissible by sparsification.
495  static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496  if (!hasAnySparseResult(linalgOp))
497  return true;
498 
499  OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500  unsigned nest = 0;
501  const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502  for (const AffineExpr l : order.getResults()) {
503  unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504  auto itTp =
505  cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
506  if (linalg::isReductionIterator(itTp.getValue()))
507  break; // terminate at first reduction
508  nest++;
509  }
510  // Determine admissible dynamic insertion situations:
511  // (1) fully injective, since there are no reductions,
512  // (2) admissible 1-d expansion in innermost dimension.
513  return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
514  };
515 
516  // Last resort cycle resolution.
517  static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518  linalg::LinalgOp linalgOp,
519  PatternRewriter &rewriter) {
520  // Compute topological sort while leaving out every sparse input tensor in
521  // succession until an acylic iteration graph results.
522  for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523  Value tval = t->get();
524  auto srcEnc = getSparseTensorEncoding(tval.getType());
525  // The constraints introduced by compound index expression are
526  // complicated. Skip them.
527  AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528  bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529  return !llvm::isa<AffineDimExpr>(exp);
530  });
531  if (!srcEnc || hasCompExpr)
532  continue;
533 
534  // Try scheduling loop without constraints from `tval`.
535  AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536  if (!order) // still cyclic
537  continue;
538 
539  // Found an input tensor that resolves the cycle by inserting a
540  // conversion into a sparse tensor that adheres to the iteration
541  // graph order.
542  auto stt = getSparseTensorType(tval);
543  assert(stt.isIdentity());
544  order = inversePermutation(order);
545  // sorted loop -> lvl map.
546  idxMap = idxMap.compose(order);
547 
548  // Found a permutation such that the results in `idxMap` is sorted.
549  // For example,
550  // (d0, d1, d2, d3) -> (d2, d1, d0)
551  // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552  // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553  // transposed tensor's levels are visited in the same order as the loop
554  // scheduling order.
556  for (AffineExpr expr : idxMap.getResults()) {
557  unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558  lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559  }
560  std::sort(lvlSeq.begin(), lvlSeq.end(), [](auto &lhs, auto &rhs) -> bool {
561  return lhs.first < rhs.first;
562  });
563  SmallVector<unsigned> perm =
564  llvm::to_vector(llvm::make_second_range(lvlSeq));
565  auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
566  // The result of the idxMap must be unsorted.
567  assert(!dimToLvl.isIdentity());
568 
569  // Inserting the transpose
570  rewriter.setInsertionPoint(linalgOp);
571  RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
572  Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
573  rewriter.modifyOpInPlace(linalgOp, [&]() {
574  linalgOp->setOperand(t->getOperandNumber(), dst);
575  });
576 
577  // Release the transposed form afterwards.
578  // TODO: CSE when used in more than one following op?
579  rewriter.setInsertionPointAfter(linalgOp);
580  rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
581 
582  return success();
583  }
584  // Cannot be resolved with a single conversion.
585  // TODO: convert more than one?
586  return failure();
587  }
588 };
589 
590 //===----------------------------------------------------------------------===//
591 // Reinterpret Map Rewriters for operations other than linalg.generics
592 //===----------------------------------------------------------------------===//
593 
594 template <typename AllocOp>
595 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
597  LogicalResult matchAndRewrite(AllocOp op,
598  PatternRewriter &rewriter) const override {
600  return failure();
601 
602  Location loc = op.getLoc();
603  auto stt = getSparseTensorType(op.getResult());
604 
605  SmallVector<Value> maxDimCrds;
606  maxDimCrds.reserve(stt.getDimRank());
607  ValueRange dynSz = op.getDynamicSizes();
608  for (int64_t dimSz : stt.getDimShape()) {
609  if (ShapedType::isDynamic(dimSz)) {
610  Value maxCrd = rewriter.create<arith::SubIOp>(
611  loc, dynSz.front(), constantIndex(rewriter, loc, 1));
612  maxDimCrds.push_back(maxCrd);
613  dynSz = dynSz.drop_front();
614  } else {
615  maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
616  }
617  }
618 
619  ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
620  CrdTransDirectionKind::dim2lvl);
621  auto lvlShape = stt.getLvlShape();
622  SmallVector<Value> dynLvlSzs;
623  for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
624  if (ShapedType::isDynamic(lvlShape[i])) {
625  Value sz = rewriter.create<arith::AddIOp>(
626  loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
627  dynLvlSzs.push_back(sz);
628  }
629  }
630 
631  assert(dynSz.empty()); // should have consumed all.
632  rewriter.startOpModification(op);
633  op->setOperands(dynLvlSzs);
634  op.getResult().setType(stt.getDemappedType());
635  rewriter.finalizeOpModification(op);
636  rewriter.setInsertionPointAfter(op);
637 
638  Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
639  rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
640  return success();
641  }
642 };
643 
644 struct TensorInsertDemapper
645  : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
646  using DemapInsRewriter::DemapInsRewriter;
647  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
648  PatternRewriter &rewriter) const {
650  return failure();
651 
652  Location loc = op.getLoc();
653  auto stt = getSparseTensorType(op.getResult());
654  ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
655  CrdTransDirectionKind::dim2lvl);
656  auto insertOp = rewriter.create<tensor::InsertOp>(
657  loc, op.getScalar(), adaptor.getDest(), lvlCrd);
658 
659  Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
660  rewriter.replaceOp(op, out);
661  return success();
662  }
663 };
664 
665 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
667  LogicalResult matchAndRewrite(AssembleOp op,
668  PatternRewriter &rewriter) const override {
670  return failure();
671 
672  assert(hasAnySparseResult(op));
673  auto stt = getSparseTensorType(op.getResult());
674  rewriter.modifyOpInPlace(
675  op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
676  rewriter.setInsertionPointAfter(op);
677  Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
678  rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
679  return success();
680  }
681 };
682 
683 struct SparseDisassembleDemapper
684  : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
685  using DemapInsRewriter::DemapInsRewriter;
686  LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
687  PatternRewriter &rewriter) const {
689  return failure();
690 
691  assert(hasAnySparseOperandOrResult(op));
692  rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
693  op.getTensorMutable().assign(adaptor.getTensor());
694  });
695  return success();
696  }
697 };
698 
699 struct ForeachOpDemapper
700  : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
701  using DemapInsRewriter::DemapInsRewriter;
702  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
703  PatternRewriter &rewriter) const {
704  // Only handle operations with sparse input/output with non-identity dim2lvl
705  // maps.
707  return failure();
708 
709  // TODO: demap constant as well.
710  if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
711  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
712  return failure();
713 
714  Location loc = op.getLoc();
715  // Cache the type information since we update the foreach op in-place.
716  auto srcStt = getSparseTensorType(op.getTensor());
717  SmallVector<Type> prevRetTps(op.getResultTypes());
718 
719  rewriter.startOpModification(op);
720  op.getTensorMutable().assign(adaptor.getTensor());
721  op.getInitArgsMutable().assign(adaptor.getInitArgs());
722  // Update results' types.
723  for (auto r : op.getResults())
724  if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
725  r.setType(stt->getDemappedType());
726 
727  Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
728  // Update the foreach body.
729  SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
730  blockArgTps.push_back(srcStt.getElementType());
731  blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
732  adaptor.getInitArgs().getTypes().end());
733  Block *body = op.getBody();
734  // Block Args: [dimCrd, val, initArgs]
735  unsigned preArgNum = body->getNumArguments();
736  for (Type t : blockArgTps)
737  body->addArgument(t, loc);
738 
739  // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
740  rewriter.setInsertionPointToStart(body);
741  ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
742 
743  ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
744  CrdTransDirectionKind::lvl2dim);
745  rewriter.replaceAllUsesWith(
746  body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
747  body->eraseArguments(0, srcStt.getDimRank());
748  // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
749  unsigned numInitArgs = op.getInitArgs().size();
750  rewriter.replaceAllUsesWith(body->getArgument(0),
751  body->getArgument(lvlRank + numInitArgs + 1));
752  body->eraseArgument(0);
753  // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
754  ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
755  ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
756  // Remap back before replacement.
757  SmallVector<Value> reMappedArgs =
758  remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
759  rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
760  body->eraseArguments(0, numInitArgs);
761  // Block Args: [lvlCrds, DemappedArgs] and we are done.
762 
763  // Update yield operations.
764  if (numInitArgs != 0) {
765  rewriter.setInsertionPointToEnd(body);
766  auto yield = llvm::cast<YieldOp>(body->getTerminator());
767  if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
768  stt && !stt->isIdentity()) {
769  Value y =
770  genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
771  rewriter.create<YieldOp>(loc, y);
772  rewriter.eraseOp(yield);
773  }
774  }
775  rewriter.finalizeOpModification(op);
776 
777  rewriter.setInsertionPointAfter(op);
778  SmallVector<Value> outs =
779  remapValueRange(rewriter, prevRetTps, op.getResults());
780 
781  // Replace all the uses of the foreach results, expect the use in
782  // reinterpret_map used to remap the output.
783  for (auto [from, to] : llvm::zip(op.getResults(), outs))
784  rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
785 
786  return success();
787  }
788 };
789 
790 } // namespace
791 
793  ReinterpretMapScope scope) {
794  if (scope == ReinterpretMapScope::kAll ||
796  patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
797  patterns.getContext());
798  }
799  if (scope == ReinterpretMapScope::kAll ||
801  patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
802  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
803  SparseDisassembleDemapper, TensorInsertDemapper,
804  ForeachOpDemapper>(patterns.getContext());
805  }
806 }
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
Definition: AffineExpr.h:228
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
unsigned getPosition() const
Definition: AffineExpr.cpp:340
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:499
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:329
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:609
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
BlockArgListType getArguments()
Definition: Block.h:84
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:192
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:267
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:702
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
Definition: SparseTensor.h:110
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:38
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Definition: SparseTensor.h:105
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:62
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:44
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362