MLIR 23.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
38#include <type_traits>
39#include <utility>
40
41#define DEBUG_TYPE "linalg-transforms"
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46//===----------------------------------------------------------------------===//
47// Transformations exposed as functional-style API calls.
48//===----------------------------------------------------------------------===//
49
50//===----------------------------------------------------------------------===//
51// peelLoop transformation.
52//===----------------------------------------------------------------------===//
53
54/// Try to peel and canonicalize loop `op` and return the new result.
55/// Also applies affine_min/max bounds simplification on the fly where relevant.
56// TODO: Add support for scf.parallel and affine.for loops.
58 Operation *op) {
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
62 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
63 partialIteration)))
64 return partialIteration->getResults();
65 assert(!partialIteration && "expected that loop was not peeled");
66 return forOp->getResults();
67 })
68 .Default([&](Operation *op) { return op->getResults(); });
69}
70
71/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
72/// where relevant.
75 for (auto loopOp : loops)
76 peelLoop(rewriter, loopOp);
77}
78
79//===----------------------------------------------------------------------===//
80// pack transformation.
81//===----------------------------------------------------------------------===//
82
83#ifndef NDEBUG
84/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
86 bool found = false;
87 for (AffineExpr e : map.getResults()) {
88 if (!e.isFunctionOfDim(dim))
89 continue;
90 if (found)
91 return false;
92 found = true;
93 }
94 return true;
95}
96#endif // NDEBUG
97
99 return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
100}
101
102/// Return the index of the first result of `map` that is a function of
103/// AffineDimExpr(dim), std::nullopt otherwise.
104static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
105 int64_t dim) {
106 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
107 AffineExpr expr = map.getResult(i);
108 if (!expr.isFunctionOfDim(dim))
109 continue;
110 return i;
111 }
112 return std::nullopt;
113}
114
115/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
116/// `newDim` at `iteratorTypes.size()` by:
117/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
118/// 2. Appending a `newDim` to the domain of every indexing map.
119/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
120/// by potentially adding a `newDim` result to `map`.
121/// The preserved invariant is that `iteratorTypes.size()` is always equal to
122/// `map.getNumDims()` for every map in `indexingMaps`.
123///
124/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
125/// Return a vector that records the optional packing for each operand.
126/// Return failure if the packed indexing cannot be represented with a LinalgOp.
127///
128/// Further details:
129/// ================
130/// The current implementation of packing (i.e. data tiling) consists of
131/// rewriting a linearized strip-mined form into a higher-dimensional access.
132/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
133/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
134/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
135///
136/// This rewrite into higher dimensional access is not possible for general
137/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
138/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
139/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
140/// The rewrite of the access would be a form not representable in Linalg:
141/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
142/// Note however that as `J` and `ii` iterate, the accesses do not have a
143/// particular alignment, so packing does not achieve alignment in this case
144///
145/// In the future, we may want to consider a mixed-form that allows some
146/// alignment in the presence of multiple accesses:
147/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
148/// And would rewrite accesses as:
149/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
150static FailureOr<SmallVector<std::optional<int64_t>>>
153 int64_t dim) {
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
156
157 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
158 indexingMaps.size(), std::nullopt);
160 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
161 ++operandIdx) {
162 AffineMap map = indexingMaps[operandIdx];
163
164 // Add the `newDim` to map whatever the case.
165 assert(map.getNumDims() == newDim && "num dims invariant violation");
166 map = map.shiftDims(1, newDim);
167
168 // Get the at-most-1 index of the result that is a function of `dim`.
169 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
170 // logically chunks dimension `dim` into `K * dim + newDim`, where the
171 // packing factor `K` is specified separately.
172 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
173 "num results invariant violation");
174 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
177 continue;
178 }
179
180 // We can only pack AffineDimExpr atm.
181 if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
182 return failure();
183
184 // Add `newDim` to the results of the map.
185 map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
186 map.getNumResults());
187 newMaps.push_back(map);
188
189 // Record the that `operandIdx` is packed.
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
191 }
192 indexingMaps = newMaps;
193
194 return packedDimPerIndexingMap;
195}
196
197namespace {
198
199/// Helper struct to encode packing along one dimension of a LinalgOp.
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
203};
204
205/// Helper struct to encode packing along all dimensions of a LinalgOp.
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
209 }
210 /// Return all the dims that have been packed for operand @ `operandPos`.
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
212 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
214
215private:
216 SmallVector<PackedOperandsDim> spec;
217};
218
219} // namespace
220
221FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
224 // TODO: Support Memref PackOp. Temporarily return failure.
225 if (!packOp.hasPureTensorSemantics())
226 return failure();
227
228 auto packedTensorType =
229 cast<RankedTensorType>(packOp->getResultTypes().front());
230
231 Location loc = packOp->getLoc();
232 OpBuilder::InsertionGuard g(rewriter);
233 rewriter.setInsertionPoint(packOp);
234
235 // 2. Compute the permutation vector to shuffle packed shape into the shape
236 // before any outer or inner permutations have been applied.
237 PackingMetadata packingMetadata;
238 SmallVector<int64_t> packedToStripMinedShapePerm =
239 getPackInverseDestPerm(packOp, packingMetadata);
240
241 // 3. Compute the stripMinedShape: this is the packed shape before any outer
242 // or inner permutations have been applied.
243 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
244 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
245
246 // Also compute the mixed (static+dynamic) strip-mined sizes for the
247 // expand_shape output. This is needed to support dynamic inner tile sizes,
248 // since the shapes cannot be inferred automatically when multiple dynamic
249 // dims appear in a single reassociation group during ExpandShapeOp
250 // construction.
251 SmallVector<OpFoldResult> stripMinedMixedSizes =
252 tensor::getMixedSizes(rewriter, loc, packOp.getDest());
253 applyPermutationToVector(stripMinedMixedSizes, packedToStripMinedShapePerm);
254
255 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
256 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
257 rewriter.getIndexAttr(0));
258 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
259 rewriter.getIndexAttr(0));
260 for (auto [pos, innerSize] :
261 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
262 int outerPos =
263 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
264 OpFoldResult origSize =
265 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
266 OpFoldResult outerSize =
267 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
268 AffineExpr s0, d0, d1;
269 bindDims(rewriter.getContext(), d0, d1);
270 bindSymbols(rewriter.getContext(), s0);
271 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
273 rewriter, loc, map, {outerSize, origSize, innerSize});
274 }
275 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
276 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
277 packingMetadata.reassociations);
278 Value paddingValue = packOp.getPaddingValue();
279 if (!paddingValue) {
280 paddingValue = arith::ConstantOp::create(
281 rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
282 }
283 auto padOp =
284 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
285 highs, paddingValue, /*nofold=*/false);
286
287 LDBG() << "insertPositions: "
288 << llvm::interleaved(packingMetadata.insertPositions);
289 LDBG() << "outerPositions: "
290 << llvm::interleaved(packingMetadata.outerPositions);
291 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
292 LDBG() << "packedToStripMinedShapePerm: "
293 << llvm::interleaved(packedToStripMinedShapePerm);
294 LDBG() << "reassociations: "
295 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
297 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
298 LDBG() << "collapsed type: " << collapsed;
299
300 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
301 // Pack ops which operate as simple pads may not produce legal
302 // tensor.insert_slice operations when the packed type does not rank reduce
303 // to the padded type.
304 SliceVerificationResult rankReduces =
305 isRankReducedType(packedTensorType, padOp.getResultType());
306
307 if (rankReduces == SliceVerificationResult::Success) {
308 // This pack is just a plain pad.
309 // Just insert the pad in the higher ranked tensor.
310 // Offsets.
311 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
312 rewriter.getIndexAttr(0));
313 // Strides.
314 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
315 rewriter.getIndexAttr(1));
317 tensor::getMixedSizes(rewriter, loc, packOp.getDest());
318
319 auto insertSliceOp = tensor::InsertSliceOp::create(
320 rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
321 /*offsets=*/zeros, sizes, /*strides=*/ones);
322
323 LDBG() << "insert_slice op: " << insertSliceOp;
324
325 rewriter.replaceOp(packOp, insertSliceOp->getResults());
326
327 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
328 /*transposeOp=*/nullptr};
329 }
330 }
331
332 // 5. Expand from the padded result to the stripMinedShape.
333 auto expandShapeResultType =
334 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
335 auto reshapeOp = tensor::ExpandShapeOp::create(
336 rewriter, loc, expandShapeResultType, padOp.getResult(),
337 packingMetadata.reassociations, stripMinedMixedSizes);
338
339 // 6. Transpose stripMinedShape to packedShape.
340 SmallVector<int64_t> transpPerm =
341 invertPermutationVector(packedToStripMinedShapePerm);
342 auto transposeOp = linalg::TransposeOp::create(
343 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
344
345 LDBG() << "reshape op: " << reshapeOp;
346 LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
347 LDBG() << "transpose op: " << transposeOp;
348
349 // 7. Replace packOp by transposeOp.
350 rewriter.replaceOp(packOp, transposeOp->getResults());
351
352 return LowerPackResult{padOp, reshapeOp, transposeOp};
353}
354
355FailureOr<LowerUnPackOpResult>
356linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
357 bool lowerUnpadLikeWithExtractSlice) {
358 // TODO: Support Memref UnPackOp. Temporarily return failure.
359 if (!unPackOp.hasPureTensorSemantics())
360 return failure();
361
362 Location loc = unPackOp->getLoc();
363 OpBuilder::InsertionGuard g(rewriter);
364 rewriter.setInsertionPoint(unPackOp);
365
366 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
367 int64_t packedRank = packedTensorType.getRank();
368
369 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
370 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
371 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
372 // This unpack is just a plain unpad.
373 // Just extract the slice from the higher ranked tensor.
374 ArrayRef<int64_t> destShape = destTensorType.getShape();
375 // The inner dimensions stay the same as the destination tensor, but the
376 // outer ones are additional 1s.
377 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
378 sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
379
380 auto extractSliceOp = tensor::ExtractSliceOp::create(
381 rewriter, loc, destTensorType, unPackOp.getSource(),
382 SmallVector<OpFoldResult>(packedRank, zero), sizes,
383 SmallVector<OpFoldResult>(packedRank, one));
384
385 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
386
387 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
388 /*reshapeOp=*/nullptr, extractSliceOp,
389 /*copyOp=*/nullptr};
390 }
391
392 // 1. Compute the permutation vector to shuffle packed shape into the shape
393 // before any outer or inner permutations have been applied.
394 PackingMetadata packingMetadata;
395 SmallVector<int64_t> packedToStripMinedShapePerm =
396 getUnPackInverseSrcPerm(unPackOp, packingMetadata);
397
398 // 2. Compute the stripMinedShape: this is the packed shape without outer and
399 // inner permutations.
400 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
401 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
402
403 // 3. Transpose packedShape to stripMinedShape.
404 RankedTensorType stripMinedTensorType =
405 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
406 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
407 stripMinedTensorType, packingMetadata.reassociations);
408
409 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
410 // permutation.
412 tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
413 applyPermutationToVector(dims, packedToStripMinedShapePerm);
414 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
415 stripMinedTensorType.getElementType());
416 auto transposeOp =
417 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
418 packedToStripMinedShapePerm);
419
420 LDBG() << "insertPositions: "
421 << llvm::interleaved(packingMetadata.insertPositions);
422 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
423 LDBG() << "packedToStripMinedShapePerm: "
424 << llvm::interleaved(packedToStripMinedShapePerm);
425 LDBG() << "reassociations: "
426 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
428 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
429 LDBG() << "collapsed type: " << collapsedType;
430
431 // 4. Collapse from the stripMinedShape to the padded result.
432 auto reshapeOp = tensor::CollapseShapeOp::create(
433 rewriter, loc, collapsedType, transposeOp->getResult(0),
434 packingMetadata.reassociations);
435
436 // 5. ExtractSlice.
437 int64_t destRank = destTensorType.getRank();
438 auto extractSliceOp = tensor::ExtractSliceOp::create(
439 rewriter, loc, destTensorType, reshapeOp->getResult(0),
440 SmallVector<OpFoldResult>(destRank, zero),
441 tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
442 SmallVector<OpFoldResult>(destRank, one));
443
444 // 6. Inject a copy to preserve DPS.
445 auto copyOp = linalg::CopyOp::create(
446 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
447
448 // 7. Replace unPackOp by copyOp.
449 rewriter.replaceOp(unPackOp, copyOp->getResults());
450
451 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp,
452 copyOp};
453}
454
456PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
458 for (auto &i : spec) {
459 if (!i.packedDimForEachOperand[operandPos].has_value())
460 continue;
461 res.push_back(i.packedDimForEachOperand[operandPos].value());
462 }
463 return res;
464}
465
466SmallVector<OpFoldResult>
467PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
468 SmallVector<OpFoldResult> res;
469 for (auto &i : spec) {
470 if (!i.packedDimForEachOperand[operandPos].has_value())
471 continue;
472 res.push_back(i.packedSize);
473 }
474 return res;
475}
476
477/// Implement packing of a single LinalgOp by performing packing by
478/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
479/// Return the packed Linalg op on success, failure otherwise.
480FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
481 linalg::LinalgOp linalgOp,
482 ArrayRef<OpFoldResult> packedSizes) {
483 if (packedSizes.size() != linalgOp.getNumLoops()) {
484 return rewriter.notifyMatchFailure(linalgOp,
485 "incorrect number of pack sizes");
486 }
487
488 Location loc = linalgOp->getLoc();
489 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
491 linalgOp.getIteratorTypesArray();
492 LDBG() << "Start packing: " << linalgOp;
493 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
494 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
495
498 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
499 PackedOperandsDimList listOfPackedOperandsDim;
500 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
502 // Skip tile sizes explicitly set to 0.
503 if (maybeConstant.has_value() && maybeConstant.value() == 0)
504 continue;
505
506 PackedOperandsDim packedOperandsDims;
507 packedOperandsDims.packedSize = packedSizes[i];
508 FailureOr<SmallVector<std::optional<int64_t>>>
509 maybePackedDimForEachOperand =
510 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
511 if (failed(maybePackedDimForEachOperand))
512 return failure();
513 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
514
515 LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
516 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
517 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
518 LDBG() << "packedDimForEachOperand: "
519 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
520
521 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
522 }
523
524 // Step 2. Propagate packing to all LinalgOp operands.
525 SmallVector<Value> inputsAndInits, results;
526 SmallVector<OpOperand *> initOperands =
527 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
528 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
529 for (const auto &operandsList : {inputOperands, initOperands}) {
530 for (OpOperand *opOperand : operandsList) {
531 int64_t pos = opOperand->getOperandNumber();
532 Value operand = opOperand->get();
533 SmallVector<int64_t> innerPos =
534 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
535 SmallVector<OpFoldResult> innerPackSizes =
536 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
537 LDBG() << "operand: " << operand;
538 LDBG() << "innerPos: " << llvm::interleaved(innerPos);
539 LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
540 if (innerPackSizes.empty()) {
541 inputsAndInits.push_back(operand);
542 continue;
543 }
544 Value dest = linalg::PackOp::createDestinationTensor(
545 rewriter, loc, operand, innerPackSizes, innerPos,
546 /*outerDimsPerm=*/{});
547 ShapedType operandType = cast<ShapedType>(operand.getType());
548 bool areConstantTiles =
549 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
550 return getConstantIntValue(tile).has_value();
551 });
552 if (areConstantTiles && operandType.hasStaticShape() &&
553 !linalg::PackOp::requirePaddingValue(
554 operandType.getShape(), innerPos,
555 cast<ShapedType>(dest.getType()).getShape(), {},
556 innerPackSizes)) {
557 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
558 innerPos, innerPackSizes));
559 } else {
560 // TODO: value of the padding attribute should be determined by
561 // consumers.
562 auto zeroAttr =
563 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
564 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
565 packOps.push_back(linalg::PackOp::create(
566 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
567 }
568 inputsAndInits.push_back(packOps.back().getResult());
569 }
570 }
571
572 // Step 3. Build the packed op, use the type of `inits` as result types.
573 ValueRange inputs =
574 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
575 ValueRange inits =
576 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
577 auto packedLinalgOp =
578 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
579 inputs, inits, indexingMaps, iteratorTypes);
580 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
581
582 // Step 4. Propagate packing to all the op results.
583 for (OpResult result : packedLinalgOp->getResults()) {
584 int64_t resultNum = result.getResultNumber();
585 linalg::PackOp maybePackedInit =
586 inits[resultNum].getDefiningOp<linalg::PackOp>();
587 if (!maybePackedInit) {
588 results.push_back(result);
589 continue;
590 }
591 // Build the symmetrical UnPackOp to the existing PackOp.
592 unPackOps.push_back(linalg::UnPackOp::create(
593 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
594 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
595 results.push_back(unPackOps.back().getResult());
596 }
597
598 // Step 5. Replace `linalgOp`.
599 rewriter.replaceOp(linalgOp, results);
600
601 // Return packedLinalgOp.
602 return PackResult{packOps,
603 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
604 unPackOps};
605}
606
607//===----------------------------------------------------------------------===//
608// packTranspose transformation.
609//===----------------------------------------------------------------------===//
610
611/// Return a copy of `tensorType` after permutation by `permutationVector`.
612// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
613// but this would introduce a dependence on Dialect in IR.
614// TODO: Restructure.
615static RankedTensorType permuteShape(RankedTensorType tensorType,
616 ArrayRef<int64_t> permutationVector) {
617 SmallVector<int64_t> shape(tensorType.getShape());
618 applyPermutationToVector(shape, permutationVector);
619 return RankedTensorType::Builder(tensorType).setShape(shape);
620}
621
622/// Return a new GenericOp obtained by transposing opOperand by the permutation
623/// vector:
624/// - the corresponding indexing map is transposed by `permutation`
625/// - the corresponding operand value is replaced by `transposedValue`
626/// `linalgOp` is replaced by the return op in the process.
627/// Asserts that `transposedValue` is of the proper transposed ShapedType.
629 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
630 ArrayRef<int64_t> permutation, Value transposedValue) {
631 // Sanity check the operand.
632 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
633
634 // Sanity check of the expected transposed tensor type.
635 auto tensorType = permuteShape(
636 cast<RankedTensorType>(opOperand.get().getType()), permutation);
637 (void)tensorType;
638 assert(tensorType == transposedValue.getType() &&
639 "expected tensor type mismatch");
640
641 // Compute the transposed indexing map.
642 // Sigh unsigned pollution.
643 SmallVector<unsigned> tmpTransposition =
644 llvm::map_to_vector(permutation, [](int64_t i) -> unsigned { return i; });
645 AffineMap permutationMap =
646 AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
647 AffineMap transposedMap =
648 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
649
650 // Set the transposed indexing map in the proper position.
651 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
652 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
653 // Set the transposedValue in the proper operand position.
654 SmallVector<Value> operands = linalgOp->getOperands();
655 operands[opOperand.getOperandNumber()] = transposedValue;
656
657 ValueRange operandsRef(operands);
658 auto transposedGenericOp = linalg::GenericOp::create(
659 rewriter,
660 /*location=*/linalgOp->getLoc(),
661 /*resultTensorTypes=*/
662 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
663 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
664 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
665 /*indexingMaps=*/indexingMaps,
666 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
667 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
668 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
669
670 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
671}
672
673FailureOr<PackTransposeResult>
674linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
675 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
676 ArrayRef<int64_t> outerPerm,
677 ArrayRef<int64_t> innerPerm) {
678 Location loc = linalgOp.getLoc();
679
680 // Step 1. Transpose packOp.
681 rewriter.setInsertionPoint(packOp);
682 linalg::PackOp transposedPackOp =
683 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
684
685 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
686 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
687
688 OpOperand &packUse = *packOp->getUses().begin();
689 if (packUse.getOwner() != linalgOp) {
690 return rewriter.notifyMatchFailure(
691 linalgOp, "not a single use by the LinalgOp target");
692 }
693 if (maybeUnPackOp &&
694 (!linalgOp.isDpsInit(&packUse) ||
695 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
696 return rewriter.notifyMatchFailure(linalgOp,
697 "not produced by the LinalgOp target");
698 }
699
700 // Step 2. Transpose linalgOp.
701 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
702 // identity. Don't rely on it.
703 int64_t numLeadingDims = packOp.getSourceRank();
704 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
705 // Step 2.a. Compute the permutation on the whole operand.
706 // Leading part just reuse the outerPerm.
707 SmallVector<int64_t> permutation(outerPerm);
708 if (permutation.empty())
709 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
710 // Trailing part needs to reindex positions by `numLeadingDims`.
711 if (innerPerm.empty()) {
712 llvm::append_range(
713 permutation,
714 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
715 } else {
716 llvm::append_range(permutation,
717 llvm::map_range(innerPerm, [&](int64_t pos) {
718 return numLeadingDims + pos;
719 }));
720 }
721 if (!isPermutationVector(permutation))
722 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
723
724 // Step 2.b. Save the transposedPackUse operand number in case we need to
725 // get the tied OpResult after `linalgOp` has been replaced.
726 int64_t packUseOperandNumber = packUse.getOperandNumber();
727 // Step 2.c. Actually perform the transposition.
728 rewriter.setInsertionPoint(linalgOp);
729 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
730 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
731
732 // Step 3. Maybe transpose unPackOp.
733 linalg::UnPackOp transposedUnPackOp;
734 if (maybeUnPackOp) {
735 OpOperand &opOperand =
736 transposedLinalgOp->getOpOperand(packUseOperandNumber);
737 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
738 rewriter.setInsertionPoint(maybeUnPackOp);
739 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
740 rewriter, loc, transposedResult, innerPerm, outerPerm);
741
742 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
743 }
744
745 // Step 4. Finally, replace packOp now that we don't need it anymore.
746 if (packOp.hasPureTensorSemantics())
747 rewriter.replaceOp(packOp, transposedPackOp->getResults());
748 else
749 rewriter.eraseOp(packOp);
750
751 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
752 transposedUnPackOp};
753}
754
755//===----------------------------------------------------------------------===//
756// packMatmulGreedily transformation.
757//===----------------------------------------------------------------------===//
758
759/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
760/// and n are proper parallel dimensions and k is a proper reduction
761/// dimension. Packing occurs by rewriting the op as a linalg.generic and
762/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
763/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
764/// to reorder {m, n, k} into one of the 8 possible forms. The outer
765/// dimensions of the operands are not permuted at this time, this is left for
766/// future work.
767FailureOr<PackResult>
768linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
769 ArrayRef<OpFoldResult> mnkPackedSizes,
770 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
771 ArrayRef<int64_t> mnkOrder) {
772 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
773 assert((mnkPaddedSizesNextMultipleOf.empty() ||
774 mnkPaddedSizesNextMultipleOf.size() == 3) &&
775 "num of packing sizes next multiple should be empty or of size 3");
776 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
777 assert(isPermutationVector(mnkOrder) && "expected a permutation");
778
779 int64_t numLoops = linalgOp.getNumLoops();
780 if (numLoops <= 2) {
781 LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
782 << " in: " << linalgOp;
783 return rewriter.notifyMatchFailure(
784 linalgOp, "need 3+ loops to find a matmul to pack");
785 }
786
787 // Locally adjust the desired iterator position of mnk and packing sizes.
788 int64_t numPackedDims = mnkPackedSizes.size();
789 SmallVector<int64_t> mmnnkkPos(numPackedDims);
790 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
792 SmallVector<OpFoldResult> packedSizes(numPackedDims);
793 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
794 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
795 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
796 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
797 paddedSizesNextMultipleOf[mnkOrder[i]] =
798 mnkPaddedSizesNextMultipleOf.empty() ? 0
799 : mnkPaddedSizesNextMultipleOf[i];
800 }
801
802 // 1. Infer dims that are important for matmul.
803 FailureOr<ContractionDimensions> maybeDimensions =
804 inferContractionDims(linalgOp);
805 if (failed(maybeDimensions)) {
806 LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
807 return rewriter.notifyMatchFailure(linalgOp,
808 "couldn't infer matmul iterators");
809 }
810
811 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
812 // minor iterators. In cases with multiple options for m, n, k bias towards
813 // the most minor embedding.
814 // If we wanted a different normalization order, this is where it would have
815 // to plug a heuristic.
816 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
817 kPos = maybeDimensions->k.back();
818 LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
819 << nPos << ", k@" << kPos << "): " << linalgOp;
820
821 // 2.a. Rewrite as a generic.
822 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
823 if (!genericOp) {
824 FailureOr<GenericOp> generalizeResult =
825 generalizeNamedOp(rewriter, linalgOp);
826 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
827 genericOp = *generalizeResult;
828 }
829
830 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
831 // iterators. Note that this only normalized the iteration order and does
832 // not change the indexings of any operand.
833 SmallVector<int64_t> permutation =
834 computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
835 LDBG() << "perm: " << llvm::interleaved(permutation);
836 // Sign .. unsigned pollution.
837 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
838 FailureOr<GenericOp> interchangeResult =
839 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
840 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
841 genericOp = *interchangeResult;
842 LDBG() << "Generalized Op to pack: " << genericOp;
843
844 // At this point, the op iterators are normalized to {leading, k, m, n}.
845 // The layouts induced by packing will always be:
846 // - LHS{leading_lhs, kk, mm}
847 // - RHS{leading_rhs, kk, nn}
848 // - RES{leading_res, mm, nn}
849 // If we wanted to change the packed order, we would reorder (k, m, n) to
850 // something else above.
851 //
852 // Additional permutations of the outer dims of the operands (i.e.
853 // leading_lhs, leading_rhs and leading_res) could follow by computing the
854 // desired outerPerm for each operand.
855 // This is left for future work.
856
857 // TODO: this creates too much IR, go use reifyResultShapes.
858 SmallVector<Range, 4> loopRanges =
859 cast<LinalgOp>(genericOp.getOperation())
860 .createLoopRanges(rewriter, genericOp.getLoc());
861
862 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
863 // post interchange.
864 LDBG() << "paddedSizesNextMultipleOf: "
865 << llvm::interleaved(paddedSizesNextMultipleOf);
866 LDBG() << "loopRanges: "
867 << llvm::interleaved(
868 llvm::map_range(loopRanges, [](Range r) { return r.size; }));
869 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
870 rewriter.getIndexAttr(0));
871 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
872 if (paddedSizesNextMultipleOf[i] == 0) {
873 adjustedPackedSizes.push_back(packedSizes[i]);
874 continue;
875 }
876 AffineExpr d0, s0;
877 bindDims(rewriter.getContext(), d0);
878 bindSymbols(rewriter.getContext(), s0);
879 adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
880 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
881 {loopRanges[adjustedPackedSizes.size()].size,
882 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
883 }
884 LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
885
886 // TODO: If we wanted to give the genericOp a name after packing, after
887 // calling `pack` would be a good time. One would still need to check that
888 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
889 // also allow degenerate matmul cases (i.e. matvec, dot).
890 return pack(rewriter, genericOp, adjustedPackedSizes);
891}
892
893//===----------------------------------------------------------------------===//
894// Transformations exposed as rewrite patterns.
895//===----------------------------------------------------------------------===//
896
899 assert(!tileSizeComputationFunction && "tile sizes already set");
900 SmallVector<int64_t, 4> tileSizes(ts);
901 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
903 b.setInsertionPointToStart(
904 &op->getParentOfType<func::FuncOp>().getBody().front());
905 return llvm::map_to_vector<4>(tileSizes, [&](int64_t s) {
906 Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
907 return v;
908 });
909 };
910 return *this;
911}
912
914 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
915 return vectorizeCopy(rewriter, copyOp);
916}
917
918/// Filling `dest` using FillOp constant padding value if possible.
919/// Otherwise, generate a tensor::GenerateOp.
921 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
922 const SmallVector<Value> &dynSizes) const {
923 auto padValue = padOp.getConstantPaddingValue();
924 if (padValue) {
925 // Move the padding value defined inside the PadOp block to outside.
926 if (padValue.getParentBlock() == &padOp.getRegion().front())
927 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
928 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
929 }
930
931 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
932 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
933 padOp.getResultType(), dynSizes);
934 // Copy region to new op.
935 IRMapping bvm;
936 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
937 return generateOp;
938}
939
940LogicalResult
942 PatternRewriter &rewriter) const {
943 // Given an OpFoldResult, return an index-typed value.
944 auto getIdxValue = [&](OpFoldResult ofr) {
945 if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
946 return val;
948 rewriter, padOp.getLoc(),
949 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
950 .getResult();
951 };
952
953 auto resultType = padOp.getResultType();
954 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
955 SmallVector<Value> dynSizes;
956 SmallVector<int64_t> staticSizes;
957 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
958 if (resultType.isDynamicDim(dim)) {
959 auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
960 padOp.getSource(), dim));
961 // Add low and high padding value.
962 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
963 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
964 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
965 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
966 dynSizes.push_back(plusHigh);
967 }
968 staticSizes.push_back(resultType.getDimSize(dim));
969 }
970
971 // Init tensor and fill it with padding.
972 Value emptyTensor =
973 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
974 resultType.getElementType(), dynSizes);
975 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
976
977 // Generate a InsertSliceOp for copying the PadOp source.
978 auto sourceType = padOp.getSourceType();
979 // Compute size of source of tensor::PadOp.
981 tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
982 // Strides of InsertSliceOp are all 1.
983 SmallVector<OpFoldResult> strides(sourceType.getRank(),
984 rewriter.getIndexAttr(1));
985 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
986 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
987 strides);
988
989 return success();
990}
991
993 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
994 if (!sliceOp.hasUnitStride())
995 return failure();
996
997 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
998 if (!padOp)
999 return failure();
1000
1001 bool zeroSliceGuard = true;
1002 if (controlFn) {
1003 if (std::optional<bool> control = controlFn(sliceOp))
1004 zeroSliceGuard = *control;
1005 else
1006 return failure();
1007 }
1008
1009 FailureOr<TilingResult> tilingResult =
1010 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1011 sliceOp.getMixedSizes(), zeroSliceGuard);
1012 if (failed(tilingResult))
1013 return failure();
1014
1015 RankedTensorType sourceType = sliceOp.getSourceType();
1016 RankedTensorType resultType = sliceOp.getResultType();
1017
1018 // If the extract_slice is not rank-reduced, all shapes are static and the
1019 // data source is actually used. Rewrite into pad(extract_slice(x)).
1020 if (sourceType.getRank() == resultType.getRank()) {
1021 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1022 return success();
1023 }
1024
1025 // Handle rank-reduced slice by creating another extract_slice op.
1027 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1028
1029 rewriter.replaceOp(sliceOp, rankReduced);
1030 return success();
1031}
1032
1033/// If padding value is set, returns a tensor.pad Op for the source tensor,
1034/// with the output shape matching the output of `packOp`. Otherwise, returns
1035/// the source directly.
1036///
1037/// This method assumes that all outer dims for this pack Op are 1.
1039 linalg::PackOp packOp) {
1040 Value input = packOp.getSource();
1041 // TODO: Support Memref PackOp. Temporarily return just Op Source.
1042 if (!packOp.hasPureTensorSemantics())
1043 return input;
1044
1045 if (!packOp.getPaddingValue()) {
1046 return input;
1047 }
1048
1049 assert(llvm::all_of(packOp.getAllOuterDims(),
1050 [](int64_t val) { return val == 1; }) &&
1051 "some outer dims are != 1");
1052
1053 Location loc = packOp.getLoc();
1054 ShapedType inputType = packOp.getSourceType();
1055 int64_t inputRank = inputType.getRank();
1056
1057 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1058 packOp.getDimAndTileMapping();
1059
1060 // The sizes of dynamic tiles
1061 SmallVector<Value> dynamicTileSizes;
1062
1063 // Collect dims for the padded shape.
1064 SmallVector<int64_t> paddedShape;
1065 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1066 // 1. Non-tiled outer dims.
1067 // These dims should be 1 and we simply preserve them.
1068 if (!tileAndPosMapping.count(dimIdx)) {
1069 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1070 assert(inputDimSize == 1 &&
1071 "with all outer dims == 1, this non-tiled input dim should be 1!");
1072 paddedShape.push_back(inputDimSize);
1073 continue;
1074 }
1075
1076 // 2. Tiled outer dims
1077 // As all outer dims == 1, it is safe to use the tile size for the padded
1078 // shape.
1079 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1080
1081 // 2.1 Static tile sizes
1082 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1083 if (cstTileSize.has_value()) {
1084 paddedShape.push_back(cstTileSize.value());
1085 continue;
1086 }
1087
1088 // 2.2 Dynamic tile sizes
1089 paddedShape.push_back(ShapedType::kDynamic);
1090
1091 // Get the value that holds the dynamic size.
1092 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1093 }
1094 auto resultType =
1095 RankedTensorType::get(paddedShape, inputType.getElementType());
1096 return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1097 /*nofold=*/false, loc, builder,
1098 dynamicTileSizes);
1099}
1100
1101// Normalizes a permutation on a higher rank space to its actual size, e.g.
1102// perm = [1, 4, 2]
1103// becomes
1104// norm = [0, 2, 1]
1105static SmallVector<int64_t>
1107 constexpr int64_t kNonTiledMarker = -1;
1108 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1109 for (auto [index, value] : llvm::enumerate(perm))
1110 vec[value] = index;
1111 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1112 vec, [&](int64_t v) { return v != kNonTiledMarker; });
1113 // This inverts the permutation in addition to normalizing so invert back.
1114 return invertPermutationVector(normalizedPerm);
1115}
1116
1117// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1118// assuming rank reduction of unit outer dims.
1119static SmallVector<int64_t>
1121 ArrayRef<int64_t> innerDimsPos,
1122 ArrayRef<int64_t> outerDimsPerm) {
1123 SmallVector<int64_t> rankReducedOuterDimsPerm;
1124 SmallVector<int64_t> outerDims;
1125 SmallVector<int64_t> innerDims;
1126 int64_t dim = 0;
1127 int64_t unpackedRank = shape.size();
1128 for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1129 if (llvm::is_contained(innerDimsPos, i)) {
1130 innerDims.push_back(dim++);
1131 continue;
1132 }
1133 if (shape[i] == 1)
1134 continue;
1135 outerDims.push_back(dim++);
1136 if (!outerDimsPerm.empty())
1137 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1138 }
1139
1140 // Get the position of the inner dims after permutation.
1141 SmallVector<int64_t> innerPerm =
1142 getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1143 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1144
1145 // Ditto for the outer dims.
1146 SmallVector<int64_t> perm = outerDims;
1147
1148 rankReducedOuterDimsPerm =
1149 getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1150 if (!rankReducedOuterDimsPerm.empty())
1151 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1152
1153 // The tile always ends up as the inner most dims after packing.
1154 perm.append(innerDims);
1155
1156 return perm;
1157}
1158
1160 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1161 // TODO: Support Memref PackOp. Temporarily return failure.
1162 if (!packOp.hasPureTensorSemantics())
1163 return failure();
1164
1165 if (llvm::any_of(packOp.getTiledOuterDims(),
1166 [](int64_t dim) { return dim != 1; })) {
1167 return rewriter.notifyMatchFailure(
1168 packOp, "not all outer dimensions of the result are 1s");
1169 }
1170
1171 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1172 auto outerDimsPerm = packOp.getOuterDimsPerm();
1173
1174 // Verify that there are no:
1175 // * non-unit + un-tiled-outer-dims,
1176 // that are permuted. Supporting such cases would require refining the logic
1177 // that generates the Transpose Op.
1178 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1179 static int prev = 0;
1180 // Skip tiled dims - these can be permuted.
1181 if (llvm::is_contained(innerDimsPos, dim))
1182 return true;
1183
1184 // Check whether this dim has been permuted. Permuting unit dims is fine
1185 // as that's effectively a no-op.
1186 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1187 packOp.getResult().getType().getShape()[dim] != 1))
1188 return false;
1189
1190 prev = dim;
1191 return true;
1192 })) {
1193 return rewriter.notifyMatchFailure(
1194 packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1195 "this is not supported ATM!");
1196 }
1197
1198 Location loc = packOp.getLoc();
1199
1200 int64_t srcRank = packOp.getSourceRank();
1201
1202 // 1. Get the input that is going to be packed. If the input requires padding,
1203 // add a padding operation and return that as the input.
1204 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1205
1206 // 2. Transpose the input to match the inner tile order:
1207 // %init = tensor.empty()
1208 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1209 // outs(%init)
1210 // Assumptions made:
1211 // - All tiled outer dims are 1 - the corresponding transposition order
1212 // doesn't matter, but requires all dim indices to be present.
1213 // - Un-tiled outer dims remain un-permuted.
1214
1215 // 2.1 Get the permutation for linalg.transpose:
1216 // [ untiled-dims, inner-dims-pos ]
1217 // Note, this logic assumes that the untiled dims are not permuted.
1218 SmallVector<int64_t> srcPermForTranspose;
1219 for (int64_t i = 0; i < srcRank; i++) {
1220 // We assume the `k` dimensions of the inner dim position, where `k` is the
1221 // rank of the inner tiling, correspond to the last `k` indices of the
1222 // transpose permutation. This is done by adding the indices not contained
1223 // in the inner dimension position in order from 0 to `n`. Where n is the
1224 // rank of the source tensor. For example if we have a source tensor with
1225 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1226 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1227 if (llvm::is_contained(innerDimsPos, i))
1228 continue;
1229 srcPermForTranspose.push_back(i);
1230 }
1231 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1232
1233 // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1234 // [ untiled-dims, tiled-dims ]
1235 ShapedType inputTy = cast<ShapedType>(input.getType());
1236 SmallVector<OpFoldResult> shapeForEmptyOp;
1237 for (int64_t i = 0; i < srcRank; i++) {
1238 if (llvm::is_contained(innerDimsPos, i)) {
1239 // The tiled dims are appended after this loop.
1240 continue;
1241 }
1242 if (inputTy.isStaticDim(i))
1243 shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1244 else
1245 shapeForEmptyOp.emplace_back(
1246 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1247 }
1248 shapeForEmptyOp.append(packOp.getMixedTiles());
1249
1250 // getMixedTiles() may contain Values pointing to constant ops (as opposed to
1251 // constant attributes with the corresponding value). Replace those with
1252 // attributes. This is to match the behaviour in
1253 // `getPackOpSourceOrPaddedSource`, which replaces constant SSA values with
1254 // attributes.
1255 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1256 [&](OpFoldResult ofr) {
1257 if (auto val = llvm::dyn_cast<Value>(ofr))
1258 return getAsOpFoldResult(val);
1259 return ofr;
1260 });
1261
1262 LDBG() << "Pack permutation: " << packOp;
1263 LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1264 LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1265
1266 Value empty = tensor::EmptyOp::create(
1267 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1268
1269 // 2.3 Create linalg.transpose
1270 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1271 srcPermForTranspose);
1272
1273 // 3. Insert the inner tile into the destination tensor:
1274 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1275
1276 // Compute the sizes attribute:
1277 // [ outer-dims, tile-sizes ]
1278 // Note that the output from the transpose Op excludes the tiled outer dims.
1279 // However, given the assumption that:
1280 // * all tiled outer dims == 1,
1281 // we can just use a rank-expanding tensor.insert_slice.
1282 SmallVector<OpFoldResult> writeSizes;
1283 for (auto size : packOp.getAllOuterDims()) {
1284 writeSizes.push_back(rewriter.getIndexAttr(size));
1285 }
1286
1287 for (auto tileSize : packOp.getMixedTiles()) {
1288 auto [_, tileSizeOfr] =
1289 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1290 writeSizes.push_back(tileSizeOfr);
1291 }
1292
1293 auto insert = tensor::InsertSliceOp::create(
1294 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1295
1296 // 4. Replace tensor.packOp with tensor.insert_slice created above
1297 rewriter.replaceOp(packOp, insert.getResult());
1298
1299 return success();
1300}
1301
1303 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1304 if (!unpackOp.hasPureTensorSemantics())
1305 return failure();
1306
1307 int64_t destRank = unpackOp.getDestRank();
1308 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1309 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1310 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1311 [](int64_t dim) { return dim != 1; })) {
1312 return rewriter.notifyMatchFailure(
1313 unpackOp,
1314 "require the tiled outer dimensions of the result are all 1s");
1315 }
1316
1317 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1318 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1319 Location loc = unpackOp.getLoc();
1320 Value source = unpackOp.getSource();
1321 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1322 unpackOp.getDimAndTileMapping();
1323 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1324
1325 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1326 // dims:
1327 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1328 SmallVector<int64_t> readShapeForExtractSlice;
1329 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1330 // outer-tiled-dims being all 1), this will be
1331 // [ outer-untiled-dims, tile-sizes ]
1332 SmallVector<OpFoldResult> extractSliceSizes;
1333
1334 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1335 // This should be:
1336 // [ outer-untiled-dims, tile-sizes ]
1337 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1338 // permutation.
1339 SmallVector<OpFoldResult> shapeForEmptyOp;
1340
1341 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1342 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1343 //
1344 // As all outer tiled dims are 1, so the corresponding
1345 // slice size to read will also 1. As this will be rank-reducing "extract
1346 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1347 // update:
1348 // * the output shape for ExtractSliceOp, nor
1349 // * the shape for EmptyOp.
1350 if (dimAndTileMapping.count(i)) {
1351 extractSliceSizes.push_back(oneIdxAttr);
1352 continue;
1353 }
1354
1355 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1356 // outer-untiled-dims
1357 if (ShapedType::isDynamic(srcShape[i])) {
1358 OpFoldResult dynamicDim =
1359 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1360 extractSliceSizes.push_back(dynamicDim);
1361 shapeForEmptyOp.push_back(dynamicDim);
1362 } else {
1363 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1364 if (srcShape[i] != 1)
1365 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1366 }
1367 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1368 // into account rank-reducing)
1369 if (srcShape[i] != 1) {
1370 readShapeForExtractSlice.push_back(srcShape[i]);
1371 }
1372 }
1373 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1374 // shape for EmptyOp.
1375 auto mixedTiles = unpackOp.getMixedTiles();
1376 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1377 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1378
1379 // Explicitly create the type for extract_slice op because the inner tile
1380 // size could be 1. We want to represent the whole inner tile in this case.
1381 auto tileShape = srcShape.drop_front(destRank);
1382 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1383 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1384 Type elemType = unpackOp.getSourceType().getElementType();
1385 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1386 Value innerTile = tensor::ExtractSliceOp::create(
1387 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1388
1389 // 2. Transpose the tile to match the outer corresponding tile order.
1391 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1392 // Unpack is a transition out of packed space so we invert the permutation.
1393 perm = invertPermutationVector(perm);
1394 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1395
1396 Value empty =
1397 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1398 auto transposedOp =
1399 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1400
1401 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1402 // transposed tile.
1403 SmallVector<OpFoldResult> tileSizes;
1404 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1405 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1406 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1407 tileSizes.push_back(
1408 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1409 }
1410
1411 auto partialTile =
1412 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1413 transposedOp.getResult()[0], tileSizes);
1414
1415 // 4. Insert the result to the destination tensor.
1416 SmallVector<OpFoldResult> writeSizes;
1417 for (int i = 0, idx = 0; i < destRank; ++i) {
1418 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1419 writeSizes.push_back(tileSizes[idx++]);
1420 else
1421 writeSizes.push_back(oneIdxAttr);
1422 }
1423 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1424 unpackOp.getDest(), writeSizes);
1425 rewriter.replaceOp(unpackOp, insert.getResult());
1426
1427 return success();
1428}
1429
1430//===----------------------------------------------------------------------===//
1431// Generic DownscaleSizeOneWindowedConvolution
1432//===----------------------------------------------------------------------===//
1433//
1434/// Returns the indices of affine map results that reference any of the given
1435/// dimensions.
1438 SmallVector<unsigned> resultIndices;
1439 for (unsigned dim : dims) {
1440 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1441 AffineExpr expr = map.getResult(i);
1442 if (expr.isFunctionOfDim(dim)) {
1443 resultIndices.push_back(i);
1444 break;
1445 }
1446 }
1447 }
1448 return resultIndices;
1449}
1450
1451/// Helper to create a rank-reducing extract_slice that removes specific
1452/// dimensions from a tensor.
1454 Location loc, Value tensor,
1455 ArrayRef<unsigned> dimsToRemove) {
1456 auto tensorType = cast<RankedTensorType>(tensor.getType());
1457 int64_t rank = tensorType.getRank();
1458
1459 // Compute new shape by removing the specified dimensions.
1460 SmallVector<int64_t> newShape;
1461 for (int64_t i = 0; i < rank; ++i) {
1462 if (!llvm::is_contained(dimsToRemove, i))
1463 newShape.push_back(tensorType.getDimSize(i));
1464 }
1465
1466 auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
1468 tensor, newType);
1469}
1470
1471/// Drops specified dimensions from an AffineExpr and compresses remaining
1472/// dimension indices. Returns std::nullopt if the expression only references
1473/// the dropped dimensions.
1474static std::optional<AffineExpr>
1476 unsigned newNumDims, MLIRContext *ctx) {
1477 // Check if expr only references dimensions to be dropped.
1478 bool onlyReferencesDroppedDims = true;
1479 for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1480 if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
1481 onlyReferencesDroppedDims = false;
1482 break;
1483 }
1484 }
1485 if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
1486 return expr.isFunctionOfDim(d);
1487 }))
1488 return std::nullopt;
1489
1490 // Replace dimensions: compute new index for each old dimension.
1491 // Dropped dimensions get mapped to constant 0, others get compressed.
1492 SmallVector<AffineExpr> dimReplacements;
1493 unsigned newDimIdx = 0;
1494 for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1495 if (llvm::is_contained(dimsToDrop, d)) {
1496 dimReplacements.push_back(getAffineConstantExpr(0, ctx));
1497 } else {
1498 dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
1499 }
1500 }
1501
1502 return expr.replaceDims(dimReplacements);
1503}
1504
1505FailureOr<LinalgOp>
1507 LinalgOp op) {
1508 auto maybeDims = inferConvolutionDims(op);
1509 if (failed(maybeDims))
1510 return failure();
1511
1512 // Currently supports only 2D convolutions.
1513 if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
1514 return failure();
1515
1516 if (op.hasPureBufferSemantics())
1517 return failure();
1518
1519 // Get loop domain indices for spatial dimensions.
1520 unsigned outSpatial0 = maybeDims->outputImage[0];
1521 unsigned outSpatial1 = maybeDims->outputImage[1];
1522 unsigned filterSpatial0 = maybeDims->filterLoop[0];
1523 unsigned filterSpatial1 = maybeDims->filterLoop[1];
1524
1525 // Get sizes from loop bounds.
1526 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
1527 int64_t outSize0 = loopRanges[outSpatial0];
1528 int64_t outSize1 = loopRanges[outSpatial1];
1529 int64_t filterSize0 = loopRanges[filterSpatial0];
1530 int64_t filterSize1 = loopRanges[filterSpatial1];
1531
1532 // Check if we can downscale by removing a spatial dimension.
1533 bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
1534 bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
1535 if (!canRemoveSpatial0 && !canRemoveSpatial1)
1536 return failure();
1537
1538 // Determine which loop dims to remove (output spatial + corresponding filter)
1539 // and sort for correct index compression when removing dimensions from affine
1540 // maps.
1541 SmallVector<unsigned> loopDimsToRemove;
1542 if (canRemoveSpatial0) {
1543 loopDimsToRemove.push_back(outSpatial0);
1544 loopDimsToRemove.push_back(filterSpatial0);
1545 } else {
1546 loopDimsToRemove.push_back(outSpatial1);
1547 loopDimsToRemove.push_back(filterSpatial1);
1548 }
1549 llvm::sort(loopDimsToRemove);
1550
1551 // Create new indexing maps with dimensions removed.
1552 SmallVector<AffineMap> newMaps;
1553 MLIRContext *ctx = op.getContext();
1554 unsigned numDims = op.getNumLoops();
1555 unsigned newNumDims = numDims - loopDimsToRemove.size();
1556 for (AffineMap map : op.getIndexingMapsArray()) {
1557 SmallVector<AffineExpr> newResults;
1558 for (AffineExpr expr : map.getResults()) {
1559 auto newExpr =
1560 dropDimsAndCompress(expr, loopDimsToRemove, newNumDims, ctx);
1561 if (newExpr)
1562 newResults.push_back(*newExpr);
1563 }
1564 newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
1565 }
1566
1567 // Create new iterator types.
1569 auto iterTypes = op.getIteratorTypesArray();
1570 for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
1571 if (!llvm::is_contained(loopDimsToRemove, idx))
1572 newIterTypes.push_back(iterTypes[idx]);
1573 }
1574
1575 // Rank-reduce operands using extract_slice.
1576 Location loc = op.getLoc();
1577 SmallVector<Value> newInputs;
1578 for (OpOperand *input : op.getDpsInputOperands()) {
1579 AffineMap map = op.getMatchingIndexingMap(input);
1580 SmallVector<unsigned> tensorDimsToRemove =
1581 getResultIndicesReferencingDims(map, loopDimsToRemove);
1582 Value reduced = createRankReducingExtractSlice(rewriter, loc, input->get(),
1583 tensorDimsToRemove);
1584 newInputs.push_back(reduced);
1585 }
1586
1587 OpOperand &output = *op.getDpsInitsMutable().begin();
1588 AffineMap outputMap = op.getMatchingIndexingMap(&output);
1589 SmallVector<unsigned> outputDimsToRemove =
1590 getResultIndicesReferencingDims(outputMap, loopDimsToRemove);
1591 Value newOutput = createRankReducingExtractSlice(rewriter, loc, output.get(),
1592 outputDimsToRemove);
1593
1594 // Create new linalg.generic with reduced dimensions.
1595 auto newOp =
1596 linalg::GenericOp::create(rewriter, loc, TypeRange{newOutput.getType()},
1597 newInputs, newOutput, newMaps, newIterTypes);
1598 rewriter.inlineRegionBefore(op->getRegion(0), newOp.getRegion(),
1599 newOp.getRegion().begin());
1600
1601 // Try to specialize the generic back to a named op only if the input was
1602 // already a specialized (named) op.
1603 LinalgOp resultOp = newOp;
1604 if (!isa<GenericOp>(op)) {
1605 FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
1606 if (succeeded(specializedOp))
1607 resultOp = *specializedOp;
1608 }
1609
1610 // Insert result back into original shape.
1612 rewriter, loc, resultOp->getResult(0), output.get());
1613
1614 rewriter.replaceOp(op, result);
1615 return resultOp;
1616}
1617
1618namespace {
1619/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
1620struct DownscaleSizeOneWindowedConvolution final
1621 : public OpInterfaceRewritePattern<LinalgOp> {
1622 DownscaleSizeOneWindowedConvolution(MLIRContext *context,
1623 PatternBenefit benefit = 1)
1624 : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
1625
1626 LogicalResult matchAndRewrite(LinalgOp op,
1627 PatternRewriter &rewriter) const override {
1629 }
1630};
1631} // namespace
1632
1634 PatternBenefit benefit) {
1635 patterns.add<DownscaleSizeOneWindowedConvolution>(patterns.getContext(),
1636 benefit);
1637}
1638
1643
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
static std::optional< AffineExpr > dropDimsAndCompress(AffineExpr expr, ArrayRef< unsigned > dimsToDrop, unsigned newNumDims, MLIRContext *ctx)
Drops specified dimensions from an AffineExpr and compresses remaining dimension indices.
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static Value createRankReducingExtractSlice(RewriterBase &rewriter, Location loc, Value tensor, ArrayRef< unsigned > dimsToRemove)
Helper to create a rank-reducing extract_slice that removes specific dimensions from a tensor.
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static SmallVector< unsigned > getResultIndicesReferencingDims(AffineMap map, ArrayRef< unsigned > dims)
Returns the indices of affine map results that reference any of the given dimensions.
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition AffineMap.h:315
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
result_range getResults()
Definition Operation.h:440
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition Transforms.h:204
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Definition Transforms.h:194
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.