MLIR 22.0.0git
SparseReinterpretMap.cpp
Go to the documentation of this file.
1//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
20#include "mlir/IR/AffineMap.h"
21
22using namespace mlir;
23using namespace mlir::sparse_tensor;
24
25namespace {
26
27//===----------------------------------------------------------------------===//
28// File Local Helper classes.
29//===----------------------------------------------------------------------===//
30
31// CRTP to help implementing a rewriter that demaps all its inputs.
32template <typename SubClass, typename SourceOp>
33struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
34 using OpRewritePattern<SourceOp>::OpRewritePattern;
35 using OpAdaptor = typename SourceOp::Adaptor;
36
37 LogicalResult matchAndRewrite(SourceOp op,
38 PatternRewriter &rewriter) const override {
39 Location loc = op.getLoc();
40
41 // Demaps non-trivial inputs.
42 bool changed = false;
43 SmallVector<Value> deMappedIns(op->getOperands());
44 for (Value &in : deMappedIns) {
45 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
46 in =
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
48 changed = true;
49 }
50 }
51
52 // CRTP call.
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
57 }
58};
59
60// Flattens an affine expression into a list of AffineDimExprs.
61struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64 BitVector dims;
65};
66
67// Flattens an affine expression into a list of AffineDimExprs.
68struct AffineExprAdmissibleVisitor
69 : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70 explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
71
72 // We only allow AffineDimExpr on output.
73 void visitAddExpr(AffineBinaryOpExpr expr) {
74 if (isOutput)
75 admissible = false;
76 }
77 void visitMulExpr(AffineBinaryOpExpr expr) {
78 if (isOutput)
79 admissible = false;
80 }
81
82 // We disallow mod, floor div and ceil div on inputs.
83 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
84 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
85 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86 operator bool() { return admissible; }
87
88private:
89 bool admissible = true;
90 bool isOutput;
91};
92
93// The first BitVector stores levels where inadmissible exprs are used.
94// The second BitVector stores the AffineDimExp that are used by the
95// inadmissible expressions.
96using InadmissInfo = std::pair<BitVector, BitVector>;
97
98} // namespace
99
100//===----------------------------------------------------------------------===//
101// File Local Helper methods.
102//===----------------------------------------------------------------------===//
103
104// Collects the inadmissible affine expression imposed on levels.
105static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
106 auto ret = std::make_pair(BitVector(map.getNumResults()),
107 BitVector(map.getNumDims()));
108 AffineDimCollector collector(map.getNumDims());
109 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.getResult(lvl));
112 if (!admissible) {
113 // Record the inadmissible level.
114 ret.first.set(lvl);
115 // Record the AffineDimExpr that is used in the inadmissible expr.
116 collector.walkPostOrder(map.getResult(lvl));
117 }
118 }
119 ret.second = collector.dims;
120 return ret;
121}
122
123// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
124// inadmissible affine expressions can be eliminated.
125// For example, we can rewrite
126// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
127// to
128// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
129// by composing inverse(idxMap), that is
130// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
131// -> ((l0 * 2 + l2) floordiv 2,
132// (l1 * 3 + l3) floordiv 3,
133// (l0 * 2 + l2) mod 2,
134// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
135//
136// This function builds the inverse(idxMap) that replace every dimensions used
137// in `info` to levels, and updates the iterator type array `itTps` for the new
138// index variable introduced.
139//
140// Note that the returned affine map does not retain the order of the input
141// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
142// replaced levels, and remaining ones for unused dimensions.
143// For example, to handle
144// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
145// which is a typical map for block_2to4. The function returns:
146// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
147// in which, (l0, l1) together replaces `d1`, yet they appear
148// before `d0` in the resulting affine map.
149// The index (loop) order can later be canonicalized by a topo sort.
150static AffineMap
151genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153 MLIRContext *ctx = idxMap.getContext();
154 auto [inAdLvls, usedDims] = info;
155 // Note that idxMap does not equal to dim2Lvl map, it is computed by
156 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
157 // ID map.
158 // TODO: we might fail here, in those case we should really return
159 // failure instead of assertion error.
160 auto lvl2Idx = inferLvlToDim(idxMap, ctx);
161
162 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
164 // This could happen when some dimensions are projected.
165 // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
166 // ==> lvl2Idx = (j, k) -> (j, k)
167 // In this case, we append the unused dimesion at the end.
168 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170 AffineDimCollector usedInLvl(idxMap.getNumDims());
171 for (auto e : idxMap.getResults())
172 usedInLvl.walkPostOrder(e);
173
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
176
177 BitVector unused = usedInLvl.dims.flip();
178 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
179 if (unused.test(i))
180 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
181 else
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
183 }
184 lvl2Idx =
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
186 }
187 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
188
189 // We do not need to replace the DimExpr that is not used in inadmissible
190 // level expressions. We use the first inAdLvl.count() dim to represent the
191 // replaced level, the remainings are reserved for unchanged ones.
192 // Note that results from the inverse map computed previously does not follow
193 // the convention we used, and we need to fix the mismatch below.
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
199
200 for (unsigned l : inAdLvls.set_bits()) {
201 // By our convention, the inadmissible level `l` always appears in the
202 // leading part (accumulated by curRepID) of the affine map's parameter
203 // list. Record the mapping so that we can replace all the uses of `l` to
204 // the correct position after the translation.
205 dimRep[l] = getAffineDimExpr(curRepID++, ctx);
206 // A new index variable is introduced for the inadmissible level, inherit
207 // the iterator type. E.g., if l0 = d0 floordiv 2, the
208 // iterator type of l0 equals to the iterator type of d0.
209 AffineExpr lvlExp = idxMap.getResult(l);
210 AffineDimCollector collector(idxMap.getNumDims());
211 collector.walkPostOrder(lvlExp);
212 // We assumes a level can only be derived from one dimension.
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
215 }
216
217 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
219 // The dimension is used in some of the inadmissible levels, and it need
220 // to be inversed. Get the inversion from the inverse map, and fix the
221 // mismatch captured by the above loop.
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
223 } else {
224 // The dimension is not used in any of the inadmissible levels, and it
225 // does not need to be inversed. Fix the mismatch by mapping it to the
226 // trailing part of the affine map (accumulated by curOriID).
227 results.push_back(getAffineDimExpr(curOriID++, ctx));
228 transItTps.push_back(itTps[d]);
229 }
230 }
231 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
232 // Update iterator type.
233 itTps.assign(transItTps.begin(), transItTps.end());
234 return AffineMap::get(numDim, 0, results, ctx);
235}
236
237// Translates the index map in the linalg::GenericOp from idx->dim map to
238// idx->lvl map. Returns failure if the index map can not be translated to an
239// admissible form.
240// Returns the translated index map array and the iterator type array.
241static std::optional<std::pair<ArrayAttr, ArrayAttr>>
242translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
243 // idxMap is a idx2dim map before reinterpretation.
244 MLIRContext *ctx = op.getContext();
245 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
246 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
247 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248 Value tensor = op->getOpOperand(i).get();
249 auto stt = tryGetSparseTensorType(tensor);
250 if (stt && !stt->isIdentity()) {
251 AffineMap dim2Lvl = stt->getDimToLvl();
252 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
253 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
254 }
255 }
256
257 // A naive way to handle common constant expressions that arise during dim2lvl
258 // translation.
259 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
260 unsigned pos, int64_t lvlSz) {
261 if (ShapedType::isStatic(lvlSz)) {
262 auto c0 = getAffineConstantExpr(0, ctx);
263 auto lvlExp = getAffineDimExpr(pos, ctx);
264 auto szExp = getAffineConstantExpr(lvlSz, ctx);
265
266 // lvl floordiv lvlSz = 0
267 auto divExp =
269 cstMapping.try_emplace(divExp, c0);
270
271 // lvl mod lvlSz = lvl
272 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
273 cstMapping.try_emplace(modExp, lvlExp);
274 }
275 };
276
277 unsigned boundedNum = 0;
278 // A fixed-point algorithm.
279 bool changed = true;
280 while (changed) {
281 changed = false;
282 for (OpOperand &operand : op->getOpOperands()) {
283 auto stt = tryGetSparseTensorType(operand.get());
284 // Skip on dense operands.
285 if (!stt || !stt->getEncoding())
286 continue;
287
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
290 AffineMap idxMap = idxMapArray[tid];
291 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (unsigned d : dimExprs.set_bits()) {
294 // The first `boundedNum` used in the AffineMap is introduced to
295 // resolve previous inadmissible expressions. We can not replace them
296 // as it might bring back the inadmissible expressions.
297 if (d < boundedNum)
298 return std::nullopt;
299 }
300
301 if (inAdLvls.count() != 0) {
302 // Naive constant progagation, should be sufficient to handle block
303 // sparsity in our cases.
304 SmallVector<int64_t> lvlShape = stt->getLvlShape();
306 unsigned position = 0;
307 for (unsigned lvl : inAdLvls.set_bits()) {
308 int64_t lvlSz = lvlShape[lvl];
309 populateCstMapping(cstMapping, position, lvlSz);
310 position++;
311 }
312
313 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
314 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
315 // inadmissible expressions.
316 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.replace(
319 cstMapping, /*numResultDims=*/transMap.getNumDims(),
320 /*numResultSyms=*/0);
321 }
322 changed = true;
323 boundedNum += inAdLvls.count();
324 }
325 }
326 };
327
328 SmallVector<Attribute> iterAttr =
329 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
330 return linalg::IteratorTypeAttr::get(ctx, itTp);
331 });
332
333 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
334 rewriter.getArrayAttr(iterAttr));
335}
336
337// Generates a "de"mapping reinterpretation of the map.
338static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
339 Value val) {
340 return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(),
341 val);
342}
343
344// Generates a "re"mapping reinterpretation of the map.
345static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
346 Value val) {
347 return ReinterpretMapOp::create(builder, val.getLoc(), enc, val);
348}
349
351 ValueRange outs) {
352 SmallVector<Value> ret(outs);
353 assert(outs.size() == types.size());
354 for (auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
357 return ret;
358}
359
360namespace {
361
362//===----------------------------------------------------------------------===//
363// Rewriting rules for linalg generic ops.
364//===----------------------------------------------------------------------===//
365
366/// Sparse rewriting rule for the generic `linalg` operation.
367struct GenericOpReinterpretMap
368 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
369public:
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
372 PatternRewriter &rewriter) const {
373 // Only rewrite single output operations with pure (sparse) tensor
374 // semantics.
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
376 !hasAnySparseOperandOrResult(linalgOp) ||
378 return failure();
379
380 // Try translating the index map.
381 auto transMap = translateMap(linalgOp, rewriter);
382 if (!transMap)
383 return rewriter.notifyMatchFailure(
384 linalgOp, "the sparse kernel can not be sparsified.");
385
386 // On success, replace update the linalg operands and maps in place.
387 Value res = linalgOp.getResult(0);
388 auto stt = tryGetSparseTensorType(res);
389 auto [idxMap, itTp] = *transMap;
390
391 rewriter.startOpModification(linalgOp);
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
394 // Use demapped arguments.
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.setType(adaptor.getOutputs()[0].getType());
398 rewriter.finalizeOpModification(linalgOp);
399
400 rewriter.setInsertionPointAfter(linalgOp);
401 if (stt && stt->hasEncoding()) {
402 Value t = genRemap(rewriter, stt->getEncoding(), res);
403 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
404 }
405 return success();
406 }
407};
408
409struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
410 GenericOpScheduler(MLIRContext *context,
412 : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
413
414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
415 PatternRewriter &rewriter) const override {
416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
417 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
418 !hasAnySparseOperandOrResult(linalgOp)) {
419 return failure();
420 }
421
422 const StringRef sorted = "sorted";
423 if (linalgOp->hasAttr(sorted))
424 return failure();
425
426 // Pass strategy to IterationGraphSorter.
427 auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
428 bool isAdmissible = false;
429 AffineMap order;
430 // A const list of all masks that we used for iteration graph
431 // computation. Must be ordered from more strict to less strict.
432 // Ideally (though might not be guaranteed), the earlier a constraint mask
433 // can be satisfied, the faster the generated kernel will be.
434 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435 SortMask::kIncludeDenseInput,
436 SortMask::kIncludeDenseOutput,
437 SortMask::kSparseOnly};
438 for (const SortMask mask : allMasks) {
439 order = scheduler.sort(mask);
440 if (order) {
441 if (isAdmissibleOrder(linalgOp, order)) {
442 isAdmissible = true;
443 break;
444 }
445 // else try a set of less strict constraints.
446 }
447 }
448
449 if (!order) {
450 // Cycles detected.
451 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
452 return rewriter.notifyMatchFailure(
453 linalgOp, "the sparse kernel can not be scheduled: loop detected.");
454 }
455 return success();
456 }
457
458 if (!isAdmissible) {
459 return rewriter.notifyMatchFailure(
460 linalgOp, "the sparse kernel can not be scheduled.");
461 }
462
463 // Marks the GenericOp to avoid recursive matching.
464 rewriter.modifyOpInPlace(linalgOp, [&]() {
465 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
466 });
467
468 // Already sorted.
469 if (order.isIdentity())
470 return success();
471
472 assert(order.isPermutation());
473 // `order` is orignial loop -> sorted loop map
474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
475 SmallVector<Attribute> curItTypes;
476 curItTypes.reserve(preItTypes.size());
477 for (AffineExpr expr : order.getResults()) {
478 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
479 curItTypes.push_back(preItTypes[loopID]);
480 }
481
482 // Inverse `order` to get sorted loop -> original loop map
483 order = inversePermutation(order);
484 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
485 for (AffineMap &idxMap : idxMaps)
486 idxMap = idxMap.compose(order); // sorted loop -> lvl map
487
488 rewriter.startOpModification(linalgOp);
489 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
490 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
491 rewriter.finalizeOpModification(linalgOp);
492
493 return success();
494 }
495
496private:
497 /// Whether the loop order is admissible by sparsification.
498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
499 if (!hasAnySparseResult(linalgOp))
500 return true;
501
502 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
503 unsigned nest = 0;
504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
505 for (const AffineExpr l : order.getResults()) {
506 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
507 auto itTp =
508 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
509 if (linalg::isReductionIterator(itTp.getValue()))
510 break; // terminate at first reduction
511 nest++;
512 }
513 // Determine admissible dynamic insertion situations:
514 // (1) fully injective, since there are no reductions,
515 // (2) admissible 1-d expansion in innermost dimension.
516 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
517 };
518
519 // Last resort cycle resolution.
520 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
521 linalg::LinalgOp linalgOp,
522 PatternRewriter &rewriter) {
523 // Compute topological sort while leaving out every sparse input tensor in
524 // succession until an acylic iteration graph results.
525 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
526 Value tval = t->get();
527 auto srcEnc = getSparseTensorEncoding(tval.getType());
528 // The constraints introduced by compound index expression are
529 // complicated. Skip them.
530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
531 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
532 return !llvm::isa<AffineDimExpr>(exp);
533 });
534 if (!srcEnc || hasCompExpr)
535 continue;
536
537 // Try scheduling loop without constraints from `tval`.
538 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
539 if (!order) // still cyclic
540 continue;
541
542 // Found an input tensor that resolves the cycle by inserting a
543 // conversion into a sparse tensor that adheres to the iteration
544 // graph order.
545 auto stt = getSparseTensorType(tval);
546 assert(stt.isIdentity());
547 order = inversePermutation(order);
548 // sorted loop -> lvl map.
549 idxMap = idxMap.compose(order);
550
551 // Found a permutation such that the results in `idxMap` is sorted.
552 // For example,
553 // (d0, d1, d2, d3) -> (d2, d1, d0)
554 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
555 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
556 // transposed tensor's levels are visited in the same order as the loop
557 // scheduling order.
558 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
559 for (AffineExpr expr : idxMap.getResults()) {
560 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
562 }
563 llvm::sort(lvlSeq, llvm::less_first());
564 SmallVector<unsigned> perm =
565 llvm::to_vector(llvm::make_second_range(lvlSeq));
566 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
567 // The result of the idxMap must be unsorted.
568 assert(!dimToLvl.isIdentity());
569
570 // Inserting the transpose
571 rewriter.setInsertionPoint(linalgOp);
572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
573 Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
574 rewriter.modifyOpInPlace(linalgOp, [&]() {
575 linalgOp->setOperand(t->getOperandNumber(), dst);
576 });
577
578 // Release the transposed form afterwards.
579 // TODO: CSE when used in more than one following op?
580 rewriter.setInsertionPointAfter(linalgOp);
581 bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst);
582
583 return success();
584 }
585 // Cannot be resolved with a single conversion.
586 // TODO: convert more than one?
587 return failure();
588 }
589
590private:
592};
593
594//===----------------------------------------------------------------------===//
595// Reinterpret Map Rewriters for operations other than linalg.generics
596//===----------------------------------------------------------------------===//
597
598template <typename AllocOp>
599struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
600 using OpRewritePattern<AllocOp>::OpRewritePattern;
601 LogicalResult matchAndRewrite(AllocOp op,
602 PatternRewriter &rewriter) const override {
604 return failure();
605
606 Location loc = op.getLoc();
607 auto stt = getSparseTensorType(op.getResult());
608
609 SmallVector<Value> maxDimCrds;
610 maxDimCrds.reserve(stt.getDimRank());
611 ValueRange dynSz = op.getDynamicSizes();
612 for (int64_t dimSz : stt.getDimShape()) {
613 if (ShapedType::isDynamic(dimSz)) {
614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
615 constantIndex(rewriter, loc, 1));
616 maxDimCrds.push_back(maxCrd);
617 dynSz = dynSz.drop_front();
618 } else {
619 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
620 }
621 }
622
623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
624 CrdTransDirectionKind::dim2lvl);
625 auto lvlShape = stt.getLvlShape();
626 SmallVector<Value> dynLvlSzs;
627 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
628 if (ShapedType::isDynamic(lvlShape[i])) {
629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
630 constantIndex(rewriter, loc, 1));
631 dynLvlSzs.push_back(sz);
632 }
633 }
634
635 assert(dynSz.empty()); // should have consumed all.
636 rewriter.startOpModification(op);
637 op->setOperands(dynLvlSzs);
638 op.getResult().setType(stt.getDemappedType());
639 rewriter.finalizeOpModification(op);
640 rewriter.setInsertionPointAfter(op);
641
642 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
643 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
644 return success();
645 }
646};
647
648struct TensorInsertDemapper
649 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
650 using DemapInsRewriter::DemapInsRewriter;
651 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
652 PatternRewriter &rewriter) const {
654 return failure();
655
656 Location loc = op.getLoc();
657 auto stt = getSparseTensorType(op.getResult());
658 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
659 CrdTransDirectionKind::dim2lvl);
660 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
661 adaptor.getDest(), lvlCrd);
662
663 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
664 rewriter.replaceOp(op, out);
665 return success();
666 }
667};
668
669struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
671 LogicalResult matchAndRewrite(AssembleOp op,
672 PatternRewriter &rewriter) const override {
674 return failure();
675
676 assert(hasAnySparseResult(op));
677 auto stt = getSparseTensorType(op.getResult());
678 rewriter.modifyOpInPlace(
679 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
680 rewriter.setInsertionPointAfter(op);
681 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
682 rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
683 return success();
684 }
685};
686
687struct SparseDisassembleDemapper
688 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
689 using DemapInsRewriter::DemapInsRewriter;
690 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
691 PatternRewriter &rewriter) const {
693 return failure();
694
695 assert(hasAnySparseOperandOrResult(op));
696 rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
697 op.getTensorMutable().assign(adaptor.getTensor());
698 });
699 return success();
700 }
701};
702
703struct ForeachOpDemapper
704 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
705 using DemapInsRewriter::DemapInsRewriter;
706 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
707 PatternRewriter &rewriter) const {
708 // Only handle operations with sparse input/output with non-identity dim2lvl
709 // maps.
711 return failure();
712
713 // TODO: demap constant as well.
714 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
715 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
716 return failure();
717
718 Location loc = op.getLoc();
719 // Cache the type information since we update the foreach op in-place.
720 auto srcStt = getSparseTensorType(op.getTensor());
721 SmallVector<Type> prevRetTps(op.getResultTypes());
722
723 rewriter.startOpModification(op);
724 op.getTensorMutable().assign(adaptor.getTensor());
725 op.getInitArgsMutable().assign(adaptor.getInitArgs());
726 // Update results' types.
727 for (auto r : op.getResults())
728 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
729 r.setType(stt->getDemappedType());
730
731 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
732 // Update the foreach body.
733 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
734 blockArgTps.push_back(srcStt.getElementType());
735 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
736 adaptor.getInitArgs().getTypes().end());
737 Block *body = op.getBody();
738 // Block Args: [dimCrd, val, initArgs]
739 unsigned preArgNum = body->getNumArguments();
740 for (Type t : blockArgTps)
741 body->addArgument(t, loc);
742
743 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
744 rewriter.setInsertionPointToStart(body);
745 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
746
747 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
748 CrdTransDirectionKind::lvl2dim);
749 rewriter.replaceAllUsesWith(
750 body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
751 body->eraseArguments(0, srcStt.getDimRank());
752 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
753 unsigned numInitArgs = op.getInitArgs().size();
754 rewriter.replaceAllUsesWith(body->getArgument(0),
755 body->getArgument(lvlRank + numInitArgs + 1));
756 body->eraseArgument(0);
757 // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
758 ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
759 ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
760 // Remap back before replacement.
761 SmallVector<Value> reMappedArgs =
762 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
763 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
764 body->eraseArguments(0, numInitArgs);
765 // Block Args: [lvlCrds, DemappedArgs] and we are done.
766
767 // Update yield operations.
768 if (numInitArgs != 0) {
769 rewriter.setInsertionPointToEnd(body);
770 auto yield = llvm::cast<YieldOp>(body->getTerminator());
771 if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
772 stt && !stt->isIdentity()) {
773 Value y =
774 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
775 YieldOp::create(rewriter, loc, y);
776 rewriter.eraseOp(yield);
777 }
778 }
779 rewriter.finalizeOpModification(op);
780
781 rewriter.setInsertionPointAfter(op);
782 SmallVector<Value> outs =
783 remapValueRange(rewriter, prevRetTps, op.getResults());
784
785 // Replace all the uses of the foreach results, expect the use in
786 // reinterpret_map used to remap the output.
787 for (auto [from, to] : llvm::zip(op.getResults(), outs))
788 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
789
790 return success();
791 }
792};
793
794} // namespace
795
799 if (scope == ReinterpretMapScope::kAll ||
801 patterns.add<GenericOpReinterpretMap>(patterns.getContext());
802 patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
803 }
804 if (scope == ReinterpretMapScope::kAll ||
806 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
807 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
808 SparseDisassembleDemapper, TensorInsertDemapper,
809 ForeachOpDemapper>(patterns.getContext());
810 }
811}
return success()
lhs
ArrayAttr()
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
BlockArgListType getArguments()
Definition Block.h:87
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:193
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
Definition Passes.h:62
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition Passes.h:45
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...