MLIR  20.0.0git
SparseReinterpretMap.cpp
Go to the documentation of this file.
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Utils/CodegenUtils.h"
11 
21 #include "mlir/IR/AffineMap.h"
22 
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
25 
26 namespace {
27 
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
31 
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
36  using OpAdaptor = typename SourceOp::Adaptor;
37 
38  LogicalResult matchAndRewrite(SourceOp op,
39  PatternRewriter &rewriter) const override {
40  Location loc = op.getLoc();
41 
42  // Demaps non-trivial inputs.
43  bool changed = false;
44  SmallVector<Value> deMappedIns(op->getOperands());
45  for (Value &in : deMappedIns) {
46  if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47  in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48  changed = true;
49  }
50  }
51 
52  // CRTP call.
53  OpAdaptor adaptor(deMappedIns, op);
54  LogicalResult status =
55  static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56  return changed ? success() : status;
57  }
58 };
59 
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62  explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63  void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64  BitVector dims;
65 };
66 
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69  : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70  explicit AffineExprAdmissibleVisitor(bool isOutput)
71  : admissible(true), isOutput(isOutput){};
72 
73  // We only allow AffineDimExpr on output.
74  void visitAddExpr(AffineBinaryOpExpr expr) {
75  if (isOutput)
76  admissible = false;
77  }
78  void visitMulExpr(AffineBinaryOpExpr expr) {
79  if (isOutput)
80  admissible = false;
81  }
82 
83  // We disallow mod, floor div and ceil div on inputs.
84  void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
85  void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86  void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
87  operator bool() { return admissible; }
88 
89 private:
90  bool admissible;
91  bool isOutput;
92 };
93 
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo = std::pair<BitVector, BitVector>;
98 
99 } // namespace
100 
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
104 
105 // Collects the inadmissible affine expression imposed on levels.
106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107  auto ret = std::make_pair(BitVector(map.getNumResults()),
108  BitVector(map.getNumDims()));
109  AffineDimCollector collector(map.getNumDims());
110  for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111  AffineExprAdmissibleVisitor admissible(isOutput);
112  admissible.walkPostOrder(map.getResult(lvl));
113  if (!admissible) {
114  // Record the inadmissible level.
115  ret.first.set(lvl);
116  // Record the AffineDimExpr that is used in the inadmissible expr.
117  collector.walkPostOrder(map.getResult(lvl));
118  }
119  }
120  ret.second = collector.dims;
121  return ret;
122 }
123 
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128 // to
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 // -> ((l0 * 2 + l2) floordiv 2,
133 // (l1 * 3 + l3) floordiv 3,
134 // (l0 * 2 + l2) mod 2,
135 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
136 //
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
140 //
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
151 static AffineMap
152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
154  MLIRContext *ctx = idxMap.getContext();
155  auto [inAdLvls, usedDims] = info;
156  // Note that idxMap does not equal to dim2Lvl map, it is computed by
157  // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158  // ID map.
159  // TODO: we might fail here, in those case we should really return
160  // failure instead of assertion error.
161  auto lvl2Idx = inferLvlToDim(idxMap, ctx);
162 
163  assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164  if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165  // This could happen when some dimensions are projected.
166  // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167  // ==> lvl2Idx = (j, k) -> (j, k)
168  // In this case, we append the unused dimesion at the end.
169  // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170  SmallVector<AffineExpr> results;
171  AffineDimCollector usedInLvl(idxMap.getNumDims());
172  for (auto e : idxMap.getResults())
173  usedInLvl.walkPostOrder(e);
174 
175  unsigned curUsedDimID = 0;
176  unsigned curUnusedDimID = lvl2Idx.getNumDims();
177 
178  BitVector unused = usedInLvl.dims.flip();
179  for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180  if (unused.test(i))
181  results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182  else
183  results.push_back(lvl2Idx.getResult(curUsedDimID++));
184  }
185  lvl2Idx =
186  AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
187  }
188  assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
189 
190  // We do not need to replace the DimExpr that is not used in inadmissible
191  // level expressions. We use the first inAdLvl.count() dim to represent the
192  // replaced level, the remainings are reserved for unchanged ones.
193  // Note that results from the inverse map computed previously does not follow
194  // the convention we used, and we need to fix the mismatch below.
195  unsigned curRepID = 0;
196  unsigned curOriID = inAdLvls.count();
197  SmallVector<AffineExpr> results;
200 
201  for (unsigned l : inAdLvls.set_bits()) {
202  // By our convention, the inadmissible level `l` always appears in the
203  // leading part (accumulated by curRepID) of the affine map's parameter
204  // list. Record the mapping so that we can replace all the uses of `l` to
205  // the correct position after the translation.
206  dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207  // A new index variable is introduced for the inadmissible level, inherit
208  // the iterator type. E.g., if l0 = d0 floordiv 2, the
209  // iterator type of l0 equals to the iterator type of d0.
210  AffineExpr lvlExp = idxMap.getResult(l);
211  AffineDimCollector collector(idxMap.getNumDims());
212  collector.walkPostOrder(lvlExp);
213  // We assumes a level can only be derived from one dimension.
214  assert(collector.dims.count() == 1);
215  transItTps.push_back(itTps[collector.dims.find_first()]);
216  }
217 
218  for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219  if (usedDims.test(d)) {
220  // The dimension is used in some of the inadmissible levels, and it need
221  // to be inversed. Get the inversion from the inverse map, and fix the
222  // mismatch captured by the above loop.
223  results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224  } else {
225  // The dimension is not used in any of the inadmissible levels, and it
226  // does not need to be inversed. Fix the mismatch by mapping it to the
227  // trailing part of the affine map (accumulated by curOriID).
228  results.push_back(getAffineDimExpr(curOriID++, ctx));
229  transItTps.push_back(itTps[d]);
230  }
231  }
232  unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233  // Update iterator type.
234  itTps.assign(transItTps.begin(), transItTps.end());
235  return AffineMap::get(numDim, 0, results, ctx);
236 }
237 
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
240 // admissible form.
241 // Returns the translated index map array and the iterator type array.
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244  // idxMap is a idx2dim map before reinterpretation.
245  MLIRContext *ctx = op.getContext();
246  SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247  SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248  for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249  Value tensor = op->getOpOperand(i).get();
250  auto stt = tryGetSparseTensorType(tensor);
251  if (stt && !stt->isIdentity()) {
252  AffineMap dim2Lvl = stt->getDimToLvl();
253  // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254  idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
255  }
256  }
257 
258  // A naive way to handle common constant expressions that arise during dim2lvl
259  // translation.
260  auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261  unsigned pos, int64_t lvlSz) {
262  if (!ShapedType::isDynamic(lvlSz)) {
263  auto c0 = getAffineConstantExpr(0, ctx);
264  auto lvlExp = getAffineDimExpr(pos, ctx);
265  auto szExp = getAffineConstantExpr(lvlSz, ctx);
266 
267  // lvl floordiv lvlSz = 0
268  auto divExp =
270  cstMapping.try_emplace(divExp, c0);
271 
272  // lvl mod lvlSz = lvl
273  auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274  cstMapping.try_emplace(modExp, lvlExp);
275  }
276  };
277 
278  unsigned boundedNum = 0;
279  // A fixed-point algorithm.
280  bool changed = true;
281  while (changed) {
282  changed = false;
283  for (OpOperand &operand : op->getOpOperands()) {
284  auto stt = tryGetSparseTensorType(operand.get());
285  // Skip on dense operands.
286  if (!stt || !stt->getEncoding())
287  continue;
288 
289  unsigned tid = operand.getOperandNumber();
290  bool isOutput = &operand == op.getDpsInitOperand(0);
291  AffineMap idxMap = idxMapArray[tid];
292  InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293  auto [inAdLvls, dimExprs] = inAdInfo;
294  for (unsigned d : dimExprs.set_bits()) {
295  // The first `boundedNum` used in the AffineMap is introduced to
296  // resolve previous inadmissible expressions. We can not replace them
297  // as it might bring back the inadmissible expressions.
298  if (d < boundedNum)
299  return std::nullopt;
300  }
301 
302  if (inAdLvls.count() != 0) {
303  // Naive constant progagation, should be sufficient to handle block
304  // sparsity in our cases.
305  SmallVector<int64_t> lvlShape = stt->getLvlShape();
307  unsigned position = 0;
308  for (unsigned lvl : inAdLvls.set_bits()) {
309  int64_t lvlSz = lvlShape[lvl];
310  populateCstMapping(cstMapping, position, lvlSz);
311  position++;
312  }
313 
314  AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315  // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316  // inadmissible expressions.
317  for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318  AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319  idxMapArray[tid] = transMap.replace(
320  cstMapping, /*numResultDims=*/transMap.getNumDims(),
321  /*numResultSyms=*/0);
322  }
323  changed = true;
324  boundedNum += inAdLvls.count();
325  }
326  }
327  };
328 
329  SmallVector<Attribute> iterAttr =
330  llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331  return linalg::IteratorTypeAttr::get(ctx, itTp);
332  });
333 
334  return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335  rewriter.getArrayAttr(iterAttr));
336 }
337 
338 // Generates a "de"mapping reinterpretation of the map.
339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340  Value val) {
341  return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342  val);
343 }
344 
345 // Generates a "re"mapping reinterpretation of the map.
346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347  Value val) {
348  return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
349 }
350 
352  ValueRange outs) {
353  SmallVector<Value> ret(outs);
354  assert(outs.size() == types.size());
355  for (auto [r, t] : llvm::zip(ret, types))
356  if (r.getType() != t)
357  r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358  return ret;
359 }
360 
361 namespace {
362 
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
366 
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369  : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 public:
371  using DemapInsRewriter::DemapInsRewriter;
372  LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373  PatternRewriter &rewriter) const {
374  // Only rewrite single output operations with pure (sparse) tensor
375  // semantics.
376  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377  !hasAnySparseOperandOrResult(linalgOp) ||
379  return failure();
380 
381  // Try translating the index map.
382  auto transMap = translateMap(linalgOp, rewriter);
383  if (!transMap)
384  return rewriter.notifyMatchFailure(
385  linalgOp, "the sparse kernel can not be sparsified.");
386 
387  // On success, replace update the linalg operands and maps in place.
388  Value res = linalgOp.getResult(0);
389  auto stt = tryGetSparseTensorType(res);
390  auto [idxMap, itTp] = *transMap;
391 
392  rewriter.startOpModification(linalgOp);
393  linalgOp.setIndexingMapsAttr(idxMap);
394  linalgOp.setIteratorTypesAttr(itTp);
395  // Use demapped arguments.
396  linalgOp.getInputsMutable().assign(adaptor.getInputs());
397  linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398  res.setType(adaptor.getOutputs()[0].getType());
399  rewriter.finalizeOpModification(linalgOp);
400 
401  rewriter.setInsertionPointAfter(linalgOp);
402  if (stt && stt->hasEncoding()) {
403  Value t = genRemap(rewriter, stt->getEncoding(), res);
404  rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
405  }
406  return success();
407  }
408 };
409 
410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
412  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413  PatternRewriter &rewriter) const override {
414  if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415  hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416  !hasAnySparseOperandOrResult(linalgOp)) {
417  return failure();
418  }
419 
420  const StringRef sorted = "sorted";
421  if (linalgOp->hasAttr(sorted))
422  return failure();
423 
424  auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
425  bool isAdmissible = false;
426  AffineMap order;
427  // A const list of all masks that we used for iteration graph
428  // computation. Must be ordered from more strict to less strict.
429  // Ideally (though might not be guaranteed), the earlier a constraint mask
430  // can be satisfied, the faster the generated kernel will be.
431  const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435  for (const SortMask mask : allMasks) {
436  order = scheduler.sort(mask);
437  if (order) {
438  if (isAdmissibleOrder(linalgOp, order)) {
439  isAdmissible = true;
440  break;
441  }
442  // else try a set of less strict constraints.
443  }
444  }
445 
446  if (!order) {
447  // Cycles detected.
448  if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449  return rewriter.notifyMatchFailure(
450  linalgOp, "the sparse kernel can not be scheduled: loop detected.");
451  }
452  return success();
453  }
454 
455  if (!isAdmissible) {
456  return rewriter.notifyMatchFailure(
457  linalgOp, "the sparse kernel can not be scheduled.");
458  }
459 
460  // Marks the GenericOp to avoid recursive matching.
461  rewriter.modifyOpInPlace(linalgOp, [&]() {
462  linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
463  });
464 
465  // Already sorted.
466  if (order.isIdentity())
467  return success();
468 
469  assert(order.isPermutation());
470  // `order` is orignial loop -> sorted loop map
471  ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472  SmallVector<Attribute> curItTypes;
473  curItTypes.reserve(preItTypes.size());
474  for (AffineExpr expr : order.getResults()) {
475  unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476  curItTypes.push_back(preItTypes[loopID]);
477  }
478 
479  // Inverse `order` to get sorted loop -> original loop map
480  order = inversePermutation(order);
481  SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482  for (AffineMap &idxMap : idxMaps)
483  idxMap = idxMap.compose(order); // sorted loop -> lvl map
484 
485  rewriter.startOpModification(linalgOp);
486  linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487  linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488  rewriter.finalizeOpModification(linalgOp);
489 
490  return success();
491  }
492 
493 private:
494  /// Whether the loop order is admissible by sparsification.
495  static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496  if (!hasAnySparseResult(linalgOp))
497  return true;
498 
499  OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500  unsigned nest = 0;
501  const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502  for (const AffineExpr l : order.getResults()) {
503  unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504  auto itTp =
505  cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
506  if (linalg::isReductionIterator(itTp.getValue()))
507  break; // terminate at first reduction
508  nest++;
509  }
510  // Determine admissible dynamic insertion situations:
511  // (1) fully injective, since there are no reductions,
512  // (2) admissible 1-d expansion in innermost dimension.
513  return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
514  };
515 
516  // Last resort cycle resolution.
517  static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518  linalg::LinalgOp linalgOp,
519  PatternRewriter &rewriter) {
520  // Compute topological sort while leaving out every sparse input tensor in
521  // succession until an acylic iteration graph results.
522  for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523  Value tval = t->get();
524  auto srcEnc = getSparseTensorEncoding(tval.getType());
525  // The constraints introduced by compound index expression are
526  // complicated. Skip them.
527  AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528  bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529  return !llvm::isa<AffineDimExpr>(exp);
530  });
531  if (!srcEnc || hasCompExpr)
532  continue;
533 
534  // Try scheduling loop without constraints from `tval`.
535  AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536  if (!order) // still cyclic
537  continue;
538 
539  // Found an input tensor that resolves the cycle by inserting a
540  // conversion into a sparse tensor that adheres to the iteration
541  // graph order.
542  auto stt = getSparseTensorType(tval);
543  assert(stt.isIdentity());
544  order = inversePermutation(order);
545  // sorted loop -> lvl map.
546  idxMap = idxMap.compose(order);
547 
548  // Found a permutation such that the results in `idxMap` is sorted.
549  // For example,
550  // (d0, d1, d2, d3) -> (d2, d1, d0)
551  // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552  // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553  // transposed tensor's levels are visited in the same order as the loop
554  // scheduling order.
556  for (AffineExpr expr : idxMap.getResults()) {
557  unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558  lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
559  }
560  llvm::sort(lvlSeq, llvm::less_first());
561  SmallVector<unsigned> perm =
562  llvm::to_vector(llvm::make_second_range(lvlSeq));
563  auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
564  // The result of the idxMap must be unsorted.
565  assert(!dimToLvl.isIdentity());
566 
567  // Inserting the transpose
568  rewriter.setInsertionPoint(linalgOp);
569  RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
570  Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
571  rewriter.modifyOpInPlace(linalgOp, [&]() {
572  linalgOp->setOperand(t->getOperandNumber(), dst);
573  });
574 
575  // Release the transposed form afterwards.
576  // TODO: CSE when used in more than one following op?
577  rewriter.setInsertionPointAfter(linalgOp);
578  rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
579 
580  return success();
581  }
582  // Cannot be resolved with a single conversion.
583  // TODO: convert more than one?
584  return failure();
585  }
586 };
587 
588 //===----------------------------------------------------------------------===//
589 // Reinterpret Map Rewriters for operations other than linalg.generics
590 //===----------------------------------------------------------------------===//
591 
592 template <typename AllocOp>
593 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
595  LogicalResult matchAndRewrite(AllocOp op,
596  PatternRewriter &rewriter) const override {
598  return failure();
599 
600  Location loc = op.getLoc();
601  auto stt = getSparseTensorType(op.getResult());
602 
603  SmallVector<Value> maxDimCrds;
604  maxDimCrds.reserve(stt.getDimRank());
605  ValueRange dynSz = op.getDynamicSizes();
606  for (int64_t dimSz : stt.getDimShape()) {
607  if (ShapedType::isDynamic(dimSz)) {
608  Value maxCrd = rewriter.create<arith::SubIOp>(
609  loc, dynSz.front(), constantIndex(rewriter, loc, 1));
610  maxDimCrds.push_back(maxCrd);
611  dynSz = dynSz.drop_front();
612  } else {
613  maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
614  }
615  }
616 
617  ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618  CrdTransDirectionKind::dim2lvl);
619  auto lvlShape = stt.getLvlShape();
620  SmallVector<Value> dynLvlSzs;
621  for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622  if (ShapedType::isDynamic(lvlShape[i])) {
623  Value sz = rewriter.create<arith::AddIOp>(
624  loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
625  dynLvlSzs.push_back(sz);
626  }
627  }
628 
629  assert(dynSz.empty()); // should have consumed all.
630  rewriter.startOpModification(op);
631  op->setOperands(dynLvlSzs);
632  op.getResult().setType(stt.getDemappedType());
633  rewriter.finalizeOpModification(op);
634  rewriter.setInsertionPointAfter(op);
635 
636  Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
637  rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
638  return success();
639  }
640 };
641 
642 struct TensorInsertDemapper
643  : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
644  using DemapInsRewriter::DemapInsRewriter;
645  LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
646  PatternRewriter &rewriter) const {
648  return failure();
649 
650  Location loc = op.getLoc();
651  auto stt = getSparseTensorType(op.getResult());
652  ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653  CrdTransDirectionKind::dim2lvl);
654  auto insertOp = rewriter.create<tensor::InsertOp>(
655  loc, op.getScalar(), adaptor.getDest(), lvlCrd);
656 
657  Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
658  rewriter.replaceOp(op, out);
659  return success();
660  }
661 };
662 
663 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
665  LogicalResult matchAndRewrite(AssembleOp op,
666  PatternRewriter &rewriter) const override {
668  return failure();
669 
670  assert(hasAnySparseResult(op));
671  auto stt = getSparseTensorType(op.getResult());
672  rewriter.modifyOpInPlace(
673  op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
674  rewriter.setInsertionPointAfter(op);
675  Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
676  rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
677  return success();
678  }
679 };
680 
681 struct SparseDisassembleDemapper
682  : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
683  using DemapInsRewriter::DemapInsRewriter;
684  LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
685  PatternRewriter &rewriter) const {
687  return failure();
688 
689  assert(hasAnySparseOperandOrResult(op));
690  rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
691  op.getTensorMutable().assign(adaptor.getTensor());
692  });
693  return success();
694  }
695 };
696 
697 struct ForeachOpDemapper
698  : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
699  using DemapInsRewriter::DemapInsRewriter;
700  LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
701  PatternRewriter &rewriter) const {
702  // Only handle operations with sparse input/output with non-identity dim2lvl
703  // maps.
705  return failure();
706 
707  // TODO: demap constant as well.
708  if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
709  if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
710  return failure();
711 
712  Location loc = op.getLoc();
713  // Cache the type information since we update the foreach op in-place.
714  auto srcStt = getSparseTensorType(op.getTensor());
715  SmallVector<Type> prevRetTps(op.getResultTypes());
716 
717  rewriter.startOpModification(op);
718  op.getTensorMutable().assign(adaptor.getTensor());
719  op.getInitArgsMutable().assign(adaptor.getInitArgs());
720  // Update results' types.
721  for (auto r : op.getResults())
722  if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
723  r.setType(stt->getDemappedType());
724 
725  Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
726  // Update the foreach body.
727  SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
728  blockArgTps.push_back(srcStt.getElementType());
729  blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
730  adaptor.getInitArgs().getTypes().end());
731  Block *body = op.getBody();
732  // Block Args: [dimCrd, val, initArgs]
733  unsigned preArgNum = body->getNumArguments();
734  for (Type t : blockArgTps)
735  body->addArgument(t, loc);
736 
737  // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
738  rewriter.setInsertionPointToStart(body);
739  ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
740 
741  ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
742  CrdTransDirectionKind::lvl2dim);
743  rewriter.replaceAllUsesWith(
744  body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
745  body->eraseArguments(0, srcStt.getDimRank());
746  // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
747  unsigned numInitArgs = op.getInitArgs().size();
748  rewriter.replaceAllUsesWith(body->getArgument(0),
749  body->getArgument(lvlRank + numInitArgs + 1));
750  body->eraseArgument(0);
751  // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
752  ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
753  ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
754  // Remap back before replacement.
755  SmallVector<Value> reMappedArgs =
756  remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
757  rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
758  body->eraseArguments(0, numInitArgs);
759  // Block Args: [lvlCrds, DemappedArgs] and we are done.
760 
761  // Update yield operations.
762  if (numInitArgs != 0) {
763  rewriter.setInsertionPointToEnd(body);
764  auto yield = llvm::cast<YieldOp>(body->getTerminator());
765  if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
766  stt && !stt->isIdentity()) {
767  Value y =
768  genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
769  rewriter.create<YieldOp>(loc, y);
770  rewriter.eraseOp(yield);
771  }
772  }
773  rewriter.finalizeOpModification(op);
774 
775  rewriter.setInsertionPointAfter(op);
776  SmallVector<Value> outs =
777  remapValueRange(rewriter, prevRetTps, op.getResults());
778 
779  // Replace all the uses of the foreach results, expect the use in
780  // reinterpret_map used to remap the output.
781  for (auto [from, to] : llvm::zip(op.getResults(), outs))
782  rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
783 
784  return success();
785  }
786 };
787 
788 } // namespace
789 
791  ReinterpretMapScope scope) {
792  if (scope == ReinterpretMapScope::kAll ||
794  patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
795  patterns.getContext());
796  }
797  if (scope == ReinterpretMapScope::kAll ||
799  patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
800  TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
801  SparseDisassembleDemapper, TensorInsertDemapper,
802  ForeachOpDemapper>(patterns.getContext());
803  }
804 }
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
Affine binary operation expression.
Definition: AffineExpr.h:227
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
BlockArgListType getArguments()
Definition: Block.h:87
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:195
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:358
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents an operand of an operation.
Definition: Value.h:267
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:708
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:140
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp)
Factory method that construct an iteration graph sorter for the given linalg.generic operation.
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition: Utils.cpp:188
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
Definition: SparseTensor.h:196
uint64_t Level
The type of level identifiers and level-ranks.
Definition: SparseTensor.h:42
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Definition: SparseTensor.h:191
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition: Passes.h:45
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362