MLIR 23.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
38#include <type_traits>
39#include <utility>
40
41#define DEBUG_TYPE "linalg-transforms"
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46//===----------------------------------------------------------------------===//
47// Transformations exposed as functional-style API calls.
48//===----------------------------------------------------------------------===//
49
50//===----------------------------------------------------------------------===//
51// peelLoop transformation.
52//===----------------------------------------------------------------------===//
53
54/// Try to peel and canonicalize loop `op` and return the new result.
55/// Also applies affine_min/max bounds simplification on the fly where relevant.
56// TODO: Add support for scf.parallel and affine.for loops.
58 Operation *op) {
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
62 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
63 partialIteration)))
64 return partialIteration->getResults();
65 assert(!partialIteration && "expected that loop was not peeled");
66 return forOp->getResults();
67 })
68 .Default([&](Operation *op) { return op->getResults(); });
69}
70
71/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
72/// where relevant.
75 for (auto loopOp : loops)
76 peelLoop(rewriter, loopOp);
77}
78
79//===----------------------------------------------------------------------===//
80// pack transformation.
81//===----------------------------------------------------------------------===//
82
83#ifndef NDEBUG
84/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
86 bool found = false;
87 for (AffineExpr e : map.getResults()) {
88 if (!e.isFunctionOfDim(dim))
89 continue;
90 if (found)
91 return false;
92 found = true;
93 }
94 return true;
95}
96#endif // NDEBUG
97
99 return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
100}
101
102/// Return the index of the first result of `map` that is a function of
103/// AffineDimExpr(dim), std::nullopt otherwise.
104static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
105 int64_t dim) {
106 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
107 AffineExpr expr = map.getResult(i);
108 if (!expr.isFunctionOfDim(dim))
109 continue;
110 return i;
111 }
112 return std::nullopt;
113}
114
115/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
116/// `newDim` at `iteratorTypes.size()` by:
117/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
118/// 2. Appending a `newDim` to the domain of every indexing map.
119/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
120/// by potentially adding a `newDim` result to `map`.
121/// The preserved invariant is that `iteratorTypes.size()` is always equal to
122/// `map.getNumDims()` for every map in `indexingMaps`.
123///
124/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
125/// Return a vector that records the optional packing for each operand.
126/// Return failure if the packed indexing cannot be represented with a LinalgOp.
127///
128/// Further details:
129/// ================
130/// The current implementation of packing (i.e. data tiling) consists of
131/// rewriting a linearized strip-mined form into a higher-dimensional access.
132/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
133/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
134/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
135///
136/// This rewrite into higher dimensional access is not possible for general
137/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
138/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
139/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
140/// The rewrite of the access would be a form not representable in Linalg:
141/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
142/// Note however that as `J` and `ii` iterate, the accesses do not have a
143/// particular alignment, so packing does not achieve alignment in this case
144///
145/// In the future, we may want to consider a mixed-form that allows some
146/// alignment in the presence of multiple accesses:
147/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
148/// And would rewrite accesses as:
149/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
150static FailureOr<SmallVector<std::optional<int64_t>>>
153 int64_t dim) {
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
156
157 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
158 indexingMaps.size(), std::nullopt);
160 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
161 ++operandIdx) {
162 AffineMap map = indexingMaps[operandIdx];
163
164 // Add the `newDim` to map whatever the case.
165 assert(map.getNumDims() == newDim && "num dims invariant violation");
166 map = map.shiftDims(1, newDim);
167
168 // Get the at-most-1 index of the result that is a function of `dim`.
169 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
170 // logically chunks dimension `dim` into `K * dim + newDim`, where the
171 // packing factor `K` is specified separately.
172 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
173 "num results invariant violation");
174 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
177 continue;
178 }
179
180 // We can only pack AffineDimExpr atm.
181 if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
182 return failure();
183
184 // Add `newDim` to the results of the map.
185 map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
186 map.getNumResults());
187 newMaps.push_back(map);
188
189 // Record the that `operandIdx` is packed.
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
191 }
192 indexingMaps = newMaps;
193
194 return packedDimPerIndexingMap;
195}
196
197namespace {
198
199/// Helper struct to encode packing along one dimension of a LinalgOp.
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
203};
204
205/// Helper struct to encode packing along all dimensions of a LinalgOp.
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
209 }
210 /// Return all the dims that have been packed for operand @ `operandPos`.
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
212 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
214
215private:
216 SmallVector<PackedOperandsDim> spec;
217};
218
219} // namespace
220
221FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
224 // TODO: Support Memref PackOp. Temporarily return failure.
225 if (!packOp.hasPureTensorSemantics())
226 return failure();
227
228 // 1. Filter out NYI cases.
229 auto packedTensorType =
230 cast<RankedTensorType>(packOp->getResultTypes().front());
231 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
232 return rewriter.notifyMatchFailure(
233 packOp,
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 }
236
237 Location loc = packOp->getLoc();
238 OpBuilder::InsertionGuard g(rewriter);
239 rewriter.setInsertionPoint(packOp);
240
241 // 2. Compute the permutation vector to shuffle packed shape into the shape
242 // before any outer or inner permutations have been applied.
243 PackingMetadata packingMetadata;
244 SmallVector<int64_t> packedToStripMinedShapePerm =
245 getPackInverseDestPerm(packOp, packingMetadata);
246
247 // 3. Compute the stripMinedShape: this is the packed shape before any outer
248 // or inner permutations have been applied.
249 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
250 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
251
252 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
253 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
254 rewriter.getIndexAttr(0));
255 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
256 rewriter.getIndexAttr(0));
257 for (auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
259 int outerPos =
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
261 OpFoldResult origSize =
262 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
263 OpFoldResult outerSize =
264 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
265 AffineExpr s0, d0, d1;
266 bindDims(rewriter.getContext(), d0, d1);
267 bindSymbols(rewriter.getContext(), s0);
268 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
270 rewriter, loc, map, {outerSize, origSize, innerSize});
271 }
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
273 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
276 if (!paddingValue) {
277 paddingValue = arith::ConstantOp::create(
278 rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
279 }
280 auto padOp =
281 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue, /*nofold=*/false);
283
284 LDBG() << "insertPositions: "
285 << llvm::interleaved(packingMetadata.insertPositions);
286 LDBG() << "outerPositions: "
287 << llvm::interleaved(packingMetadata.outerPositions);
288 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
289 LDBG() << "packedToStripMinedShapePerm: "
290 << llvm::interleaved(packedToStripMinedShapePerm);
291 LDBG() << "reassociations: "
292 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
294 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 LDBG() << "collapsed type: " << collapsed;
296
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
298 // Pack ops which operate as simple pads may not produce legal
299 // tensor.insert_slice operations when the packed type does not rank reduce
300 // to the padded type.
301 SliceVerificationResult rankReduces =
302 isRankReducedType(packedTensorType, padOp.getResultType());
303
304 if (rankReduces == SliceVerificationResult::Success) {
305 // This pack is just a plain pad.
306 // Just insert the pad in the higher ranked tensor.
307 // Offsets.
308 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
309 rewriter.getIndexAttr(0));
310 // Strides.
311 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
312 rewriter.getIndexAttr(1));
314 tensor::getMixedSizes(rewriter, loc, packOp.getDest());
315
316 auto insertSliceOp = tensor::InsertSliceOp::create(
317 rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
318 /*offsets=*/zeros, sizes, /*strides=*/ones);
319
320 LDBG() << "insert_slice op: " << insertSliceOp;
321
322 rewriter.replaceOp(packOp, insertSliceOp->getResults());
323
324 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
325 /*transposeOp=*/nullptr};
326 }
327 }
328
329 // 5. Expand from the padded result to the stripMinedShape.
330 auto expandShapeResultType =
331 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
332 auto reshapeOp = tensor::ExpandShapeOp::create(
333 rewriter, loc, expandShapeResultType, padOp.getResult(),
334 packingMetadata.reassociations);
335
336 // 6. Transpose stripMinedShape to packedShape.
337 SmallVector<int64_t> transpPerm =
338 invertPermutationVector(packedToStripMinedShapePerm);
339 auto transposeOp = linalg::TransposeOp::create(
340 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
341
342 LDBG() << "reshape op: " << reshapeOp;
343 LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
344 LDBG() << "transpose op: " << transposeOp;
345
346 // 7. Replace packOp by transposeOp.
347 rewriter.replaceOp(packOp, transposeOp->getResults());
348
349 return LowerPackResult{padOp, reshapeOp, transposeOp};
350}
351
352FailureOr<LowerUnPackOpResult>
353linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
354 bool lowerUnpadLikeWithExtractSlice) {
355 // TODO: Support Memref UnPackOp. Temporarily return failure.
356 if (!unPackOp.hasPureTensorSemantics())
357 return failure();
358
359 Location loc = unPackOp->getLoc();
360 OpBuilder::InsertionGuard g(rewriter);
361 rewriter.setInsertionPoint(unPackOp);
362
363 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
364 int64_t packedRank = packedTensorType.getRank();
365
366 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
369 // This unpack is just a plain unpad.
370 // Just extract the slice from the higher ranked tensor.
371 ArrayRef<int64_t> destShape = destTensorType.getShape();
372 // The inner dimensions stay the same as the destination tensor, but the
373 // outer ones are additional 1s.
374 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
375 sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
376
377 auto extractSliceOp = tensor::ExtractSliceOp::create(
378 rewriter, loc, destTensorType, unPackOp.getSource(),
379 SmallVector<OpFoldResult>(packedRank, zero), sizes,
380 SmallVector<OpFoldResult>(packedRank, one));
381
382 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
383
384 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
385 /*reshapeOp=*/nullptr, extractSliceOp,
386 /*copyOp=*/nullptr};
387 }
388
389 // 1. Compute the permutation vector to shuffle packed shape into the shape
390 // before any outer or inner permutations have been applied.
391 PackingMetadata packingMetadata;
392 SmallVector<int64_t> packedToStripMinedShapePerm =
393 getUnPackInverseSrcPerm(unPackOp, packingMetadata);
394
395 // 2. Compute the stripMinedShape: this is the packed shape without outer and
396 // inner permutations.
397 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
398 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
399
400 // 3. Transpose packedShape to stripMinedShape.
401 RankedTensorType stripMinedTensorType =
402 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
403 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
404 stripMinedTensorType, packingMetadata.reassociations);
405
406 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
407 // permutation.
409 tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
410 applyPermutationToVector(dims, packedToStripMinedShapePerm);
411 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
412 stripMinedTensorType.getElementType());
413 auto transposeOp =
414 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
415 packedToStripMinedShapePerm);
416
417 LDBG() << "insertPositions: "
418 << llvm::interleaved(packingMetadata.insertPositions);
419 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
420 LDBG() << "packedToStripMinedShapePerm: "
421 << llvm::interleaved(packedToStripMinedShapePerm);
422 LDBG() << "reassociations: "
423 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
425 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
426 LDBG() << "collapsed type: " << collapsedType;
427
428 // 4. Collapse from the stripMinedShape to the padded result.
429 auto reshapeOp = tensor::CollapseShapeOp::create(
430 rewriter, loc, collapsedType, transposeOp->getResult(0),
431 packingMetadata.reassociations);
432
433 // 5. ExtractSlice.
434 int64_t destRank = destTensorType.getRank();
435 auto extractSliceOp = tensor::ExtractSliceOp::create(
436 rewriter, loc, destTensorType, reshapeOp->getResult(0),
437 SmallVector<OpFoldResult>(destRank, zero),
438 tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
439 SmallVector<OpFoldResult>(destRank, one));
440
441 // 6. Inject a copy to preserve DPS.
442 auto copyOp = linalg::CopyOp::create(
443 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
444
445 // 7. Replace unPackOp by copyOp.
446 rewriter.replaceOp(unPackOp, copyOp->getResults());
447
448 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp,
449 copyOp};
450}
451
453PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
455 for (auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
457 continue;
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
459 }
460 return res;
461}
462
463SmallVector<OpFoldResult>
464PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
465 SmallVector<OpFoldResult> res;
466 for (auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
468 continue;
469 res.push_back(i.packedSize);
470 }
471 return res;
472}
473
474/// Implement packing of a single LinalgOp by performing packing by
475/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
476/// Return the packed Linalg op on success, failure otherwise.
477FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
478 linalg::LinalgOp linalgOp,
479 ArrayRef<OpFoldResult> packedSizes) {
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
481 return rewriter.notifyMatchFailure(linalgOp,
482 "incorrect number of pack sizes");
483 }
484
485 Location loc = linalgOp->getLoc();
486 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
488 linalgOp.getIteratorTypesArray();
489 LDBG() << "Start packing: " << linalgOp;
490 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
491 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
492
495 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
496 PackedOperandsDimList listOfPackedOperandsDim;
497 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
498 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
499 // Skip tile sizes explicitly set to 0.
500 if (maybeConstant.has_value() && maybeConstant.value() == 0)
501 continue;
502
503 PackedOperandsDim packedOperandsDims;
504 packedOperandsDims.packedSize = packedSizes[i];
505 FailureOr<SmallVector<std::optional<int64_t>>>
506 maybePackedDimForEachOperand =
507 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
508 if (failed(maybePackedDimForEachOperand))
509 return failure();
510 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
511
512 LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
513 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
514 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
515 LDBG() << "packedDimForEachOperand: "
516 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
517
518 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
519 }
520
521 // Step 2. Propagate packing to all LinalgOp operands.
522 SmallVector<Value> inputsAndInits, results;
523 SmallVector<OpOperand *> initOperands =
524 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
525 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
526 for (const auto &operandsList : {inputOperands, initOperands}) {
527 for (OpOperand *opOperand : operandsList) {
528 int64_t pos = opOperand->getOperandNumber();
529 Value operand = opOperand->get();
530 SmallVector<int64_t> innerPos =
531 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
532 SmallVector<OpFoldResult> innerPackSizes =
533 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
534 LDBG() << "operand: " << operand;
535 LDBG() << "innerPos: " << llvm::interleaved(innerPos);
536 LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
537 if (innerPackSizes.empty()) {
538 inputsAndInits.push_back(operand);
539 continue;
540 }
541 Value dest = linalg::PackOp::createDestinationTensor(
542 rewriter, loc, operand, innerPackSizes, innerPos,
543 /*outerDimsPerm=*/{});
544 ShapedType operandType = cast<ShapedType>(operand.getType());
545 bool areConstantTiles =
546 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
547 return getConstantIntValue(tile).has_value();
548 });
549 if (areConstantTiles && operandType.hasStaticShape() &&
550 !linalg::PackOp::requirePaddingValue(
551 operandType.getShape(), innerPos,
552 cast<ShapedType>(dest.getType()).getShape(), {},
553 innerPackSizes)) {
554 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
555 innerPos, innerPackSizes));
556 } else {
557 // TODO: value of the padding attribute should be determined by
558 // consumers.
559 auto zeroAttr =
560 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
561 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
562 packOps.push_back(linalg::PackOp::create(
563 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
564 }
565 inputsAndInits.push_back(packOps.back().getResult());
566 }
567 }
568
569 // Step 3. Build the packed op, use the type of `inits` as result types.
570 ValueRange inputs =
571 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
572 ValueRange inits =
573 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
574 auto packedLinalgOp =
575 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
576 inputs, inits, indexingMaps, iteratorTypes);
577 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
578
579 // Step 4. Propagate packing to all the op results.
580 for (OpResult result : packedLinalgOp->getResults()) {
581 int64_t resultNum = result.getResultNumber();
582 linalg::PackOp maybePackedInit =
583 inits[resultNum].getDefiningOp<linalg::PackOp>();
584 if (!maybePackedInit) {
585 results.push_back(result);
586 continue;
587 }
588 // Build the symmetrical UnPackOp to the existing PackOp.
589 unPackOps.push_back(linalg::UnPackOp::create(
590 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
591 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
592 results.push_back(unPackOps.back().getResult());
593 }
594
595 // Step 5. Replace `linalgOp`.
596 rewriter.replaceOp(linalgOp, results);
597
598 // Return packedLinalgOp.
599 return PackResult{packOps,
600 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
601 unPackOps};
602}
603
604//===----------------------------------------------------------------------===//
605// packTranspose transformation.
606//===----------------------------------------------------------------------===//
607
608/// Return a copy of `tensorType` after permutation by `permutationVector`.
609// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
610// but this would introduce a dependence on Dialect in IR.
611// TODO: Restructure.
612static RankedTensorType permuteShape(RankedTensorType tensorType,
613 ArrayRef<int64_t> permutationVector) {
614 SmallVector<int64_t> shape(tensorType.getShape());
615 applyPermutationToVector(shape, permutationVector);
616 return RankedTensorType::Builder(tensorType).setShape(shape);
617}
618
619/// Return a new GenericOp obtained by transposing opOperand by the permutation
620/// vector:
621/// - the corresponding indexing map is transposed by `permutation`
622/// - the corresponding operand value is replaced by `transposedValue`
623/// `linalgOp` is replaced by the return op in the process.
624/// Asserts that `transposedValue` is of the proper transposed ShapedType.
626 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
627 ArrayRef<int64_t> permutation, Value transposedValue) {
628 // Sanity check the operand.
629 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
630
631 // Sanity check of the expected transposed tensor type.
632 auto tensorType = permuteShape(
633 cast<RankedTensorType>(opOperand.get().getType()), permutation);
634 (void)tensorType;
635 assert(tensorType == transposedValue.getType() &&
636 "expected tensor type mismatch");
637
638 // Compute the transposed indexing map.
639 // Sigh unsigned pollution.
640 SmallVector<unsigned> tmpTransposition =
641 llvm::map_to_vector(permutation, [](int64_t i) -> unsigned { return i; });
642 AffineMap permutationMap =
643 AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
644 AffineMap transposedMap =
645 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
646
647 // Set the transposed indexing map in the proper position.
648 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
649 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
650 // Set the transposedValue in the proper operand position.
651 SmallVector<Value> operands = linalgOp->getOperands();
652 operands[opOperand.getOperandNumber()] = transposedValue;
653
654 ValueRange operandsRef(operands);
655 auto transposedGenericOp = linalg::GenericOp::create(
656 rewriter,
657 /*location=*/linalgOp->getLoc(),
658 /*resultTensorTypes=*/
659 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
660 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
661 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
662 /*indexingMaps=*/indexingMaps,
663 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
664 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
665 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
666
667 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
668}
669
670FailureOr<PackTransposeResult>
671linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
672 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
673 ArrayRef<int64_t> outerPerm,
674 ArrayRef<int64_t> innerPerm) {
675 Location loc = linalgOp.getLoc();
676
677 // Step 1. Transpose packOp.
678 rewriter.setInsertionPoint(packOp);
679 linalg::PackOp transposedPackOp =
680 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
681
682 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
683 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
684
685 OpOperand &packUse = *packOp->getUses().begin();
686 if (packUse.getOwner() != linalgOp) {
687 return rewriter.notifyMatchFailure(
688 linalgOp, "not a single use by the LinalgOp target");
689 }
690 if (maybeUnPackOp &&
691 (!linalgOp.isDpsInit(&packUse) ||
692 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
693 return rewriter.notifyMatchFailure(linalgOp,
694 "not produced by the LinalgOp target");
695 }
696
697 // Step 2. Transpose linalgOp.
698 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
699 // identity. Don't rely on it.
700 int64_t numLeadingDims = packOp.getSourceRank();
701 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
702 // Step 2.a. Compute the permutation on the whole operand.
703 // Leading part just reuse the outerPerm.
704 SmallVector<int64_t> permutation(outerPerm);
705 if (permutation.empty())
706 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
707 // Trailing part needs to reindex positions by `numLeadingDims`.
708 if (innerPerm.empty()) {
709 llvm::append_range(
710 permutation,
711 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
712 } else {
713 llvm::append_range(permutation,
714 llvm::map_range(innerPerm, [&](int64_t pos) {
715 return numLeadingDims + pos;
716 }));
717 }
718 if (!isPermutationVector(permutation))
719 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
720
721 // Step 2.b. Save the transposedPackUse operand number in case we need to
722 // get the tied OpResult after `linalgOp` has been replaced.
723 int64_t packUseOperandNumber = packUse.getOperandNumber();
724 // Step 2.c. Actually perform the transposition.
725 rewriter.setInsertionPoint(linalgOp);
726 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
727 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
728
729 // Step 3. Maybe transpose unPackOp.
730 linalg::UnPackOp transposedUnPackOp;
731 if (maybeUnPackOp) {
732 OpOperand &opOperand =
733 transposedLinalgOp->getOpOperand(packUseOperandNumber);
734 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
735 rewriter.setInsertionPoint(maybeUnPackOp);
736 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
737 rewriter, loc, transposedResult, innerPerm, outerPerm);
738
739 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
740 }
741
742 // Step 4. Finally, replace packOp now that we don't need it anymore.
743 if (packOp.hasPureTensorSemantics())
744 rewriter.replaceOp(packOp, transposedPackOp->getResults());
745 else
746 rewriter.eraseOp(packOp);
747
748 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
749 transposedUnPackOp};
750}
751
752//===----------------------------------------------------------------------===//
753// packMatmulGreedily transformation.
754//===----------------------------------------------------------------------===//
755
756/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
757/// and n are proper parallel dimensions and k is a proper reduction
758/// dimension. Packing occurs by rewriting the op as a linalg.generic and
759/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
760/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
761/// to reorder {m, n, k} into one of the 8 possible forms. The outer
762/// dimensions of the operands are not permuted at this time, this is left for
763/// future work.
764FailureOr<PackResult>
765linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
766 ArrayRef<OpFoldResult> mnkPackedSizes,
767 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
768 ArrayRef<int64_t> mnkOrder) {
769 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
770 assert((mnkPaddedSizesNextMultipleOf.empty() ||
771 mnkPaddedSizesNextMultipleOf.size() == 3) &&
772 "num of packing sizes next multiple should be empty or of size 3");
773 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
774 assert(isPermutationVector(mnkOrder) && "expected a permutation");
775
776 int64_t numLoops = linalgOp.getNumLoops();
777 if (numLoops <= 2) {
778 LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
779 << " in: " << linalgOp;
780 return rewriter.notifyMatchFailure(
781 linalgOp, "need 3+ loops to find a matmul to pack");
782 }
783
784 // Locally adjust the desired iterator position of mnk and packing sizes.
785 int64_t numPackedDims = mnkPackedSizes.size();
786 SmallVector<int64_t> mmnnkkPos(numPackedDims);
787 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
788 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
789 SmallVector<OpFoldResult> packedSizes(numPackedDims);
790 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
792 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
793 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
794 paddedSizesNextMultipleOf[mnkOrder[i]] =
795 mnkPaddedSizesNextMultipleOf.empty() ? 0
796 : mnkPaddedSizesNextMultipleOf[i];
797 }
798
799 // 1. Infer dims that are important for matmul.
800 FailureOr<ContractionDimensions> maybeDimensions =
801 inferContractionDims(linalgOp);
802 if (failed(maybeDimensions)) {
803 LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
804 return rewriter.notifyMatchFailure(linalgOp,
805 "couldn't infer matmul iterators");
806 }
807
808 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
809 // minor iterators. In cases with multiple options for m, n, k bias towards
810 // the most minor embedding.
811 // If we wanted a different normalization order, this is where it would have
812 // to plug a heuristic.
813 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
814 kPos = maybeDimensions->k.back();
815 LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
816 << nPos << ", k@" << kPos << "): " << linalgOp;
817
818 // 2.a. Rewrite as a generic.
819 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
820 if (!genericOp) {
821 FailureOr<GenericOp> generalizeResult =
822 generalizeNamedOp(rewriter, linalgOp);
823 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
824 genericOp = *generalizeResult;
825 }
826
827 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
828 // iterators. Note that this only normalized the iteration order and does
829 // not change the indexings of any operand.
830 SmallVector<int64_t> permutation =
831 computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
832 LDBG() << "perm: " << llvm::interleaved(permutation);
833 // Sign .. unsigned pollution.
834 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
835 FailureOr<GenericOp> interchangeResult =
836 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
837 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
838 genericOp = *interchangeResult;
839 LDBG() << "Generalized Op to pack: " << genericOp;
840
841 // At this point, the op iterators are normalized to {leading, k, m, n}.
842 // The layouts induced by packing will always be:
843 // - LHS{leading_lhs, kk, mm}
844 // - RHS{leading_rhs, kk, nn}
845 // - RES{leading_res, mm, nn}
846 // If we wanted to change the packed order, we would reorder (k, m, n) to
847 // something else above.
848 //
849 // Additional permutations of the outer dims of the operands (i.e.
850 // leading_lhs, leading_rhs and leading_res) could follow by computing the
851 // desired outerPerm for each operand.
852 // This is left for future work.
853
854 // TODO: this creates too much IR, go use reifyResultShapes.
855 SmallVector<Range, 4> loopRanges =
856 cast<LinalgOp>(genericOp.getOperation())
857 .createLoopRanges(rewriter, genericOp.getLoc());
858
859 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
860 // post interchange.
861 LDBG() << "paddedSizesNextMultipleOf: "
862 << llvm::interleaved(paddedSizesNextMultipleOf);
863 LDBG() << "loopRanges: "
864 << llvm::interleaved(
865 llvm::map_range(loopRanges, [](Range r) { return r.size; }));
866 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
867 rewriter.getIndexAttr(0));
868 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
869 if (paddedSizesNextMultipleOf[i] == 0) {
870 adjustedPackedSizes.push_back(packedSizes[i]);
871 continue;
872 }
873 AffineExpr d0, s0;
874 bindDims(rewriter.getContext(), d0);
875 bindSymbols(rewriter.getContext(), s0);
876 adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
877 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
878 {loopRanges[adjustedPackedSizes.size()].size,
879 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
880 }
881 LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
882
883 // TODO: If we wanted to give the genericOp a name after packing, after
884 // calling `pack` would be a good time. One would still need to check that
885 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
886 // also allow degenerate matmul cases (i.e. matvec, dot).
887 return pack(rewriter, genericOp, adjustedPackedSizes);
888}
889
890//===----------------------------------------------------------------------===//
891// Transformations exposed as rewrite patterns.
892//===----------------------------------------------------------------------===//
893
896 assert(!tileSizeComputationFunction && "tile sizes already set");
897 SmallVector<int64_t, 4> tileSizes(ts);
898 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
900 b.setInsertionPointToStart(
901 &op->getParentOfType<func::FuncOp>().getBody().front());
902 return llvm::map_to_vector<4>(tileSizes, [&](int64_t s) {
903 Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
904 return v;
905 });
906 };
907 return *this;
908}
909
911 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
912 return vectorizeCopy(rewriter, copyOp);
913}
914
915/// Filling `dest` using FillOp constant padding value if possible.
916/// Otherwise, generate a tensor::GenerateOp.
918 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
919 const SmallVector<Value> &dynSizes) const {
920 auto padValue = padOp.getConstantPaddingValue();
921 if (padValue) {
922 // Move the padding value defined inside the PadOp block to outside.
923 if (padValue.getParentBlock() == &padOp.getRegion().front())
924 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
925 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
926 }
927
928 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
929 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
930 padOp.getResultType(), dynSizes);
931 // Copy region to new op.
932 IRMapping bvm;
933 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
934 return generateOp;
935}
936
937LogicalResult
939 PatternRewriter &rewriter) const {
940 // Given an OpFoldResult, return an index-typed value.
941 auto getIdxValue = [&](OpFoldResult ofr) {
942 if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
943 return val;
945 rewriter, padOp.getLoc(),
946 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
947 .getResult();
948 };
949
950 auto resultType = padOp.getResultType();
951 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
952 SmallVector<Value> dynSizes;
953 SmallVector<int64_t> staticSizes;
954 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
955 if (resultType.isDynamicDim(dim)) {
956 auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
957 padOp.getSource(), dim));
958 // Add low and high padding value.
959 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
960 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
961 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
962 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
963 dynSizes.push_back(plusHigh);
964 }
965 staticSizes.push_back(resultType.getDimSize(dim));
966 }
967
968 // Init tensor and fill it with padding.
969 Value emptyTensor =
970 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
971 resultType.getElementType(), dynSizes);
972 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
973
974 // Generate a InsertSliceOp for copying the PadOp source.
975 auto sourceType = padOp.getSourceType();
976 // Compute size of source of tensor::PadOp.
978 tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
979 // Strides of InsertSliceOp are all 1.
980 SmallVector<OpFoldResult> strides(sourceType.getRank(),
981 rewriter.getIndexAttr(1));
982 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
983 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
984 strides);
985
986 return success();
987}
988
990 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
991 if (!sliceOp.hasUnitStride())
992 return failure();
993
994 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
995 if (!padOp)
996 return failure();
997
998 bool zeroSliceGuard = true;
999 if (controlFn) {
1000 if (std::optional<bool> control = controlFn(sliceOp))
1001 zeroSliceGuard = *control;
1002 else
1003 return failure();
1004 }
1005
1006 FailureOr<TilingResult> tilingResult =
1007 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1008 sliceOp.getMixedSizes(), zeroSliceGuard);
1009 if (failed(tilingResult))
1010 return failure();
1011
1012 RankedTensorType sourceType = sliceOp.getSourceType();
1013 RankedTensorType resultType = sliceOp.getResultType();
1014
1015 // If the extract_slice is not rank-reduced, all shapes are static and the
1016 // data source is actually used. Rewrite into pad(extract_slice(x)).
1017 if (sourceType.getRank() == resultType.getRank()) {
1018 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1019 return success();
1020 }
1021
1022 // Handle rank-reduced slice by creating another extract_slice op.
1024 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1025
1026 rewriter.replaceOp(sliceOp, rankReduced);
1027 return success();
1028}
1029
1030/// If padding value is set, returns a tensor.pad Op for the source tensor,
1031/// with the output shape matching the output of `packOp`. Otherwise, returns
1032/// the source directly.
1033///
1034/// This method assumes that all outer dims for this pack Op are 1.
1036 linalg::PackOp packOp) {
1037 Value input = packOp.getSource();
1038 // TODO: Support Memref PackOp. Temporarily return just Op Source.
1039 if (!packOp.hasPureTensorSemantics())
1040 return input;
1041
1042 if (!packOp.getPaddingValue()) {
1043 return input;
1044 }
1045
1046 assert(llvm::all_of(packOp.getAllOuterDims(),
1047 [](int64_t val) { return val == 1; }) &&
1048 "some outer dims are != 1");
1049
1050 Location loc = packOp.getLoc();
1051 ShapedType inputType = packOp.getSourceType();
1052 int64_t inputRank = inputType.getRank();
1053
1054 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1055 packOp.getDimAndTileMapping();
1056
1057 // The sizes of dynamic tiles
1058 SmallVector<Value> dynamicTileSizes;
1059
1060 // Collect dims for the padded shape.
1061 SmallVector<int64_t> paddedShape;
1062 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1063 // 1. Non-tiled outer dims.
1064 // These dims should be 1 and we simply preserve them.
1065 if (!tileAndPosMapping.count(dimIdx)) {
1066 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1067 assert(inputDimSize == 1 &&
1068 "with all outer dims == 1, this non-tiled input dim should be 1!");
1069 paddedShape.push_back(inputDimSize);
1070 continue;
1071 }
1072
1073 // 2. Tiled outer dims
1074 // As all outer dims == 1, it is safe to use the tile size for the padded
1075 // shape.
1076 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1077
1078 // 2.1 Static tile sizes
1079 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1080 if (cstTileSize.has_value()) {
1081 paddedShape.push_back(cstTileSize.value());
1082 continue;
1083 }
1084
1085 // 2.2 Dynamic tile sizes
1086 paddedShape.push_back(ShapedType::kDynamic);
1087
1088 // Get the value that holds the dynamic size.
1089 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1090 }
1091 auto resultType =
1092 RankedTensorType::get(paddedShape, inputType.getElementType());
1093 return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1094 /*nofold=*/false, loc, builder,
1095 dynamicTileSizes);
1096}
1097
1098// Normalizes a permutation on a higher rank space to its actual size, e.g.
1099// perm = [1, 4, 2]
1100// becomes
1101// norm = [0, 2, 1]
1102static SmallVector<int64_t>
1104 constexpr int64_t kNonTiledMarker = -1;
1105 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1106 for (auto [index, value] : llvm::enumerate(perm))
1107 vec[value] = index;
1108 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1109 vec, [&](int64_t v) { return v != kNonTiledMarker; });
1110 // This inverts the permutation in addition to normalizing so invert back.
1111 return invertPermutationVector(normalizedPerm);
1112}
1113
1114// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1115// assuming rank reduction of unit outer dims.
1116static SmallVector<int64_t>
1118 ArrayRef<int64_t> innerDimsPos,
1119 ArrayRef<int64_t> outerDimsPerm) {
1120 SmallVector<int64_t> rankReducedOuterDimsPerm;
1121 SmallVector<int64_t> outerDims;
1122 SmallVector<int64_t> innerDims;
1123 int64_t dim = 0;
1124 int64_t unpackedRank = shape.size();
1125 for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1126 if (llvm::is_contained(innerDimsPos, i)) {
1127 innerDims.push_back(dim++);
1128 continue;
1129 }
1130 if (shape[i] == 1)
1131 continue;
1132 outerDims.push_back(dim++);
1133 if (!outerDimsPerm.empty())
1134 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1135 }
1136
1137 // Get the position of the inner dims after permutation.
1138 SmallVector<int64_t> innerPerm =
1139 getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1140 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1141
1142 // Ditto for the outer dims.
1143 SmallVector<int64_t> perm = outerDims;
1144
1145 rankReducedOuterDimsPerm =
1146 getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1147 if (!rankReducedOuterDimsPerm.empty())
1148 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1149
1150 // The tile always ends up as the inner most dims after packing.
1151 perm.append(innerDims);
1152
1153 return perm;
1154}
1155
1157 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1158 // TODO: Support Memref PackOp. Temporarily return failure.
1159 if (!packOp.hasPureTensorSemantics())
1160 return failure();
1161
1162 if (llvm::any_of(packOp.getTiledOuterDims(),
1163 [](int64_t dim) { return dim != 1; })) {
1164 return rewriter.notifyMatchFailure(
1165 packOp, "not all outer dimensions of the result are 1s");
1166 }
1167
1168 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1169 auto outerDimsPerm = packOp.getOuterDimsPerm();
1170
1171 // Verify that there are no:
1172 // * non-unit + un-tiled-outer-dims,
1173 // that are permuted. Supporting such cases would require refining the logic
1174 // that generates the Transpose Op.
1175 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1176 static int prev = 0;
1177 // Skip tiled dims - these can be permuted.
1178 if (llvm::is_contained(innerDimsPos, dim))
1179 return true;
1180
1181 // Check whether this dim has been permuted. Permuting unit dims is fine
1182 // as that's effectively a no-op.
1183 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1184 packOp.getResult().getType().getShape()[dim] != 1))
1185 return false;
1186
1187 prev = dim;
1188 return true;
1189 })) {
1190 return rewriter.notifyMatchFailure(
1191 packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1192 "this is not supported ATM!");
1193 }
1194
1195 Location loc = packOp.getLoc();
1196
1197 int64_t srcRank = packOp.getSourceRank();
1198
1199 // 1. Get the input that is going to be packed. If the input requires padding,
1200 // add a padding operation and return that as the input.
1201 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1202
1203 // 2. Transpose the input to match the inner tile order:
1204 // %init = tensor.empty()
1205 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1206 // outs(%init)
1207 // Assumptions made:
1208 // - All tiled outer dims are 1 - the corresponding transposition order
1209 // doesn't matter, but requires all dim indices to be present.
1210 // - Un-tiled outer dims remain un-permuted.
1211
1212 // 2.1 Get the permutation for linalg.transpose:
1213 // [ untiled-dims, inner-dims-pos ]
1214 // Note, this logic assumes that the untiled dims are not permuted.
1215 SmallVector<int64_t> srcPermForTranspose;
1216 for (int64_t i = 0; i < srcRank; i++) {
1217 // We assume the `k` dimensions of the inner dim position, where `k` is the
1218 // rank of the inner tiling, correspond to the last `k` indices of the
1219 // transpose permutation. This is done by adding the indices not contained
1220 // in the inner dimension position in order from 0 to `n`. Where n is the
1221 // rank of the source tensor. For example if we have a source tensor with
1222 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1223 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1224 if (llvm::is_contained(innerDimsPos, i))
1225 continue;
1226 srcPermForTranspose.push_back(i);
1227 }
1228 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1229
1230 // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1231 // [ untiled-dims, tiled-dims ]
1232 ShapedType inputTy = cast<ShapedType>(input.getType());
1233 SmallVector<OpFoldResult> shapeForEmptyOp;
1234 for (int64_t i = 0; i < srcRank; i++) {
1235 if (llvm::is_contained(innerDimsPos, i)) {
1236 // The tiled dims are appended after this loop.
1237 continue;
1238 }
1239 if (inputTy.isStaticDim(i))
1240 shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1241 else
1242 shapeForEmptyOp.emplace_back(
1243 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1244 }
1245 shapeForEmptyOp.append(packOp.getMixedTiles());
1246
1247 // getMixedTiles() may contain Values pointing to constant ops (as opposed to
1248 // constant attributes with the corresponding value). Replace those with
1249 // attributes. This is to match the behaviour in
1250 // `getPackOpSourceOrPaddedSource`, which replaces constant SSA values with
1251 // attributes.
1252 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1253 [&](OpFoldResult ofr) {
1254 if (auto val = llvm::dyn_cast<Value>(ofr))
1255 return getAsOpFoldResult(val);
1256 return ofr;
1257 });
1258
1259 LDBG() << "Pack permutation: " << packOp;
1260 LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1261 LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1262
1263 Value empty = tensor::EmptyOp::create(
1264 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1265
1266 // 2.3 Create linalg.transpose
1267 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1268 srcPermForTranspose);
1269
1270 // 3. Insert the inner tile into the destination tensor:
1271 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1272
1273 // Compute the sizes attribute:
1274 // [ outer-dims, tile-sizes ]
1275 // Note that the output from the transpose Op excludes the tiled outer dims.
1276 // However, given the assumption that:
1277 // * all tiled outer dims == 1,
1278 // we can just use a rank-expanding tensor.insert_slice.
1279 SmallVector<OpFoldResult> writeSizes;
1280 for (auto size : packOp.getAllOuterDims()) {
1281 writeSizes.push_back(rewriter.getIndexAttr(size));
1282 }
1283
1284 for (auto tileSize : packOp.getMixedTiles()) {
1285 auto [_, tileSizeOfr] =
1286 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1287 writeSizes.push_back(tileSizeOfr);
1288 }
1289
1290 auto insert = tensor::InsertSliceOp::create(
1291 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1292
1293 // 4. Replace tensor.packOp with tensor.insert_slice created above
1294 rewriter.replaceOp(packOp, insert.getResult());
1295
1296 return success();
1297}
1298
1300 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1301 if (!unpackOp.hasPureTensorSemantics())
1302 return failure();
1303
1304 int64_t destRank = unpackOp.getDestRank();
1305 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1306 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1307 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1308 [](int64_t dim) { return dim != 1; })) {
1309 return rewriter.notifyMatchFailure(
1310 unpackOp,
1311 "require the tiled outer dimensions of the result are all 1s");
1312 }
1313
1314 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1315 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1316 Location loc = unpackOp.getLoc();
1317 Value source = unpackOp.getSource();
1318 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1319 unpackOp.getDimAndTileMapping();
1320 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1321
1322 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1323 // dims:
1324 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1325 SmallVector<int64_t> readShapeForExtractSlice;
1326 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1327 // outer-tiled-dims being all 1), this will be
1328 // [ outer-untiled-dims, tile-sizes ]
1329 SmallVector<OpFoldResult> extractSliceSizes;
1330
1331 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1332 // This should be:
1333 // [ outer-untiled-dims, tile-sizes ]
1334 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1335 // permutation.
1336 SmallVector<OpFoldResult> shapeForEmptyOp;
1337
1338 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1339 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1340 //
1341 // As all outer tiled dims are 1, so the corresponding
1342 // slice size to read will also 1. As this will be rank-reducing "extract
1343 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1344 // update:
1345 // * the output shape for ExtractSliceOp, nor
1346 // * the shape for EmptyOp.
1347 if (dimAndTileMapping.count(i)) {
1348 extractSliceSizes.push_back(oneIdxAttr);
1349 continue;
1350 }
1351
1352 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1353 // outer-untiled-dims
1354 if (ShapedType::isDynamic(srcShape[i])) {
1355 OpFoldResult dynamicDim =
1356 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1357 extractSliceSizes.push_back(dynamicDim);
1358 shapeForEmptyOp.push_back(dynamicDim);
1359 } else {
1360 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1361 if (srcShape[i] != 1)
1362 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1363 }
1364 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1365 // into account rank-reducing)
1366 if (srcShape[i] != 1) {
1367 readShapeForExtractSlice.push_back(srcShape[i]);
1368 }
1369 }
1370 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1371 // shape for EmptyOp.
1372 auto mixedTiles = unpackOp.getMixedTiles();
1373 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1374 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1375
1376 // Explicitly create the type for extract_slice op because the inner tile
1377 // size could be 1. We want to represent the whole inner tile in this case.
1378 auto tileShape = srcShape.drop_front(destRank);
1379 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1380 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1381 Type elemType = unpackOp.getSourceType().getElementType();
1382 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1383 Value innerTile = tensor::ExtractSliceOp::create(
1384 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1385
1386 // 2. Transpose the tile to match the outer corresponding tile order.
1388 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1389 // Unpack is a transition out of packed space so we invert the permutation.
1390 perm = invertPermutationVector(perm);
1391 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1392
1393 Value empty =
1394 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1395 auto transposedOp =
1396 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1397
1398 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1399 // transposed tile.
1400 SmallVector<OpFoldResult> tileSizes;
1401 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1402 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1403 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1404 tileSizes.push_back(
1405 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1406 }
1407
1408 auto partialTile =
1409 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1410 transposedOp.getResult()[0], tileSizes);
1411
1412 // 4. Insert the result to the destination tensor.
1413 SmallVector<OpFoldResult> writeSizes;
1414 for (int i = 0, idx = 0; i < destRank; ++i) {
1415 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1416 writeSizes.push_back(tileSizes[idx++]);
1417 else
1418 writeSizes.push_back(oneIdxAttr);
1419 }
1420 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1421 unpackOp.getDest(), writeSizes);
1422 rewriter.replaceOp(unpackOp, insert.getResult());
1423
1424 return success();
1425}
1426
1427//===----------------------------------------------------------------------===//
1428// Generic DownscaleSizeOneWindowedConvolution
1429//===----------------------------------------------------------------------===//
1430//
1431/// Returns the indices of affine map results that reference any of the given
1432/// dimensions.
1435 SmallVector<unsigned> resultIndices;
1436 for (unsigned dim : dims) {
1437 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1438 AffineExpr expr = map.getResult(i);
1439 if (expr.isFunctionOfDim(dim)) {
1440 resultIndices.push_back(i);
1441 break;
1442 }
1443 }
1444 }
1445 return resultIndices;
1446}
1447
1448/// Helper to create a rank-reducing extract_slice that removes specific
1449/// dimensions from a tensor.
1451 Location loc, Value tensor,
1452 ArrayRef<unsigned> dimsToRemove) {
1453 auto tensorType = cast<RankedTensorType>(tensor.getType());
1454 int64_t rank = tensorType.getRank();
1455
1456 // Compute new shape by removing the specified dimensions.
1457 SmallVector<int64_t> newShape;
1458 for (int64_t i = 0; i < rank; ++i) {
1459 if (!llvm::is_contained(dimsToRemove, i))
1460 newShape.push_back(tensorType.getDimSize(i));
1461 }
1462
1463 auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
1465 tensor, newType);
1466}
1467
1468/// Drops specified dimensions from an AffineExpr and compresses remaining
1469/// dimension indices. Returns std::nullopt if the expression only references
1470/// the dropped dimensions.
1471static std::optional<AffineExpr>
1473 unsigned newNumDims, MLIRContext *ctx) {
1474 // Check if expr only references dimensions to be dropped.
1475 bool onlyReferencesDroppedDims = true;
1476 for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1477 if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
1478 onlyReferencesDroppedDims = false;
1479 break;
1480 }
1481 }
1482 if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
1483 return expr.isFunctionOfDim(d);
1484 }))
1485 return std::nullopt;
1486
1487 // Replace dimensions: compute new index for each old dimension.
1488 // Dropped dimensions get mapped to constant 0, others get compressed.
1489 SmallVector<AffineExpr> dimReplacements;
1490 unsigned newDimIdx = 0;
1491 for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1492 if (llvm::is_contained(dimsToDrop, d)) {
1493 dimReplacements.push_back(getAffineConstantExpr(0, ctx));
1494 } else {
1495 dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
1496 }
1497 }
1498
1499 return expr.replaceDims(dimReplacements);
1500}
1501
1502FailureOr<LinalgOp>
1504 LinalgOp op) {
1505 auto maybeDims = inferConvolutionDims(op);
1506 if (failed(maybeDims))
1507 return failure();
1508
1509 // Currently supports only 2D convolutions.
1510 if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
1511 return failure();
1512
1513 if (op.hasPureBufferSemantics())
1514 return failure();
1515
1516 // Get loop domain indices for spatial dimensions.
1517 unsigned outSpatial0 = maybeDims->outputImage[0];
1518 unsigned outSpatial1 = maybeDims->outputImage[1];
1519 unsigned filterSpatial0 = maybeDims->filterLoop[0];
1520 unsigned filterSpatial1 = maybeDims->filterLoop[1];
1521
1522 // Get sizes from loop bounds.
1523 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
1524 int64_t outSize0 = loopRanges[outSpatial0];
1525 int64_t outSize1 = loopRanges[outSpatial1];
1526 int64_t filterSize0 = loopRanges[filterSpatial0];
1527 int64_t filterSize1 = loopRanges[filterSpatial1];
1528
1529 // Check if we can downscale by removing a spatial dimension.
1530 bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
1531 bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
1532 if (!canRemoveSpatial0 && !canRemoveSpatial1)
1533 return failure();
1534
1535 // Determine which loop dims to remove (output spatial + corresponding filter)
1536 // and sort for correct index compression when removing dimensions from affine
1537 // maps.
1538 SmallVector<unsigned> loopDimsToRemove;
1539 if (canRemoveSpatial0) {
1540 loopDimsToRemove.push_back(outSpatial0);
1541 loopDimsToRemove.push_back(filterSpatial0);
1542 } else {
1543 loopDimsToRemove.push_back(outSpatial1);
1544 loopDimsToRemove.push_back(filterSpatial1);
1545 }
1546 llvm::sort(loopDimsToRemove);
1547
1548 // Create new indexing maps with dimensions removed.
1549 SmallVector<AffineMap> newMaps;
1550 MLIRContext *ctx = op.getContext();
1551 unsigned numDims = op.getNumLoops();
1552 unsigned newNumDims = numDims - loopDimsToRemove.size();
1553 for (AffineMap map : op.getIndexingMapsArray()) {
1554 SmallVector<AffineExpr> newResults;
1555 for (AffineExpr expr : map.getResults()) {
1556 auto newExpr =
1557 dropDimsAndCompress(expr, loopDimsToRemove, newNumDims, ctx);
1558 if (newExpr)
1559 newResults.push_back(*newExpr);
1560 }
1561 newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
1562 }
1563
1564 // Create new iterator types.
1566 auto iterTypes = op.getIteratorTypesArray();
1567 for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
1568 if (!llvm::is_contained(loopDimsToRemove, idx))
1569 newIterTypes.push_back(iterTypes[idx]);
1570 }
1571
1572 // Rank-reduce operands using extract_slice.
1573 Location loc = op.getLoc();
1574 SmallVector<Value> newInputs;
1575 for (OpOperand *input : op.getDpsInputOperands()) {
1576 AffineMap map = op.getMatchingIndexingMap(input);
1577 SmallVector<unsigned> tensorDimsToRemove =
1578 getResultIndicesReferencingDims(map, loopDimsToRemove);
1579 Value reduced = createRankReducingExtractSlice(rewriter, loc, input->get(),
1580 tensorDimsToRemove);
1581 newInputs.push_back(reduced);
1582 }
1583
1584 OpOperand &output = *op.getDpsInitsMutable().begin();
1585 AffineMap outputMap = op.getMatchingIndexingMap(&output);
1586 SmallVector<unsigned> outputDimsToRemove =
1587 getResultIndicesReferencingDims(outputMap, loopDimsToRemove);
1588 Value newOutput = createRankReducingExtractSlice(rewriter, loc, output.get(),
1589 outputDimsToRemove);
1590
1591 // Create new linalg.generic with reduced dimensions.
1592 auto newOp =
1593 linalg::GenericOp::create(rewriter, loc, TypeRange{newOutput.getType()},
1594 newInputs, newOutput, newMaps, newIterTypes);
1595 rewriter.inlineRegionBefore(op->getRegion(0), newOp.getRegion(),
1596 newOp.getRegion().begin());
1597
1598 // Try to specialize the generic back to a named op only if the input was
1599 // already a specialized (named) op.
1600 LinalgOp resultOp = newOp;
1601 if (!isa<GenericOp>(op)) {
1602 FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
1603 if (succeeded(specializedOp))
1604 resultOp = *specializedOp;
1605 }
1606
1607 // Insert result back into original shape.
1609 rewriter, loc, resultOp->getResult(0), output.get());
1610
1611 rewriter.replaceOp(op, result);
1612 return resultOp;
1613}
1614
1615namespace {
1616/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
1617struct DownscaleSizeOneWindowedConvolution final
1618 : public OpInterfaceRewritePattern<LinalgOp> {
1619 DownscaleSizeOneWindowedConvolution(MLIRContext *context,
1620 PatternBenefit benefit = 1)
1621 : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
1622
1623 LogicalResult matchAndRewrite(LinalgOp op,
1624 PatternRewriter &rewriter) const override {
1626 }
1627};
1628} // namespace
1629
1631 PatternBenefit benefit) {
1632 patterns.add<DownscaleSizeOneWindowedConvolution>(patterns.getContext(),
1633 benefit);
1634}
1635
1640
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
static std::optional< AffineExpr > dropDimsAndCompress(AffineExpr expr, ArrayRef< unsigned > dimsToDrop, unsigned newNumDims, MLIRContext *ctx)
Drops specified dimensions from an AffineExpr and compresses remaining dimension indices.
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static Value createRankReducingExtractSlice(RewriterBase &rewriter, Location loc, Value tensor, ArrayRef< unsigned > dimsToRemove)
Helper to create a rank-reducing extract_slice that removes specific dimensions from a tensor.
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static SmallVector< unsigned > getResultIndicesReferencingDims(AffineMap map, ArrayRef< unsigned > dims)
Returns the indices of affine map results that reference any of the given dimensions.
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition AffineMap.h:315
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:441
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition Transforms.h:204
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Definition Transforms.h:194
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.