MLIR 23.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
31#include "mlir/Support/LLVM.h"
32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
38#include <type_traits>
39#include <utility>
40
41#define DEBUG_TYPE "linalg-transforms"
42
43using namespace mlir;
44using namespace mlir::linalg;
45
46//===----------------------------------------------------------------------===//
47// Transformations exposed as functional-style API calls.
48//===----------------------------------------------------------------------===//
49
50//===----------------------------------------------------------------------===//
51// peelLoop transformation.
52//===----------------------------------------------------------------------===//
53
54/// Try to peel and canonicalize loop `op` and return the new result.
55/// Also applies affine_min/max bounds simplification on the fly where relevant.
56// TODO: Add support for scf.parallel and affine.for loops.
58 Operation *op) {
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
62 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
63 partialIteration)))
64 return partialIteration->getResults();
65 assert(!partialIteration && "expected that loop was not peeled");
66 return forOp->getResults();
67 })
68 .Default([&](Operation *op) { return op->getResults(); });
69}
70
71/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
72/// where relevant.
75 for (auto loopOp : loops)
76 peelLoop(rewriter, loopOp);
77}
78
79//===----------------------------------------------------------------------===//
80// pack transformation.
81//===----------------------------------------------------------------------===//
82
83#ifndef NDEBUG
84/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
86 bool found = false;
87 for (AffineExpr e : map.getResults()) {
88 if (!e.isFunctionOfDim(dim))
89 continue;
90 if (found)
91 return false;
92 found = true;
93 }
94 return true;
95}
96#endif // NDEBUG
97
99 return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
100}
101
102/// Return the index of the first result of `map` that is a function of
103/// AffineDimExpr(dim), std::nullopt otherwise.
104static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
105 int64_t dim) {
106 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
107 AffineExpr expr = map.getResult(i);
108 if (!expr.isFunctionOfDim(dim))
109 continue;
110 return i;
111 }
112 return std::nullopt;
113}
114
115/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
116/// `newDim` at `iteratorTypes.size()` by:
117/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
118/// 2. Appending a `newDim` to the domain of every indexing map.
119/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
120/// by potentially adding a `newDim` result to `map`.
121/// The preserved invariant is that `iteratorTypes.size()` is always equal to
122/// `map.getNumDims()` for every map in `indexingMaps`.
123///
124/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
125/// Return a vector that records the optional packing for each operand.
126/// Return failure if the packed indexing cannot be represented with a LinalgOp.
127///
128/// Further details:
129/// ================
130/// The current implementation of packing (i.e. data tiling) consists of
131/// rewriting a linearized strip-mined form into a higher-dimensional access.
132/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
133/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
134/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
135///
136/// This rewrite into higher dimensional access is not possible for general
137/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
138/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
139/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
140/// The rewrite of the access would be a form not representable in Linalg:
141/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
142/// Note however that as `J` and `ii` iterate, the accesses do not have a
143/// particular alignment, so packing does not achieve alignment in this case
144///
145/// In the future, we may want to consider a mixed-form that allows some
146/// alignment in the presence of multiple accesses:
147/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
148/// And would rewrite accesses as:
149/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
150static FailureOr<SmallVector<std::optional<int64_t>>>
153 int64_t dim) {
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
156
157 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
158 indexingMaps.size(), std::nullopt);
160 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
161 ++operandIdx) {
162 AffineMap map = indexingMaps[operandIdx];
163
164 // Add the `newDim` to map whatever the case.
165 assert(map.getNumDims() == newDim && "num dims invariant violation");
166 map = map.shiftDims(1, newDim);
167
168 // Get the at-most-1 index of the result that is a function of `dim`.
169 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
170 // logically chunks dimension `dim` into `K * dim + newDim`, where the
171 // packing factor `K` is specified separately.
172 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
173 "num results invariant violation");
174 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
177 continue;
178 }
179
180 // We can only pack AffineDimExpr atm.
181 if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
182 return failure();
183
184 // Add `newDim` to the results of the map.
185 map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
186 map.getNumResults());
187 newMaps.push_back(map);
188
189 // Record the that `operandIdx` is packed.
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
191 }
192 indexingMaps = newMaps;
193
194 return packedDimPerIndexingMap;
195}
196
197namespace {
198
199/// Helper struct to encode packing along one dimension of a LinalgOp.
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
203};
204
205/// Helper struct to encode packing along all dimensions of a LinalgOp.
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
209 }
210 /// Return all the dims that have been packed for operand @ `operandPos`.
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
212 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
214
215private:
216 SmallVector<PackedOperandsDim> spec;
217};
218
219} // namespace
220
221FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
224 // TODO: Support Memref PackOp. Temporarily return failure.
225 if (!packOp.hasPureTensorSemantics())
226 return failure();
227
228 // 1. Filter out NYI cases.
229 auto packedTensorType =
230 cast<RankedTensorType>(packOp->getResultTypes().front());
231 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
232 return rewriter.notifyMatchFailure(
233 packOp,
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 }
236
237 Location loc = packOp->getLoc();
238 OpBuilder::InsertionGuard g(rewriter);
239 rewriter.setInsertionPoint(packOp);
240
241 // 2. Compute the permutation vector to shuffle packed shape into the shape
242 // before any outer or inner permutations have been applied.
243 PackingMetadata packingMetadata;
244 SmallVector<int64_t> packedToStripMinedShapePerm =
245 getPackInverseDestPerm(packOp, packingMetadata);
246
247 // 3. Compute the stripMinedShape: this is the packed shape before any outer
248 // or inner permutations have been applied.
249 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
250 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
251
252 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
253 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
254 rewriter.getIndexAttr(0));
255 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
256 rewriter.getIndexAttr(0));
257 for (auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
259 int outerPos =
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
261 OpFoldResult origSize =
262 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
263 OpFoldResult outerSize =
264 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
265 AffineExpr s0, d0, d1;
266 bindDims(rewriter.getContext(), d0, d1);
267 bindSymbols(rewriter.getContext(), s0);
268 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
270 rewriter, loc, map, {outerSize, origSize, innerSize});
271 }
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
273 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
276 if (!paddingValue) {
277 paddingValue = arith::ConstantOp::create(
278 rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
279 }
280 auto padOp =
281 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue, /*nofold=*/false);
283
284 LDBG() << "insertPositions: "
285 << llvm::interleaved(packingMetadata.insertPositions);
286 LDBG() << "outerPositions: "
287 << llvm::interleaved(packingMetadata.outerPositions);
288 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
289 LDBG() << "packedToStripMinedShapePerm: "
290 << llvm::interleaved(packedToStripMinedShapePerm);
291 LDBG() << "reassociations: "
292 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
294 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 LDBG() << "collapsed type: " << collapsed;
296
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
298 // Pack ops which operate as simple pads may not produce legal
299 // tensor.insert_slice operations when the packed type does not rank reduce
300 // to the padded type.
301 SliceVerificationResult rankReduces =
302 isRankReducedType(packedTensorType, padOp.getResultType());
303
304 if (rankReduces == SliceVerificationResult::Success) {
305 // This pack is just a plain pad.
306 // Just insert the pad in the higher ranked tensor.
307 // Offsets.
308 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
309 rewriter.getIndexAttr(0));
310 // Strides.
311 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
312 rewriter.getIndexAttr(1));
314 tensor::getMixedSizes(rewriter, loc, packOp.getDest());
315
316 auto insertSliceOp = tensor::InsertSliceOp::create(
317 rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
318 /*offsets=*/zeros, sizes, /*strides=*/ones);
319
320 LDBG() << "insert_slice op: " << insertSliceOp;
321
322 rewriter.replaceOp(packOp, insertSliceOp->getResults());
323
324 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
325 /*transposeOp=*/nullptr};
326 }
327 }
328
329 // 5. Expand from the padded result to the stripMinedShape.
330 auto expandShapeResultType =
331 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
332 auto reshapeOp = tensor::ExpandShapeOp::create(
333 rewriter, loc, expandShapeResultType, padOp.getResult(),
334 packingMetadata.reassociations);
335
336 // 6. Transpose stripMinedShape to packedShape.
337 SmallVector<int64_t> transpPerm =
338 invertPermutationVector(packedToStripMinedShapePerm);
339 auto transposeOp = linalg::TransposeOp::create(
340 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
341
342 LDBG() << "reshape op: " << reshapeOp;
343 LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
344 LDBG() << "transpose op: " << transposeOp;
345
346 // 7. Replace packOp by transposeOp.
347 rewriter.replaceOp(packOp, transposeOp->getResults());
348
349 return LowerPackResult{padOp, reshapeOp, transposeOp};
350}
351
352FailureOr<LowerUnPackOpResult>
353linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
354 bool lowerUnpadLikeWithExtractSlice) {
355 // TODO: Support Memref UnPackOp. Temporarily return failure.
356 if (!unPackOp.hasPureTensorSemantics())
357 return failure();
358
359 Location loc = unPackOp->getLoc();
360 OpBuilder::InsertionGuard g(rewriter);
361 rewriter.setInsertionPoint(unPackOp);
362
363 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
364 int64_t packedRank = packedTensorType.getRank();
365
366 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
369 // This unpack is just a plain unpad.
370 // Just extract the slice from the higher ranked tensor.
371 ArrayRef<int64_t> destShape = destTensorType.getShape();
372 // The inner dimensions stay the same as the destination tensor, but the
373 // outer ones are additional 1s.
374 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
375 sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
376
377 auto extractSliceOp = tensor::ExtractSliceOp::create(
378 rewriter, loc, destTensorType, unPackOp.getSource(),
379 SmallVector<OpFoldResult>(packedRank, zero), sizes,
380 SmallVector<OpFoldResult>(packedRank, one));
381
382 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
383
384 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
385 /*reshapeOp=*/nullptr, extractSliceOp,
386 /*copyOp=*/nullptr};
387 }
388
389 // 1. Compute the permutation vector to shuffle packed shape into the shape
390 // before any outer or inner permutations have been applied.
391 PackingMetadata packingMetadata;
392 SmallVector<int64_t> packedToStripMinedShapePerm =
393 getUnPackInverseSrcPerm(unPackOp, packingMetadata);
394
395 // 2. Compute the stripMinedShape: this is the packed shape without outer and
396 // inner permutations.
397 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
398 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
399
400 // 3. Transpose packedShape to stripMinedShape.
401 RankedTensorType stripMinedTensorType =
402 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
403 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
404 stripMinedTensorType, packingMetadata.reassociations);
405
406 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
407 // permutation.
409 tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
410 applyPermutationToVector(dims, packedToStripMinedShapePerm);
411 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
412 stripMinedTensorType.getElementType());
413 auto transposeOp =
414 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
415 packedToStripMinedShapePerm);
416
417 LDBG() << "insertPositions: "
418 << llvm::interleaved(packingMetadata.insertPositions);
419 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
420 LDBG() << "packedToStripMinedShapePerm: "
421 << llvm::interleaved(packedToStripMinedShapePerm);
422 LDBG() << "reassociations: "
423 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
425 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
426 LDBG() << "collapsed type: " << collapsedType;
427
428 // 4. Collapse from the stripMinedShape to the padded result.
429 auto reshapeOp = tensor::CollapseShapeOp::create(
430 rewriter, loc, collapsedType, transposeOp->getResult(0),
431 packingMetadata.reassociations);
432
433 // 5. ExtractSlice.
434 int64_t destRank = destTensorType.getRank();
435 auto extractSliceOp = tensor::ExtractSliceOp::create(
436 rewriter, loc, destTensorType, reshapeOp->getResult(0),
437 SmallVector<OpFoldResult>(destRank, zero),
438 tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
439 SmallVector<OpFoldResult>(destRank, one));
440
441 // 6. Inject a copy to preserve DPS.
442 auto copyOp = linalg::CopyOp::create(
443 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
444
445 // 7. Replace unPackOp by copyOp.
446 rewriter.replaceOp(unPackOp, copyOp->getResults());
447
448 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp,
449 copyOp};
450}
451
453PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
455 for (auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
457 continue;
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
459 }
460 return res;
461}
462
463SmallVector<OpFoldResult>
464PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
465 SmallVector<OpFoldResult> res;
466 for (auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
468 continue;
469 res.push_back(i.packedSize);
470 }
471 return res;
472}
473
474/// Implement packing of a single LinalgOp by performing packing by
475/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
476/// Return the packed Linalg op on success, failure otherwise.
477FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
478 linalg::LinalgOp linalgOp,
479 ArrayRef<OpFoldResult> packedSizes) {
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
481 return rewriter.notifyMatchFailure(linalgOp,
482 "incorrect number of pack sizes");
483 }
484
485 Location loc = linalgOp->getLoc();
486 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
488 linalgOp.getIteratorTypesArray();
489 LDBG() << "Start packing: " << linalgOp;
490 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
491 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
492
495 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
496 PackedOperandsDimList listOfPackedOperandsDim;
497 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
498 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
499 // Skip tile sizes explicitly set to 0.
500 if (maybeConstant.has_value() && maybeConstant.value() == 0)
501 continue;
502
503 PackedOperandsDim packedOperandsDims;
504 packedOperandsDims.packedSize = packedSizes[i];
505 FailureOr<SmallVector<std::optional<int64_t>>>
506 maybePackedDimForEachOperand =
507 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
508 if (failed(maybePackedDimForEachOperand))
509 return failure();
510 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
511
512 LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
513 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
514 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
515 LDBG() << "packedDimForEachOperand: "
516 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
517
518 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
519 }
520
521 // Step 2. Propagate packing to all LinalgOp operands.
522 SmallVector<Value> inputsAndInits, results;
523 SmallVector<OpOperand *> initOperands =
524 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
525 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
526 for (const auto &operandsList : {inputOperands, initOperands}) {
527 for (OpOperand *opOperand : operandsList) {
528 int64_t pos = opOperand->getOperandNumber();
529 Value operand = opOperand->get();
530 SmallVector<int64_t> innerPos =
531 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
532 SmallVector<OpFoldResult> innerPackSizes =
533 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
534 LDBG() << "operand: " << operand;
535 LDBG() << "innerPos: " << llvm::interleaved(innerPos);
536 LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
537 if (innerPackSizes.empty()) {
538 inputsAndInits.push_back(operand);
539 continue;
540 }
541 Value dest = linalg::PackOp::createDestinationTensor(
542 rewriter, loc, operand, innerPackSizes, innerPos,
543 /*outerDimsPerm=*/{});
544 ShapedType operandType = cast<ShapedType>(operand.getType());
545 bool areConstantTiles =
546 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
547 return getConstantIntValue(tile).has_value();
548 });
549 if (areConstantTiles && operandType.hasStaticShape() &&
550 !linalg::PackOp::requirePaddingValue(
551 operandType.getShape(), innerPos,
552 cast<ShapedType>(dest.getType()).getShape(), {},
553 innerPackSizes)) {
554 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
555 innerPos, innerPackSizes));
556 } else {
557 // TODO: value of the padding attribute should be determined by
558 // consumers.
559 auto zeroAttr =
560 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
561 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
562 packOps.push_back(linalg::PackOp::create(
563 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
564 }
565 inputsAndInits.push_back(packOps.back().getResult());
566 }
567 }
568
569 // Step 3. Build the packed op, use the type of `inits` as result types.
570 ValueRange inputs =
571 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
572 ValueRange inits =
573 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
574 auto packedLinalgOp =
575 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
576 inputs, inits, indexingMaps, iteratorTypes);
577 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
578
579 // Step 4. Propagate packing to all the op results.
580 for (OpResult result : packedLinalgOp->getResults()) {
581 int64_t resultNum = result.getResultNumber();
582 linalg::PackOp maybePackedInit =
583 inits[resultNum].getDefiningOp<linalg::PackOp>();
584 if (!maybePackedInit) {
585 results.push_back(result);
586 continue;
587 }
588 // Build the symmetrical UnPackOp to the existing PackOp.
589 unPackOps.push_back(linalg::UnPackOp::create(
590 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
591 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
592 results.push_back(unPackOps.back().getResult());
593 }
594
595 // Step 5. Replace `linalgOp`.
596 rewriter.replaceOp(linalgOp, results);
597
598 // Return packedLinalgOp.
599 return PackResult{packOps,
600 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
601 unPackOps};
602}
603
604//===----------------------------------------------------------------------===//
605// packTranspose transformation.
606//===----------------------------------------------------------------------===//
607
608/// Return a copy of `tensorType` after permutation by `permutationVector`.
609// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
610// but this would introduce a dependence on Dialect in IR.
611// TODO: Restructure.
612static RankedTensorType permuteShape(RankedTensorType tensorType,
613 ArrayRef<int64_t> permutationVector) {
614 SmallVector<int64_t> shape(tensorType.getShape());
615 applyPermutationToVector(shape, permutationVector);
616 return RankedTensorType::Builder(tensorType).setShape(shape);
617}
618
619/// Return a new GenericOp obtained by transposing opOperand by the permutation
620/// vector:
621/// - the corresponding indexing map is transposed by `permutation`
622/// - the corresponding operand value is replaced by `transposedValue`
623/// `linalgOp` is replaced by the return op in the process.
624/// Asserts that `transposedValue` is of the proper transposed ShapedType.
626 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
627 ArrayRef<int64_t> permutation, Value transposedValue) {
628 // Sanity check the operand.
629 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
630
631 // Sanity check of the expected transposed tensor type.
632 auto tensorType = permuteShape(
633 cast<RankedTensorType>(opOperand.get().getType()), permutation);
634 (void)tensorType;
635 assert(tensorType == transposedValue.getType() &&
636 "expected tensor type mismatch");
637
638 // Compute the transposed indexing map.
639 // Sigh unsigned pollution.
640 SmallVector<unsigned> tmpTransposition =
641 llvm::map_to_vector(permutation, [](int64_t i) -> unsigned { return i; });
642 AffineMap permutationMap =
643 AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
644 AffineMap transposedMap =
645 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
646
647 // Set the transposed indexing map in the proper position.
648 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
649 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
650 // Set the transposedValue in the proper operand position.
651 SmallVector<Value> operands = linalgOp->getOperands();
652 operands[opOperand.getOperandNumber()] = transposedValue;
653
654 ValueRange operandsRef(operands);
655 auto transposedGenericOp = linalg::GenericOp::create(
656 rewriter,
657 /*location=*/linalgOp->getLoc(),
658 /*resultTensorTypes=*/
659 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
660 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
661 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
662 /*indexingMaps=*/indexingMaps,
663 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
664 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
665 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
666
667 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
668}
669
670FailureOr<PackTransposeResult>
671linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
672 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
673 ArrayRef<int64_t> outerPerm,
674 ArrayRef<int64_t> innerPerm) {
675 Location loc = linalgOp.getLoc();
676
677 // Step 1. Transpose packOp.
678 rewriter.setInsertionPoint(packOp);
679 linalg::PackOp transposedPackOp =
680 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
681
682 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
683 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
684
685 OpOperand &packUse = *packOp->getUses().begin();
686 if (packUse.getOwner() != linalgOp) {
687 return rewriter.notifyMatchFailure(
688 linalgOp, "not a single use by the LinalgOp target");
689 }
690 if (maybeUnPackOp &&
691 (!linalgOp.isDpsInit(&packUse) ||
692 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
693 return rewriter.notifyMatchFailure(linalgOp,
694 "not produced by the LinalgOp target");
695 }
696
697 // Step 2. Transpose linalgOp.
698 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
699 // identity. Don't rely on it.
700 int64_t numLeadingDims = packOp.getSourceRank();
701 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
702 // Step 2.a. Compute the permutation on the whole operand.
703 // Leading part just reuse the outerPerm.
704 SmallVector<int64_t> permutation(outerPerm);
705 if (permutation.empty())
706 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
707 // Trailing part needs to reindex positions by `numLeadingDims`.
708 if (innerPerm.empty()) {
709 llvm::append_range(
710 permutation,
711 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
712 } else {
713 llvm::append_range(permutation,
714 llvm::map_range(innerPerm, [&](int64_t pos) {
715 return numLeadingDims + pos;
716 }));
717 }
718 if (!isPermutationVector(permutation))
719 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
720
721 // Step 2.b. Save the transposedPackUse operand number in case we need to
722 // get the tied OpResult after `linalgOp` has been replaced.
723 int64_t packUseOperandNumber = packUse.getOperandNumber();
724 // Step 2.c. Actually perform the transposition.
725 rewriter.setInsertionPoint(linalgOp);
726 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
727 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
728
729 // Step 3. Maybe transpose unPackOp.
730 linalg::UnPackOp transposedUnPackOp;
731 if (maybeUnPackOp) {
732 OpOperand &opOperand =
733 transposedLinalgOp->getOpOperand(packUseOperandNumber);
734 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
735 rewriter.setInsertionPoint(maybeUnPackOp);
736 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
737 rewriter, loc, transposedResult, innerPerm, outerPerm);
738
739 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
740 }
741
742 // Step 4. Finally, replace packOp now that we don't need it anymore.
743 if (packOp.hasPureTensorSemantics())
744 rewriter.replaceOp(packOp, transposedPackOp->getResults());
745 else
746 rewriter.eraseOp(packOp);
747
748 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
749 transposedUnPackOp};
750}
751
752//===----------------------------------------------------------------------===//
753// packMatmulGreedily transformation.
754//===----------------------------------------------------------------------===//
755
756/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
757/// and n are proper parallel dimensions and k is a proper reduction
758/// dimension. Packing occurs by rewriting the op as a linalg.generic and
759/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
760/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
761/// to reorder {m, n, k} into one of the 8 possible forms. The outer
762/// dimensions of the operands are not permuted at this time, this is left for
763/// future work.
764FailureOr<PackResult>
765linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
766 ArrayRef<OpFoldResult> mnkPackedSizes,
767 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
768 ArrayRef<int64_t> mnkOrder) {
769 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
770 assert((mnkPaddedSizesNextMultipleOf.empty() ||
771 mnkPaddedSizesNextMultipleOf.size() == 3) &&
772 "num of packing sizes next multiple should be empty or of size 3");
773 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
774 assert(isPermutationVector(mnkOrder) && "expected a permutation");
775
776 int64_t numLoops = linalgOp.getNumLoops();
777 if (numLoops <= 2) {
778 LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
779 << " in: " << linalgOp;
780 return rewriter.notifyMatchFailure(
781 linalgOp, "need 3+ loops to find a matmul to pack");
782 }
783
784 // Locally adjust the desired iterator position of mnk and packing sizes.
785 int64_t numPackedDims = mnkPackedSizes.size();
786 SmallVector<int64_t> mmnnkkPos(numPackedDims);
787 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
788 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
789 SmallVector<OpFoldResult> packedSizes(numPackedDims);
790 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
792 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
793 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
794 paddedSizesNextMultipleOf[mnkOrder[i]] =
795 mnkPaddedSizesNextMultipleOf.empty() ? 0
796 : mnkPaddedSizesNextMultipleOf[i];
797 }
798
799 // 1. Infer dims that are important for matmul.
800 FailureOr<ContractionDimensions> maybeDimensions =
801 inferContractionDims(linalgOp);
802 if (failed(maybeDimensions)) {
803 LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
804 return rewriter.notifyMatchFailure(linalgOp,
805 "couldn't infer matmul iterators");
806 }
807
808 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
809 // minor iterators. In cases with multiple options for m, n, k bias towards
810 // the most minor embedding.
811 // If we wanted a different normalization order, this is where it would have
812 // to plug a heuristic.
813 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
814 kPos = maybeDimensions->k.back();
815 LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
816 << nPos << ", k@" << kPos << "): " << linalgOp;
817
818 // 2.a. Rewrite as a generic.
819 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
820 if (!genericOp) {
821 FailureOr<GenericOp> generalizeResult =
822 generalizeNamedOp(rewriter, linalgOp);
823 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
824 genericOp = *generalizeResult;
825 }
826
827 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
828 // iterators. Note that this only normalized the iteration order and does
829 // not change the indexings of any operand.
830 SmallVector<int64_t> permutation =
831 computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
832 LDBG() << "perm: " << llvm::interleaved(permutation);
833 // Sign .. unsigned pollution.
834 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
835 FailureOr<GenericOp> interchangeResult =
836 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
837 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
838 genericOp = *interchangeResult;
839 LDBG() << "Generalized Op to pack: " << genericOp;
840
841 // At this point, the op iterators are normalized to {leading, k, m, n}.
842 // The layouts induced by packing will always be:
843 // - LHS{leading_lhs, kk, mm}
844 // - RHS{leading_rhs, kk, nn}
845 // - RES{leading_res, mm, nn}
846 // If we wanted to change the packed order, we would reorder (k, m, n) to
847 // something else above.
848 //
849 // Additional permutations of the outer dims of the operands (i.e.
850 // leading_lhs, leading_rhs and leading_res) could follow by computing the
851 // desired outerPerm for each operand.
852 // This is left for future work.
853
854 // TODO: this creates too much IR, go use reifyResultShapes.
855 SmallVector<Range, 4> loopRanges =
856 cast<LinalgOp>(genericOp.getOperation())
857 .createLoopRanges(rewriter, genericOp.getLoc());
858
859 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
860 // post interchange.
861 LDBG() << "paddedSizesNextMultipleOf: "
862 << llvm::interleaved(paddedSizesNextMultipleOf);
863 LDBG() << "loopRanges: "
864 << llvm::interleaved(
865 llvm::map_range(loopRanges, [](Range r) { return r.size; }));
866 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
867 rewriter.getIndexAttr(0));
868 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
869 if (paddedSizesNextMultipleOf[i] == 0) {
870 adjustedPackedSizes.push_back(packedSizes[i]);
871 continue;
872 }
873 AffineExpr d0, s0;
874 bindDims(rewriter.getContext(), d0);
875 bindSymbols(rewriter.getContext(), s0);
876 adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
877 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
878 {loopRanges[adjustedPackedSizes.size()].size,
879 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
880 }
881 LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
882
883 // TODO: If we wanted to give the genericOp a name after packing, after
884 // calling `pack` would be a good time. One would still need to check that
885 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
886 // also allow degenerate matmul cases (i.e. matvec, dot).
887 return pack(rewriter, genericOp, adjustedPackedSizes);
888}
889
890//===----------------------------------------------------------------------===//
891// Transformations exposed as rewrite patterns.
892//===----------------------------------------------------------------------===//
893
896 assert(!tileSizeComputationFunction && "tile sizes already set");
897 SmallVector<int64_t, 4> tileSizes(ts);
898 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
900 b.setInsertionPointToStart(
901 &op->getParentOfType<func::FuncOp>().getBody().front());
902 return llvm::map_to_vector<4>(tileSizes, [&](int64_t s) {
903 Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
904 return v;
905 });
906 };
907 return *this;
908}
909
911 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
912 return vectorizeCopy(rewriter, copyOp);
913}
914
915/// Filling `dest` using FillOp constant padding value if possible.
916/// Otherwise, generate a tensor::GenerateOp.
918 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
919 const SmallVector<Value> &dynSizes) const {
920 auto padValue = padOp.getConstantPaddingValue();
921 if (padValue) {
922 // Move the padding value defined inside the PadOp block to outside.
923 if (padValue.getParentBlock() == &padOp.getRegion().front())
924 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
925 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
926 }
927
928 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
929 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
930 padOp.getResultType(), dynSizes);
931 // Copy region to new op.
932 IRMapping bvm;
933 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
934 return generateOp;
935}
936
937LogicalResult
939 PatternRewriter &rewriter) const {
940 // Given an OpFoldResult, return an index-typed value.
941 auto getIdxValue = [&](OpFoldResult ofr) {
942 if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
943 return val;
945 rewriter, padOp.getLoc(),
946 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
947 .getResult();
948 };
949
950 auto resultType = padOp.getResultType();
951 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
952 SmallVector<Value> dynSizes;
953 SmallVector<int64_t> staticSizes;
954 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
955 if (resultType.isDynamicDim(dim)) {
956 auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
957 padOp.getSource(), dim));
958 // Add low and high padding value.
959 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
960 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
961 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
962 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
963 dynSizes.push_back(plusHigh);
964 }
965 staticSizes.push_back(resultType.getDimSize(dim));
966 }
967
968 // Init tensor and fill it with padding.
969 Value emptyTensor =
970 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
971 resultType.getElementType(), dynSizes);
972 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
973
974 // Generate a InsertSliceOp for copying the PadOp source.
975 auto sourceType = padOp.getSourceType();
976 // Compute size of source of tensor::PadOp.
978 tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
979 // Strides of InsertSliceOp are all 1.
980 SmallVector<OpFoldResult> strides(sourceType.getRank(),
981 rewriter.getIndexAttr(1));
982 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
983 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
984 strides);
985
986 return success();
987}
988
990 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
991 if (!sliceOp.hasUnitStride())
992 return failure();
993
994 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
995 if (!padOp)
996 return failure();
997
998 bool zeroSliceGuard = true;
999 if (controlFn) {
1000 if (std::optional<bool> control = controlFn(sliceOp))
1001 zeroSliceGuard = *control;
1002 else
1003 return failure();
1004 }
1005
1006 FailureOr<TilingResult> tilingResult =
1007 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1008 sliceOp.getMixedSizes(), zeroSliceGuard);
1009 if (failed(tilingResult))
1010 return failure();
1011
1012 RankedTensorType sourceType = sliceOp.getSourceType();
1013 RankedTensorType resultType = sliceOp.getResultType();
1014
1015 // If the extract_slice is not rank-reduced, all shapes are static and the
1016 // data source is actually used. Rewrite into pad(extract_slice(x)).
1017 if (sourceType.getRank() == resultType.getRank()) {
1018 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1019 return success();
1020 }
1021
1022 // Handle rank-reduced slice by creating another extract_slice op.
1024 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1025
1026 rewriter.replaceOp(sliceOp, rankReduced);
1027 return success();
1028}
1029
1030/// If padding value is set, returns a tensor.pad Op for the source tensor,
1031/// with the output shape matching the output of `packOp`. Otherwise, returns
1032/// the source directly.
1033///
1034/// This method assumes that all outer dims for this pack Op are 1.
1036 linalg::PackOp packOp) {
1037 Value input = packOp.getSource();
1038 // TODO: Support Memref PackOp. Temporarily return just Op Source.
1039 if (!packOp.hasPureTensorSemantics())
1040 return input;
1041
1042 if (!packOp.getPaddingValue()) {
1043 return input;
1044 }
1045
1046 assert(llvm::all_of(packOp.getAllOuterDims(),
1047 [](int64_t val) { return val == 1; }) &&
1048 "some outer dims are != 1");
1049
1050 Location loc = packOp.getLoc();
1051 ShapedType inputType = packOp.getSourceType();
1052 int64_t inputRank = inputType.getRank();
1053
1054 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1055 packOp.getDimAndTileMapping();
1056
1057 // The sizes of dynamic tiles
1058 SmallVector<Value> dynamicTileSizes;
1059
1060 // Collect dims for the padded shape.
1061 SmallVector<int64_t> paddedShape;
1062 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1063 // 1. Non-tiled outer dims.
1064 // These dims should be 1 and we simply preserve them.
1065 if (!tileAndPosMapping.count(dimIdx)) {
1066 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1067 assert(inputDimSize == 1 &&
1068 "with all outer dims == 1, this non-tiled input dim should be 1!");
1069 paddedShape.push_back(inputDimSize);
1070 continue;
1071 }
1072
1073 // 2. Tiled outer dims
1074 // As all outer dims == 1, it is safe to use the tile size for the padded
1075 // shape.
1076 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1077
1078 // 2.1 Static tile sizes
1079 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1080 if (cstTileSize.has_value()) {
1081 paddedShape.push_back(cstTileSize.value());
1082 continue;
1083 }
1084
1085 // 2.2 Dynamic tile sizes
1086 paddedShape.push_back(ShapedType::kDynamic);
1087
1088 // Get the value that holds the dynamic size.
1089 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1090 }
1091 auto resultType =
1092 RankedTensorType::get(paddedShape, inputType.getElementType());
1093 return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1094 /*nofold=*/false, loc, builder,
1095 dynamicTileSizes);
1096}
1097
1098// Normalizes a permutation on a higher rank space to its actual size, e.g.
1099// perm = [1, 4, 2]
1100// becomes
1101// norm = [0, 2, 1]
1102static SmallVector<int64_t>
1104 constexpr int64_t kNonTiledMarker = -1;
1105 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1106 for (auto [index, value] : llvm::enumerate(perm))
1107 vec[value] = index;
1108 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1109 vec, [&](int64_t v) { return v != kNonTiledMarker; });
1110 // This inverts the permutation in addition to normalizing so invert back.
1111 return invertPermutationVector(normalizedPerm);
1112}
1113
1114// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1115// assuming rank reduction of unit outer dims.
1116static SmallVector<int64_t>
1118 ArrayRef<int64_t> innerDimsPos,
1119 ArrayRef<int64_t> outerDimsPerm) {
1120 SmallVector<int64_t> rankReducedOuterDimsPerm;
1121 SmallVector<int64_t> outerDims;
1122 SmallVector<int64_t> innerDims;
1123 int64_t dim = 0;
1124 int64_t unpackedRank = shape.size();
1125 for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1126 if (llvm::is_contained(innerDimsPos, i)) {
1127 innerDims.push_back(dim++);
1128 continue;
1129 }
1130 if (shape[i] == 1)
1131 continue;
1132 outerDims.push_back(dim++);
1133 if (!outerDimsPerm.empty())
1134 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1135 }
1136
1137 // Get the position of the inner dims after permutation.
1138 SmallVector<int64_t> innerPerm =
1139 getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1140 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1141
1142 // Ditto for the outer dims.
1143 SmallVector<int64_t> perm = outerDims;
1144
1145 rankReducedOuterDimsPerm =
1146 getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1147 if (!rankReducedOuterDimsPerm.empty())
1148 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1149
1150 // The tile always ends up as the inner most dims after packing.
1151 perm.append(innerDims);
1152
1153 return perm;
1154}
1155
1157 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1158 // TODO: Support Memref PackOp. Temporarily return failure.
1159 if (!packOp.hasPureTensorSemantics())
1160 return failure();
1161
1162 if (llvm::any_of(packOp.getTiledOuterDims(),
1163 [](int64_t dim) { return dim != 1; })) {
1164 return rewriter.notifyMatchFailure(
1165 packOp, "not all outer dimensions of the result are 1s");
1166 }
1167
1168 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1169 auto outerDimsPerm = packOp.getOuterDimsPerm();
1170
1171 // Verify that there are no:
1172 // * non-unit + un-tiled-outer-dims,
1173 // that are permuted. Supporting such cases would require refining the logic
1174 // that generates the Transpose Op.
1175 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1176 static int prev = 0;
1177 // Skip tiled dims - these can be permuted.
1178 if (llvm::is_contained(innerDimsPos, dim))
1179 return true;
1180
1181 // Check whether this dim has been permuted. Permuting unit dims is fine
1182 // as that's effectively a no-op.
1183 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1184 packOp.getResult().getType().getShape()[dim] != 1))
1185 return false;
1186
1187 prev = dim;
1188 return true;
1189 })) {
1190 return rewriter.notifyMatchFailure(
1191 packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1192 "this is not supported ATM!");
1193 }
1194
1195 Location loc = packOp.getLoc();
1196
1197 int64_t srcRank = packOp.getSourceRank();
1198
1199 // 1. Get the input that is going to be packed. If the input requires padding,
1200 // add a padding operation and return that as the input.
1201 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1202
1203 // 2. Transpose the input to match the inner tile order:
1204 // %init = tensor.empty()
1205 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1206 // outs(%init)
1207 // Assumptions made:
1208 // - All tiled outer dims are 1 - the corresponding transposition order
1209 // doesn't matter, but requires all dim indices to be present.
1210 // - Un-tiled outer dims remain un-permuted.
1211
1212 // 2.1 Get the permutation for linalg.transpose:
1213 // [ untiled-dims, inner-dims-pos ]
1214 // Note, this logic assumes that the untiled dims are not permuted.
1215 SmallVector<int64_t> srcPermForTranspose;
1216 for (int64_t i = 0; i < srcRank; i++) {
1217 // We assume the `k` dimensions of the inner dim position, where `k` is the
1218 // rank of the inner tiling, correspond to the last `k` indices of the
1219 // transpose permutation. This is done by adding the indices not contained
1220 // in the inner dimension position in order from 0 to `n`. Where n is the
1221 // rank of the source tensor. For example if we have a source tensor with
1222 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1223 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1224 if (llvm::is_contained(innerDimsPos, i))
1225 continue;
1226 srcPermForTranspose.push_back(i);
1227 }
1228 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1229
1230 // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1231 // [ untiled-dims, tiled-dims ]
1232 ShapedType inputTy = cast<ShapedType>(input.getType());
1233 SmallVector<OpFoldResult> shapeForEmptyOp;
1234 for (int64_t i = 0; i < srcRank; i++) {
1235 if (llvm::is_contained(innerDimsPos, i)) {
1236 // The tiled dims are appended after this loop.
1237 continue;
1238 }
1239 if (inputTy.isStaticDim(i))
1240 shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1241 else
1242 shapeForEmptyOp.emplace_back(
1243 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1244 }
1245 shapeForEmptyOp.append(packOp.getMixedTiles());
1246
1247 // getMixedTiles() may contain Values pointing to constant ops (as opposed to
1248 // constant attributes with the corresponding value). Replace those with
1249 // attributes. This is to match the behaviour in
1250 // `getPackOpSourceOrPaddedSource`, which replaces constant SSA values with
1251 // attributes.
1252 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1253 [&](OpFoldResult ofr) {
1254 if (auto val = llvm::dyn_cast<Value>(ofr))
1255 return getAsOpFoldResult(val);
1256 return ofr;
1257 });
1258
1259 LDBG() << "Pack permutation: " << packOp;
1260 LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1261 LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1262
1263 Value empty = tensor::EmptyOp::create(
1264 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1265
1266 // 2.3 Create linalg.transpose
1267 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1268 srcPermForTranspose);
1269
1270 // 3. Insert the inner tile into the destination tensor:
1271 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1272
1273 // Compute the sizes attribute:
1274 // [ outer-dims, tile-sizes ]
1275 // Note that the output from the transpose Op excludes the tiled outer dims.
1276 // However, given the assumption that:
1277 // * all tiled outer dims == 1,
1278 // we can just use a rank-expanding tensor.insert_slice.
1279 SmallVector<OpFoldResult> writeSizes;
1280 for (auto size : packOp.getAllOuterDims()) {
1281 writeSizes.push_back(rewriter.getIndexAttr(size));
1282 }
1283
1284 for (auto tileSize : packOp.getMixedTiles()) {
1285 auto [_, tileSizeOfr] =
1286 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1287 writeSizes.push_back(tileSizeOfr);
1288 }
1289
1290 auto insert = tensor::InsertSliceOp::create(
1291 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1292
1293 // 4. Replace tensor.packOp with tensor.insert_slice created above
1294 rewriter.replaceOp(packOp, insert.getResult());
1295
1296 return success();
1297}
1298
1300 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1301 if (!unpackOp.hasPureTensorSemantics())
1302 return failure();
1303
1304 int64_t destRank = unpackOp.getDestRank();
1305 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1306 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1307 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1308 [](int64_t dim) { return dim != 1; })) {
1309 return rewriter.notifyMatchFailure(
1310 unpackOp,
1311 "require the tiled outer dimensions of the result are all 1s");
1312 }
1313
1314 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1315 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1316 Location loc = unpackOp.getLoc();
1317 Value source = unpackOp.getSource();
1318 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1319 unpackOp.getDimAndTileMapping();
1320 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1321
1322 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1323 // dims:
1324 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1325 SmallVector<int64_t> readShapeForExtractSlice;
1326 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1327 // outer-tiled-dims being all 1), this will be
1328 // [ outer-untiled-dims, tile-sizes ]
1329 SmallVector<OpFoldResult> extractSliceSizes;
1330
1331 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1332 // This should be:
1333 // [ outer-untiled-dims, tile-sizes ]
1334 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1335 // permutation.
1336 SmallVector<OpFoldResult> shapeForEmptyOp;
1337
1338 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1339 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1340 //
1341 // As all outer tiled dims are 1, so the corresponding
1342 // slice size to read will also 1. As this will be rank-reducing "extract
1343 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1344 // update:
1345 // * the output shape for ExtractSliceOp, nor
1346 // * the shape for EmptyOp.
1347 if (dimAndTileMapping.count(i)) {
1348 extractSliceSizes.push_back(oneIdxAttr);
1349 continue;
1350 }
1351
1352 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1353 // outer-untiled-dims
1354 if (ShapedType::isDynamic(srcShape[i])) {
1355 OpFoldResult dynamicDim =
1356 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1357 extractSliceSizes.push_back(dynamicDim);
1358 shapeForEmptyOp.push_back(dynamicDim);
1359 } else {
1360 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1361 if (srcShape[i] != 1)
1362 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1363 }
1364 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1365 // into account rank-reducing)
1366 if (srcShape[i] != 1) {
1367 readShapeForExtractSlice.push_back(srcShape[i]);
1368 }
1369 }
1370 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1371 // shape for EmptyOp.
1372 auto mixedTiles = unpackOp.getMixedTiles();
1373 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1374 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1375
1376 // Explicitly create the type for extract_slice op because the inner tile
1377 // size could be 1. We want to represent the whole inner tile in this case.
1378 auto tileShape = srcShape.drop_front(destRank);
1379 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1380 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1381 Type elemType = unpackOp.getSourceType().getElementType();
1382 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1383 Value innerTile = tensor::ExtractSliceOp::create(
1384 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1385
1386 // 2. Transpose the tile to match the outer corresponding tile order.
1388 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1389 // Unpack is a transition out of packed space so we invert the permutation.
1390 perm = invertPermutationVector(perm);
1391 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1392
1393 Value empty =
1394 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1395 auto transposedOp =
1396 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1397
1398 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1399 // transposed tile.
1400 SmallVector<OpFoldResult> tileSizes;
1401 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1402 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1403 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1404 tileSizes.push_back(
1405 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1406 }
1407
1408 auto partialTile =
1409 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1410 transposedOp.getResult()[0], tileSizes);
1411
1412 // 4. Insert the result to the destination tensor.
1413 SmallVector<OpFoldResult> writeSizes;
1414 for (int i = 0, idx = 0; i < destRank; ++i) {
1415 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1416 writeSizes.push_back(tileSizes[idx++]);
1417 else
1418 writeSizes.push_back(oneIdxAttr);
1419 }
1420 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1421 unpackOp.getDest(), writeSizes);
1422 rewriter.replaceOp(unpackOp, insert.getResult());
1423
1424 return success();
1425}
1426
1427// The following are patterns for downscaling convolution ops with size-1
1428// window dimensions.
1429//
1430// Note that we'd eventually want to write such transformations in a generic
1431// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1432// and then turning back to named ops. But for now it's fine to have a few
1433// patterns matching special ops to get started.
1434
1435template <typename Conv2DOp, typename Conv1DOp>
1437 returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
1438 // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
1439 std::optional<DilationsAndStrides> convParams =
1441 if (!convParams)
1442 return failure();
1443 SmallVector<int64_t> dilations = std::move(convParams->dilations);
1444 SmallVector<int64_t> strides = std::move(convParams->strides);
1445
1446 if (convOp.hasPureBufferSemantics())
1447 return failure(); // To be implemented.
1448
1449 Value input = convOp.getDpsInputs().front();
1450 Value kernel = convOp.getDpsInputs().back();
1451 Value output = convOp.getDpsInits().front();
1452
1453 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1454 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1455 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1456
1457 auto kernelShape = kernelType.getShape();
1458 auto outputShape = outputType.getShape();
1459
1460 // Get domain indices based on Conv2DOp type. These are known at compile time.
1461 int64_t khIndex, kwIndex, ohIndex, owIndex;
1462 if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
1463 std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
1464 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
1465 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
1466 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
1467 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
1468 // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
1469 khIndex = 0;
1470 kwIndex = 1;
1471 ohIndex = 1;
1472 owIndex = 2;
1473 } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
1474 // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
1475 khIndex = 2;
1476 kwIndex = 3;
1477 ohIndex = 2;
1478 owIndex = 3;
1479 } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
1480 std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
1481 // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
1482 khIndex = 0;
1483 kwIndex = 1;
1484 ohIndex = 2;
1485 owIndex = 3;
1486 }
1487
1488 // Only handle the case where at least one of the window dimensions is
1489 // of size 1. Other cases can rely on tiling to reduce to such cases.
1490 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1491 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1492 bool removeH = (khSize == 1 && ohSize == 1);
1493 bool removeW = (kwSize == 1 && owSize == 1);
1494 if (!removeH && !removeW)
1495 return failure();
1496
1497 // Get new shapes and types for all operands by removing the size-1
1498 // dimension.
1499 using RTTBuilder = RankedTensorType::Builder;
1500 RankedTensorType newInputType =
1501 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1502 RankedTensorType newKernelType =
1503 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1504 RankedTensorType newOutputType =
1505 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1506
1507 // Rank-reduce operands.
1508 Location loc = convOp.getLoc();
1510 rewriter, loc, input, newInputType);
1512 rewriter, loc, kernel, newKernelType);
1514 rewriter, loc, output, newOutputType);
1515
1516 // Rank-reduce strides and dilations too.
1517 // TODO: dropDim 1-liner helper.
1518 strides.erase(strides.begin() + (removeH ? 0 : 1));
1519 auto stridesAttr = rewriter.getI64VectorAttr(strides);
1520
1521 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1522 auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1523
1524 auto conv1DOp = Conv1DOp::create(
1525 rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1526 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1527
1528 // Insert back.
1530 rewriter, loc, conv1DOp.getResult(0), output);
1531 rewriter.replaceOp(convOp, inserted);
1532
1533 return conv1DOp;
1534}
1535
1536template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1537 Conv1DNwcWcfOp>;
1538template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1539 Conv1DNcwFcwOp>;
1540template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1541 PoolingNwcSumOp>;
1542template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1543 PoolingNcwSumOp>;
1544template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1545 PoolingNwcMaxOp>;
1547 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1548template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1549 PoolingNwcMinOp>;
1551 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1552template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1553 PoolingNcwMaxOp>;
1554
1555FailureOr<DepthwiseConv1DNwcWcOp>
1557 LinalgOp convOp, PatternRewriter &rewriter) const {
1558 // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
1559 std::optional<DilationsAndStrides> convParams =
1561 if (!convParams)
1562 return failure();
1563 SmallVector<int64_t> dilations = std::move(convParams->dilations);
1564 SmallVector<int64_t> strides = std::move(convParams->strides);
1565
1566 if (convOp.hasPureBufferSemantics())
1567 return failure(); // To be implemented.
1568
1569 Value input = convOp.getDpsInputs().front();
1570 Value kernel = convOp.getDpsInputs().back();
1571 Value output = convOp.getDpsInits().front();
1572
1573 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1574 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1575 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1576
1577 auto kernelShape = kernelType.getShape();
1578 auto outputShape = outputType.getShape();
1579
1580 // Only handle the case where at least one of the window dimensions is
1581 // of size 1. Other cases can rely on tiling to reduce to such cases.
1582 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1583 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1584 bool removeH = (khSize == 1 && ohSize == 1);
1585 bool removeW = (kwSize == 1 && owSize == 1);
1586 if (!removeH && !removeW)
1587 return failure();
1588
1589 // Get new shapes and types for all operands by removing the size-1
1590 // dimension.
1591 using RTTBuilder = RankedTensorType::Builder;
1592 RankedTensorType newInputType =
1593 RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1594 RankedTensorType newKernelType =
1595 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1596 RankedTensorType newOutputType =
1597 RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1598
1599 // Rank-reduce operands.
1600 Location loc = convOp.getLoc();
1602 rewriter, loc, input, newInputType);
1604 rewriter, loc, kernel, newKernelType);
1606 rewriter, loc, output, newOutputType);
1607
1608 // Rank-reduce strides and dilations too.
1609 // TODO: dropDim 1-liner helper.
1610 strides.erase(strides.begin() + (removeH ? 0 : 1));
1611 auto stridesAttr = rewriter.getI64VectorAttr(strides);
1612
1613 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1614 auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1615
1616 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1617 rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1618 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1619
1620 // Insert back.
1622 rewriter, loc, conv1DOp.getResult(0), output);
1623 rewriter.replaceOp(convOp, inserted);
1624
1625 return conv1DOp;
1626}
1627
1628FailureOr<Conv1DOp>
1630 PatternRewriter &rewriter) const {
1631 // Check if this LinalgOp is a Conv2DOp (named or generic).
1632 std::optional<DilationsAndStrides> convParams =
1634 if (!convParams)
1635 return failure();
1636
1637 if (convOp.hasPureBufferSemantics())
1638 return failure(); // To be implemented.
1639
1640 Value input = convOp.getDpsInputs().front();
1641 Value kernel = convOp.getDpsInputs().back();
1642 Value output = convOp.getDpsInits().front();
1643
1644 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1645 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1646 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1647
1648 auto kernelShape = kernelType.getShape();
1649 auto outputShape = outputType.getShape();
1650
1651 // Only handle the case where at least one of the window dimensions is
1652 // of size 1. Other cases can rely on tiling to reduce to such cases.
1653 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1654 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1655 bool removeH = (khSize == 1 && ohSize == 1);
1656 bool removeW = (kwSize == 1 && owSize == 1);
1657 if (!removeH && !removeW)
1658 return failure();
1659
1660 // Get new shapes and types for all operands by removing the size-1
1661 // dimension.
1662 using RTTBuilder = RankedTensorType::Builder;
1663 RankedTensorType newInputType =
1664 RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1665 RankedTensorType newKernelType =
1666 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1667 RankedTensorType newOutputType =
1668 RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1669
1670 // Rank-reduce operands.
1671 Location loc = convOp.getLoc();
1673 rewriter, loc, input, newInputType);
1675 rewriter, loc, kernel, newKernelType);
1677 rewriter, loc, output, newOutputType);
1678
1679 auto conv1DOp =
1680 Conv1DOp::create(rewriter, loc, newOutputType,
1681 ValueRange{newInput, newKernel}, ValueRange{newOutput});
1682
1683 // Insert back.
1685 rewriter, loc, conv1DOp.getResult(0), output);
1686 rewriter.replaceOp(convOp, inserted);
1687
1688 return conv1DOp;
1689}
1690
1711
1716
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition AffineMap.h:315
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:132
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:444
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::optional< DilationsAndStrides > matchConvolutionOpOfType(LinalgOp op)
Given a linalg op this function returns DilationsAndStrides if it is a convolution op of type ConvOpT...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition Transforms.h:204
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Definition Transforms.h:194
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.