MLIR 23.0.0git
SparseReinterpretMap.cpp
Go to the documentation of this file.
1//===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
20#include "mlir/IR/AffineMap.h"
21
22using namespace mlir;
23using namespace mlir::sparse_tensor;
24
25namespace {
26
27//===----------------------------------------------------------------------===//
28// File Local Helper classes.
29//===----------------------------------------------------------------------===//
30
31// CRTP to help implementing a rewriter that demaps all its inputs.
32template <typename SubClass, typename SourceOp>
33struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
34 using OpRewritePattern<SourceOp>::OpRewritePattern;
35 using OpAdaptor = typename SourceOp::Adaptor;
36
37 LogicalResult matchAndRewrite(SourceOp op,
38 PatternRewriter &rewriter) const override {
39 Location loc = op.getLoc();
40
41 // Demaps non-trivial inputs.
42 bool changed = false;
43 SmallVector<Value> deMappedIns(op->getOperands());
44 for (Value &in : deMappedIns) {
45 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
46 in =
47 ReinterpretMapOp::create(rewriter, loc, stt->getDemappedType(), in);
48 changed = true;
49 }
50 }
51
52 // CRTP call.
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
57 }
58};
59
60// Flattens an affine expression into a list of AffineDimExprs.
61struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64 BitVector dims;
65};
66
67// Flattens an affine expression into a list of AffineDimExprs.
68struct AffineExprAdmissibleVisitor
69 : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70 explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
71
72 // We only allow AffineDimExpr on output.
73 void visitAddExpr(AffineBinaryOpExpr expr) {
74 if (isOutput)
75 admissible = false;
76 }
77 void visitMulExpr(AffineBinaryOpExpr expr) {
78 if (isOutput)
79 admissible = false;
80 }
81
82 // We disallow mod, floor div and ceil div on inputs.
83 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
84 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
85 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86 operator bool() { return admissible; }
87
88private:
89 bool admissible = true;
90 bool isOutput;
91};
92
93// The first BitVector stores levels where inadmissible exprs are used.
94// The second BitVector stores the AffineDimExp that are used by the
95// inadmissible expressions.
96using InadmissInfo = std::pair<BitVector, BitVector>;
97
98} // namespace
99
100//===----------------------------------------------------------------------===//
101// File Local Helper methods.
102//===----------------------------------------------------------------------===//
103
104// Collects the inadmissible affine expression imposed on levels.
105static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
106 auto ret = std::make_pair(BitVector(map.getNumResults()),
107 BitVector(map.getNumDims()));
108 AffineDimCollector collector(map.getNumDims());
109 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
110 AffineExprAdmissibleVisitor admissible(isOutput);
111 admissible.walkPostOrder(map.getResult(lvl));
112 if (!admissible) {
113 // Record the inadmissible level.
114 ret.first.set(lvl);
115 // Record the AffineDimExpr that is used in the inadmissible expr.
116 collector.walkPostOrder(map.getResult(lvl));
117 }
118 }
119 ret.second = collector.dims;
120 return ret;
121}
122
123// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
124// inadmissible affine expressions can be eliminated.
125// For example, we can rewrite
126// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
127// to
128// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
129// by composing inverse(idxMap), that is
130// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
131// -> ((l0 * 2 + l2) floordiv 2,
132// (l1 * 3 + l3) floordiv 3,
133// (l0 * 2 + l2) mod 2,
134// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
135//
136// This function builds the inverse(idxMap) that replace every dimensions used
137// in `info` to levels, and updates the iterator type array `itTps` for the new
138// index variable introduced.
139//
140// Note that the returned affine map does not retain the order of the input
141// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
142// replaced levels, and remaining ones for unused dimensions.
143// For example, to handle
144// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
145// which is a typical map for block_2to4. The function returns:
146// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
147// in which, (l0, l1) together replaces `d1`, yet they appear
148// before `d0` in the resulting affine map.
149// The index (loop) order can later be canonicalized by a topo sort.
150static AffineMap
151genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153 MLIRContext *ctx = idxMap.getContext();
154 auto [inAdLvls, usedDims] = info;
155 // Note that idxMap does not equal to dim2Lvl map, it is computed by
156 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
157 // ID map.
158 // TODO: we might fail here, in those case we should really return
159 // failure instead of assertion error.
160 auto lvl2Idx = inferLvlToDim(idxMap, ctx);
161
162 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
163 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
164 // This could happen when some dimensions are projected.
165 // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
166 // ==> lvl2Idx = (j, k) -> (j, k)
167 // In this case, we append the unused dimesion at the end.
168 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170 AffineDimCollector usedInLvl(idxMap.getNumDims());
171 for (auto e : idxMap.getResults())
172 usedInLvl.walkPostOrder(e);
173
174 unsigned curUsedDimID = 0;
175 unsigned curUnusedDimID = lvl2Idx.getNumDims();
176
177 BitVector unused = usedInLvl.dims.flip();
178 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
179 if (unused.test(i))
180 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
181 else
182 results.push_back(lvl2Idx.getResult(curUsedDimID++));
183 }
184 lvl2Idx =
185 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
186 }
187 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
188
189 // We do not need to replace the DimExpr that is not used in inadmissible
190 // level expressions. We use the first inAdLvl.count() dim to represent the
191 // replaced level, the remainings are reserved for unchanged ones.
192 // Note that results from the inverse map computed previously does not follow
193 // the convention we used, and we need to fix the mismatch below.
194 unsigned curRepID = 0;
195 unsigned curOriID = inAdLvls.count();
199
200 for (unsigned l : inAdLvls.set_bits()) {
201 // By our convention, the inadmissible level `l` always appears in the
202 // leading part (accumulated by curRepID) of the affine map's parameter
203 // list. Record the mapping so that we can replace all the uses of `l` to
204 // the correct position after the translation.
205 dimRep[l] = getAffineDimExpr(curRepID++, ctx);
206 // A new index variable is introduced for the inadmissible level, inherit
207 // the iterator type. E.g., if l0 = d0 floordiv 2, the
208 // iterator type of l0 equals to the iterator type of d0.
209 AffineExpr lvlExp = idxMap.getResult(l);
210 AffineDimCollector collector(idxMap.getNumDims());
211 collector.walkPostOrder(lvlExp);
212 // We assumes a level can only be derived from one dimension.
213 assert(collector.dims.count() == 1);
214 transItTps.push_back(itTps[collector.dims.find_first()]);
215 }
216
217 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
218 if (usedDims.test(d)) {
219 // The dimension is used in some of the inadmissible levels, and it need
220 // to be inversed. Get the inversion from the inverse map, and fix the
221 // mismatch captured by the above loop.
222 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
223 } else {
224 // The dimension is not used in any of the inadmissible levels, and it
225 // does not need to be inversed. Fix the mismatch by mapping it to the
226 // trailing part of the affine map (accumulated by curOriID).
227 results.push_back(getAffineDimExpr(curOriID++, ctx));
228 transItTps.push_back(itTps[d]);
229 }
230 }
231 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
232 // Update iterator type.
233 itTps.assign(transItTps.begin(), transItTps.end());
234 return AffineMap::get(numDim, 0, results, ctx);
235}
236
237// Translates the index map in the linalg::GenericOp from idx->dim map to
238// idx->lvl map. Returns failure if the index map can not be translated to an
239// admissible form.
240// Returns the translated index map array and the iterator type array.
241static std::optional<std::pair<ArrayAttr, ArrayAttr>>
242translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
243 // idxMap is a idx2dim map before reinterpretation.
244 MLIRContext *ctx = op.getContext();
245 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
246 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
247 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
248 Value tensor = op->getOpOperand(i).get();
249 auto stt = tryGetSparseTensorType(tensor);
250 if (stt && !stt->isIdentity()) {
251 AffineMap dim2Lvl = stt->getDimToLvl();
252 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
253 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
254 }
255 }
256
257 // A naive way to handle common constant expressions that arise during dim2lvl
258 // translation.
259 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
260 unsigned pos, int64_t lvlSz) {
261 if (ShapedType::isStatic(lvlSz)) {
262 auto c0 = getAffineConstantExpr(0, ctx);
263 auto lvlExp = getAffineDimExpr(pos, ctx);
264 auto szExp = getAffineConstantExpr(lvlSz, ctx);
265
266 // lvl floordiv lvlSz = 0
267 auto divExp =
269 cstMapping.try_emplace(divExp, c0);
270
271 // lvl mod lvlSz = lvl
272 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
273 cstMapping.try_emplace(modExp, lvlExp);
274 }
275 };
276
277 unsigned boundedNum = 0;
278 // A fixed-point algorithm.
279 bool changed = true;
280 while (changed) {
281 changed = false;
282 for (OpOperand &operand : op->getOpOperands()) {
283 auto stt = tryGetSparseTensorType(operand.get());
284 // Skip on dense operands.
285 if (!stt || !stt->getEncoding())
286 continue;
287
288 unsigned tid = operand.getOperandNumber();
289 bool isOutput = &operand == op.getDpsInitOperand(0);
290 AffineMap idxMap = idxMapArray[tid];
291 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
292 auto [inAdLvls, dimExprs] = inAdInfo;
293 for (unsigned d : dimExprs.set_bits()) {
294 // The first `boundedNum` used in the AffineMap is introduced to
295 // resolve previous inadmissible expressions. We can not replace them
296 // as it might bring back the inadmissible expressions.
297 if (d < boundedNum)
298 return std::nullopt;
299 }
300
301 if (inAdLvls.count() != 0) {
302 // Naive constant progagation, should be sufficient to handle block
303 // sparsity in our cases.
304 SmallVector<int64_t> lvlShape = stt->getLvlShape();
306 unsigned position = 0;
307 for (unsigned lvl : inAdLvls.set_bits()) {
308 int64_t lvlSz = lvlShape[lvl];
309 populateCstMapping(cstMapping, position, lvlSz);
310 position++;
311 }
312
313 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
314 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
315 // inadmissible expressions.
316 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
317 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
318 idxMapArray[tid] = transMap.replace(
319 cstMapping, /*numResultDims=*/transMap.getNumDims(),
320 /*numResultSyms=*/0);
321 }
322 changed = true;
323 boundedNum += inAdLvls.count();
324 }
325 }
326 };
327
328 SmallVector<Attribute> iterAttr =
329 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
330 return linalg::IteratorTypeAttr::get(ctx, itTp);
331 });
332
333 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
334 rewriter.getArrayAttr(iterAttr));
335}
336
337// Generates a "de"mapping reinterpretation of the map.
338static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
339 Value val) {
340 return ReinterpretMapOp::create(builder, val.getLoc(), enc.withoutDimToLvl(),
341 val);
342}
343
344// Generates a "re"mapping reinterpretation of the map.
345static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
346 Value val) {
347 return ReinterpretMapOp::create(builder, val.getLoc(), enc, val);
348}
349
351 ValueRange outs) {
352 SmallVector<Value> ret(outs);
353 assert(outs.size() == types.size());
354 for (auto [r, t] : llvm::zip(ret, types))
355 if (r.getType() != t)
356 r = ReinterpretMapOp::create(rewriter, r.getLoc(), t, r);
357 return ret;
358}
359
360namespace {
361
362//===----------------------------------------------------------------------===//
363// Rewriting rules for linalg generic ops.
364//===----------------------------------------------------------------------===//
365
366/// Sparse rewriting rule for the generic `linalg` operation.
367struct GenericOpReinterpretMap
368 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
369public:
370 using DemapInsRewriter::DemapInsRewriter;
371 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
372 PatternRewriter &rewriter) const {
373 // Only rewrite single output operations with pure (sparse) tensor
374 // semantics.
375 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
376 !hasAnySparseOperandOrResult(linalgOp) ||
378 return failure();
379
380 // Try translating the index map.
381 auto transMap = translateMap(linalgOp, rewriter);
382 if (!transMap)
383 return rewriter.notifyMatchFailure(
384 linalgOp, "the sparse kernel can not be sparsified.");
385
386 // On success, replace update the linalg operands and maps in place.
387 Value res = linalgOp.getResult(0);
388 auto stt = tryGetSparseTensorType(res);
389 auto [idxMap, itTp] = *transMap;
390
391 rewriter.startOpModification(linalgOp);
392 linalgOp.setIndexingMapsAttr(idxMap);
393 linalgOp.setIteratorTypesAttr(itTp);
394 // Use demapped arguments.
395 linalgOp.getInputsMutable().assign(adaptor.getInputs());
396 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
397 res.setType(adaptor.getOutputs()[0].getType());
398 rewriter.finalizeOpModification(linalgOp);
399
400 rewriter.setInsertionPointAfter(linalgOp);
401 if (stt && stt->hasEncoding()) {
402 Value t = genRemap(rewriter, stt->getEncoding(), res);
403 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
404 }
405 return success();
406 }
407};
408
409struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
410 GenericOpScheduler(MLIRContext *context,
412 : OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
413
414 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
415 PatternRewriter &rewriter) const override {
416 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
417 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
418 !hasAnySparseOperandOrResult(linalgOp)) {
419 return failure();
420 }
421
422 const StringRef sorted = "sorted";
423 if (linalgOp->hasAttr(sorted))
424 return failure();
425
426 // Pass strategy to IterationGraphSorter.
427 auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
428 bool isAdmissible = false;
429 AffineMap order;
430 // A const list of all masks that we used for iteration graph
431 // computation. Must be ordered from more strict to less strict.
432 // Ideally (though might not be guaranteed), the earlier a constraint mask
433 // can be satisfied, the faster the generated kernel will be.
434 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
435 SortMask::kIncludeDenseInput,
436 SortMask::kIncludeDenseOutput,
437 SortMask::kSparseOnly};
438 for (const SortMask mask : allMasks) {
439 order = scheduler.sort(mask);
440 if (order) {
441 if (isAdmissibleOrder(linalgOp, order)) {
442 isAdmissible = true;
443 break;
444 }
445 // else try a set of less strict constraints.
446 }
447 }
448
449 if (!order) {
450 // Cycles detected.
451 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
452 return rewriter.notifyMatchFailure(
453 linalgOp, "the sparse kernel can not be scheduled: loop detected.");
454 }
455 return success();
456 }
457
458 if (!isAdmissible) {
459 return rewriter.notifyMatchFailure(
460 linalgOp, "the sparse kernel can not be scheduled.");
461 }
462
463 // Marks the GenericOp to avoid recursive matching.
464 rewriter.modifyOpInPlace(linalgOp, [&]() {
465 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
466 });
467
468 // Already sorted.
469 if (order.isIdentity())
470 return success();
471
472 assert(order.isPermutation());
473 // `order` is orignial loop -> sorted loop map
474 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
475 SmallVector<Attribute> curItTypes;
476 curItTypes.reserve(preItTypes.size());
477 for (AffineExpr expr : order.getResults()) {
478 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
479 curItTypes.push_back(preItTypes[loopID]);
480 }
481
482 // Inverse `order` to get sorted loop -> original loop map
483 order = inversePermutation(order);
484 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
485 for (AffineMap &idxMap : idxMaps)
486 idxMap = idxMap.compose(order); // sorted loop -> lvl map
487
488 rewriter.startOpModification(linalgOp);
489 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
490 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
491 rewriter.finalizeOpModification(linalgOp);
492
493 return success();
494 }
495
496private:
497 /// Whether the loop order is admissible by sparsification.
498 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
499 if (!hasAnySparseResult(linalgOp))
500 return true;
501
502 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
503 unsigned nest = 0;
504 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
505 for (const AffineExpr l : order.getResults()) {
506 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
507 auto itTp =
508 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
509 if (linalg::isReductionIterator(itTp.getValue()))
510 break; // terminate at first reduction
511 nest++;
512 }
513 // Determine admissible dynamic insertion situations:
514 // (1) fully injective, since there are no reductions,
515 // (2) admissible 1-d expansion in innermost dimension.
516 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
517 };
518
519 // Last resort cycle resolution.
520 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
521 linalg::LinalgOp linalgOp,
522 PatternRewriter &rewriter) {
523 // Compute topological sort while leaving out every sparse input tensor in
524 // succession until an acylic iteration graph results.
525 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
526 Value tval = t->get();
527 auto srcEnc = getSparseTensorEncoding(tval.getType());
528 // The constraints introduced by compound index expression are
529 // complicated. Skip them.
530 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
531 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
532 return !llvm::isa<AffineDimExpr>(exp);
533 });
534 if (!srcEnc || hasCompExpr)
535 continue;
536
537 // Try scheduling loop without constraints from `tval`.
538 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
539 if (!order) // still cyclic
540 continue;
541
542 // Found an input tensor that resolves the cycle by inserting a
543 // conversion into a sparse tensor that adheres to the iteration
544 // graph order.
545 auto stt = getSparseTensorType(tval);
546 assert(stt.isIdentity());
547 order = inversePermutation(order);
548 // sorted loop -> lvl map.
549 idxMap = idxMap.compose(order);
550
551 // Found a permutation such that the results in `idxMap` is sorted.
552 // For example,
553 // (d0, d1, d2, d3) -> (d2, d1, d0)
554 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
555 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
556 // transposed tensor's levels are visited in the same order as the loop
557 // scheduling order.
558 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
559 for (AffineExpr expr : idxMap.getResults()) {
560 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
561 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
562 }
563 llvm::sort(lvlSeq, llvm::less_first());
564 SmallVector<unsigned> perm =
565 llvm::to_vector(llvm::make_second_range(lvlSeq));
566 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
567 // The result of the idxMap must be unsorted.
568 assert(!dimToLvl.isIdentity());
569
570 // Inserting the transpose
571 rewriter.setInsertionPoint(linalgOp);
572 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
573 Value dst = ConvertOp::create(rewriter, tval.getLoc(), dstTp, tval);
574 rewriter.modifyOpInPlace(linalgOp, [&]() {
575 linalgOp->setOperand(t->getOperandNumber(), dst);
576 });
577
578 // Release the transposed form afterwards.
579 // TODO: CSE when used in more than one following op?
580 rewriter.setInsertionPointAfter(linalgOp);
581 bufferization::DeallocTensorOp::create(rewriter, dst.getLoc(), dst);
582
583 return success();
584 }
585 // Cannot be resolved with a single conversion.
586 // TODO: convert more than one?
587 return failure();
588 }
589
590private:
592};
593
594//===----------------------------------------------------------------------===//
595// Reinterpret Map Rewriters for operations other than linalg.generics
596//===----------------------------------------------------------------------===//
597
598template <typename AllocOp>
599struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
600 using OpRewritePattern<AllocOp>::OpRewritePattern;
601 LogicalResult matchAndRewrite(AllocOp op,
602 PatternRewriter &rewriter) const override {
604 return failure();
605
606 Location loc = op.getLoc();
607 auto stt = getSparseTensorType(op.getResult());
608
609 SmallVector<Value> maxDimCrds;
610 maxDimCrds.reserve(stt.getDimRank());
611 ValueRange dynSz = op.getDynamicSizes();
612 for (int64_t dimSz : stt.getDimShape()) {
613 if (ShapedType::isDynamic(dimSz)) {
614 Value maxCrd = arith::SubIOp::create(rewriter, loc, dynSz.front(),
615 constantIndex(rewriter, loc, 1));
616 maxDimCrds.push_back(maxCrd);
617 dynSz = dynSz.drop_front();
618 } else {
619 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
620 }
621 }
622
623 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
624 CrdTransDirectionKind::dim2lvl);
625 auto lvlShape = stt.getLvlShape();
626 SmallVector<Value> dynLvlSzs;
627 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
628 if (ShapedType::isDynamic(lvlShape[i])) {
629 Value sz = arith::AddIOp::create(rewriter, loc, maxLvlCrds[i],
630 constantIndex(rewriter, loc, 1));
631 dynLvlSzs.push_back(sz);
632 }
633 }
634
635 assert(dynSz.empty()); // should have consumed all.
636
637 // Create a new op to let the MLIR builder calculate the correct metadata.
638 auto allocOp =
639 AllocOp::create(rewriter, loc, stt.getDemappedType(), dynLvlSzs);
640
641 Value t = genRemap(rewriter, stt.getEncoding(), allocOp.getResult());
642 rewriter.replaceOp(op, t);
643 return success();
644 }
645};
646
647struct TensorInsertDemapper
648 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
649 using DemapInsRewriter::DemapInsRewriter;
650 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
651 PatternRewriter &rewriter) const {
653 return failure();
654
655 Location loc = op.getLoc();
656 auto stt = getSparseTensorType(op.getResult());
657 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
658 CrdTransDirectionKind::dim2lvl);
659 auto insertOp = tensor::InsertOp::create(rewriter, loc, op.getScalar(),
660 adaptor.getDest(), lvlCrd);
661
662 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
663 rewriter.replaceOp(op, out);
664 return success();
665 }
666};
667
668struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
670 LogicalResult matchAndRewrite(AssembleOp op,
671 PatternRewriter &rewriter) const override {
673 return failure();
674
675 assert(hasAnySparseResult(op));
676 auto stt = getSparseTensorType(op.getResult());
677 rewriter.modifyOpInPlace(
678 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
679 rewriter.setInsertionPointAfter(op);
680 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
681 rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
682 return success();
683 }
684};
685
686struct SparseDisassembleDemapper
687 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
688 using DemapInsRewriter::DemapInsRewriter;
689 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
690 PatternRewriter &rewriter) const {
692 return failure();
693
694 assert(hasAnySparseOperandOrResult(op));
695 rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
696 op.getTensorMutable().assign(adaptor.getTensor());
697 });
698 return success();
699 }
700};
701
702struct ForeachOpDemapper
703 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
704 using DemapInsRewriter::DemapInsRewriter;
705 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
706 PatternRewriter &rewriter) const {
707 // Only handle operations with sparse input/output with non-identity dim2lvl
708 // maps.
710 return failure();
711
712 // TODO: demap constant as well.
713 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
714 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
715 return failure();
716
717 Location loc = op.getLoc();
718 // Cache the type information since we update the foreach op in-place.
719 auto srcStt = getSparseTensorType(op.getTensor());
720 SmallVector<Type> prevRetTps(op.getResultTypes());
721
722 rewriter.startOpModification(op);
723 op.getTensorMutable().assign(adaptor.getTensor());
724 op.getInitArgsMutable().assign(adaptor.getInitArgs());
725 // Update results' types.
726 for (auto r : op.getResults())
727 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
728 r.setType(stt->getDemappedType());
729
730 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
731 // Update the foreach body.
732 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
733 blockArgTps.push_back(srcStt.getElementType());
734 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
735 adaptor.getInitArgs().getTypes().end());
736 Block *body = op.getBody();
737 // Block Args: [dimCrd, val, initArgs]
738 unsigned preArgNum = body->getNumArguments();
739 for (Type t : blockArgTps)
740 body->addArgument(t, loc);
741
742 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
743 rewriter.setInsertionPointToStart(body);
744 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
745
746 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
747 CrdTransDirectionKind::lvl2dim);
748 rewriter.replaceAllUsesWith(
749 body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
750 body->eraseArguments(0, srcStt.getDimRank());
751 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
752 unsigned numInitArgs = op.getInitArgs().size();
753 rewriter.replaceAllUsesWith(body->getArgument(0),
754 body->getArgument(lvlRank + numInitArgs + 1));
755 body->eraseArgument(0);
756 // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
757 ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
758 ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
759 // Remap back before replacement.
760 SmallVector<Value> reMappedArgs =
761 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
762 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
763 body->eraseArguments(0, numInitArgs);
764 // Block Args: [lvlCrds, DemappedArgs] and we are done.
765
766 // Update yield operations.
767 if (numInitArgs != 0) {
768 rewriter.setInsertionPointToEnd(body);
769 auto yield = llvm::cast<YieldOp>(body->getTerminator());
770 if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
771 stt && !stt->isIdentity()) {
772 Value y =
773 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
774 YieldOp::create(rewriter, loc, y);
775 rewriter.eraseOp(yield);
776 }
777 }
778 rewriter.finalizeOpModification(op);
779
780 rewriter.setInsertionPointAfter(op);
781 SmallVector<Value> outs =
782 remapValueRange(rewriter, prevRetTps, op.getResults());
783
784 // Replace all the uses of the foreach results, expect the use in
785 // reinterpret_map used to remap the output.
786 for (auto [from, to] : llvm::zip(op.getResults(), outs))
787 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
788
789 return success();
790 }
791};
792
793} // namespace
794
796 RewritePatternSet &patterns, ReinterpretMapScope scope,
798 if (scope == ReinterpretMapScope::kAll ||
800 patterns.add<GenericOpReinterpretMap>(patterns.getContext());
801 patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
802 }
803 if (scope == ReinterpretMapScope::kAll ||
805 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
806 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
807 SparseDisassembleDemapper, TensorInsertDemapper,
808 ForeachOpDemapper>(patterns.getContext());
809 }
810}
return success()
lhs
ArrayAttr()
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static SmallVector< Value > remapValueRange(OpBuilder &rewriter, TypeRange types, ValueRange outs)
static AffineMap genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, SmallVector< utils::IteratorType > &itTps)
static std::optional< std::pair< ArrayAttr, ArrayAttr > > translateMap(linalg::GenericOp op, PatternRewriter &rewriter)
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc, Value val)
static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput)
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:206
BlockArgListType getArguments()
Definition Block.h:97
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:198
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy)
Factory method that constructs an iteration graph sorter for the given linalg.generic operation with ...
AffineMap sort(SortMask mask, Value ignored=nullptr)
Returns a permutation that represents the scheduled loop order.
Level getLvlRank() const
Returns the level-rank.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
bool hasAnySparseOperandOrResult(Operation *op)
Returns true iff MLIR operand has any sparse operand or result.
uint64_t Level
The type of level identifiers and level-ranks.
LoopOrderingStrategy
Defines a strategy for loop ordering during sparse code generation.
Definition Passes.h:62
AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context)
Given the dimToLvl map, infers the lvlToDim map, or returns empty Affine map when inference fails.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
std::optional< SparseTensorType > tryGetSparseTensorType(Value val)
bool hasAnyNonIdentityOperandsOrResults(Operation *op)
Returns true iff MLIR operation has any sparse tensor with non-identity dim2lvl maps.
SparseTensorType getSparseTensorType(Value val)
Convenience methods to obtain a SparseTensorType from a Value.
SortMask
Iteration graph sorting mask,.
bool hasAnySparseResult(Operation *op)
Returns true iff MLIR operand has any sparse result.
Include the generated interface declarations.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
ReinterpretMapScope
Defines a scope for reinterpret map pass.
Definition Passes.h:45
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void populateSparseReinterpretMap(RewritePatternSet &patterns, ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy=sparse_tensor::LoopOrderingStrategy::kDefault)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...