MLIR 22.0.0git
Transforms.cpp
Go to the documentation of this file.
1//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logic and helpers to expose Linalg transforms as rewrite
10// patterns.
11//
12//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/Support/LLVM.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/InterleavedRange.h"
34#include "llvm/Support/raw_ostream.h"
35#include <utility>
36
37#define DEBUG_TYPE "linalg-transforms"
38
39using namespace mlir;
40using namespace mlir::linalg;
41
42//===----------------------------------------------------------------------===//
43// Transformations exposed as functional-style API calls.
44//===----------------------------------------------------------------------===//
45
46//===----------------------------------------------------------------------===//
47// peelLoop transformation.
48//===----------------------------------------------------------------------===//
49
50/// Try to peel and canonicalize loop `op` and return the new result.
51/// Also applies affine_min/max bounds simplification on the fly where relevant.
52// TODO: Add support for scf.parallel and affine.for loops.
54 Operation *op) {
56 .Case<scf::ForOp>([&](scf::ForOp forOp) {
57 scf::ForOp partialIteration;
58 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
59 partialIteration)))
60 return partialIteration->getResults();
61 assert(!partialIteration && "expected that loop was not peeled");
62 return forOp->getResults();
63 })
64 .Default([&](Operation *op) { return op->getResults(); });
65}
66
67/// Peel 'loops' and applies affine_min/max bounds simplification on the fly
68/// where relevant.
71 for (auto loopOp : loops)
72 peelLoop(rewriter, loopOp);
73}
74
75//===----------------------------------------------------------------------===//
76// pack transformation.
77//===----------------------------------------------------------------------===//
78
79#ifndef NDEBUG
80/// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
82 bool found = false;
83 for (AffineExpr e : map.getResults()) {
84 if (!e.isFunctionOfDim(dim))
85 continue;
86 if (found)
87 return false;
88 found = true;
89 }
90 return true;
91}
92#endif // NDEBUG
93
95 return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
96}
97
98/// Return the index of the first result of `map` that is a function of
99/// AffineDimExpr(dim), std::nullopt otherwise.
100static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
101 int64_t dim) {
102 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
103 AffineExpr expr = map.getResult(i);
104 if (!expr.isFunctionOfDim(dim))
105 continue;
106 return i;
107 }
108 return std::nullopt;
109}
110
111/// Perform one step of packing of a LinalgOp's metadata along `dim` into the
112/// `newDim` at `iteratorTypes.size()` by:
113/// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
114/// 2. Appending a `newDim` to the domain of every indexing map.
115/// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
116/// by potentially adding a `newDim` result to `map`.
117/// The preserved invariant is that `iteratorTypes.size()` is always equal to
118/// `map.getNumDims()` for every map in `indexingMaps`.
119///
120/// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
121/// Return a vector that records the optional packing for each operand.
122/// Return failure if the packed indexing cannot be represented with a LinalgOp.
123///
124/// Further details:
125/// ================
126/// The current implementation of packing (i.e. data tiling) consists of
127/// rewriting a linearized strip-mined form into a higher-dimensional access.
128/// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
129/// `I` into `4 * i + ii`, where `0 <= ii < 4`.
130/// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
131///
132/// This rewrite into higher dimensional access is not possible for general
133/// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
134/// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
135/// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
136/// The rewrite of the access would be a form not representable in Linalg:
137/// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
138/// Note however that as `J` and `ii` iterate, the accesses do not have a
139/// particular alignment, so packing does not achieve alignment in this case
140///
141/// In the future, we may want to consider a mixed-form that allows some
142/// alignment in the presence of multiple accesses:
143/// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
144/// And would rewrite accesses as:
145/// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
146static FailureOr<SmallVector<std::optional<int64_t>>>
149 int64_t dim) {
150 int64_t newDim = iteratorTypes.size();
151 iteratorTypes.push_back(iteratorTypes[dim]);
152
153 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
154 indexingMaps.size(), std::nullopt);
156 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
157 ++operandIdx) {
158 AffineMap map = indexingMaps[operandIdx];
159
160 // Add the `newDim` to map whatever the case.
161 assert(map.getNumDims() == newDim && "num dims invariant violation");
162 map = map.shiftDims(1, newDim);
163
164 // Get the at-most-1 index of the result that is a function of `dim`.
165 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
166 // logically chunks dimension `dim` into `K * dim + newDim`, where the
167 // packing factor `K` is specified separately.
168 assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
169 "num results invariant violation");
170 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
171 if (!maybeOperandDimensionToPack.has_value()) {
172 newMaps.push_back(map);
173 continue;
174 }
175
176 // We can only pack AffineDimExpr atm.
177 if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
178 return failure();
179
180 // Add `newDim` to the results of the map.
181 map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
182 map.getNumResults());
183 newMaps.push_back(map);
184
185 // Record the that `operandIdx` is packed.
186 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
187 }
188 indexingMaps = newMaps;
189
190 return packedDimPerIndexingMap;
191}
192
193namespace {
194
195/// Helper struct to encode packing along one dimension of a LinalgOp.
196struct PackedOperandsDim {
197 OpFoldResult packedSize;
198 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
199};
200
201/// Helper struct to encode packing along all dimensions of a LinalgOp.
202struct PackedOperandsDimList {
203 void pushBack(PackedOperandsDim &&packedOperandsDims) {
204 spec.emplace_back(packedOperandsDims);
205 }
206 /// Return all the dims that have been packed for operand @ `operandPos`.
207 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
208 /// Return all the pack sizes by which an operand @ `operandPos` is packed.
209 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
210
211private:
212 SmallVector<PackedOperandsDim> spec;
213};
214
215} // namespace
216
217FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
218 linalg::PackOp packOp,
219 bool lowerPadLikeWithInsertSlice) {
220 // 1. Filter out NYI cases.
221 auto packedTensorType =
222 cast<RankedTensorType>(packOp->getResultTypes().front());
223 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
224 return rewriter.notifyMatchFailure(
225 packOp,
226 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
227 }
228
229 Location loc = packOp->getLoc();
230 OpBuilder::InsertionGuard g(rewriter);
231 rewriter.setInsertionPoint(packOp);
232
233 // 2. Compute the permutation vector to shuffle packed shape into the shape
234 // before any outer or inner permutations have been applied.
235 PackingMetadata packingMetadata;
236 SmallVector<int64_t> packedToStripMinedShapePerm =
237 getPackInverseDestPerm(packOp, packingMetadata);
238
239 // 3. Compute the stripMinedShape: this is the packed shape before any outer
240 // or inner permutations have been applied.
241 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
242 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
243
244 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
245 SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
246 rewriter.getIndexAttr(0));
247 SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
248 rewriter.getIndexAttr(0));
249 for (auto [pos, innerSize] :
250 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
251 int outerPos =
252 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
253 OpFoldResult origSize =
254 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
255 OpFoldResult outerSize =
256 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
257 AffineExpr s0, d0, d1;
258 bindDims(rewriter.getContext(), d0, d1);
259 bindSymbols(rewriter.getContext(), s0);
260 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
262 rewriter, loc, map, {outerSize, origSize, innerSize});
263 }
264 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
265 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
266 packingMetadata.reassociations);
267 Value paddingValue = packOp.getPaddingValue();
268 if (!paddingValue) {
269 paddingValue = arith::ConstantOp::create(
270 rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
271 }
272 auto padOp =
273 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
274 highs, paddingValue, /*nofold=*/false);
275
276 LDBG() << "insertPositions: "
277 << llvm::interleaved(packingMetadata.insertPositions);
278 LDBG() << "outerPositions: "
279 << llvm::interleaved(packingMetadata.outerPositions);
280 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
281 LDBG() << "packedToStripMinedShapePerm: "
282 << llvm::interleaved(packedToStripMinedShapePerm);
283 LDBG() << "reassociations: "
284 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
286 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
287 LDBG() << "collapsed type: " << collapsed;
288
289 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
290 // Pack ops which operate as simple pads may not produce legal
291 // tensor.insert_slice operations when the packed type does not rank reduce
292 // to the padded type.
293 SliceVerificationResult rankReduces =
294 isRankReducedType(packedTensorType, padOp.getResultType());
295
296 if (rankReduces == SliceVerificationResult::Success) {
297 // This pack is just a plain pad.
298 // Just insert the pad in the higher ranked tensor.
299 // Offsets.
300 SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
301 rewriter.getIndexAttr(0));
302 // Strides.
303 SmallVector<OpFoldResult> ones(packOp.getDestRank(),
304 rewriter.getIndexAttr(1));
306 tensor::getMixedSizes(rewriter, loc, packOp.getDest());
307
308 auto insertSliceOp = tensor::InsertSliceOp::create(
309 rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
310 /*offsets=*/zeros, sizes, /*strides=*/ones);
311
312 LDBG() << "insert_slice op: " << insertSliceOp;
313
314 rewriter.replaceOp(packOp, insertSliceOp->getResults());
315
316 return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
317 /*transposeOp=*/nullptr};
318 }
319 }
320
321 // 5. Expand from the padded result to the stripMinedShape.
322 auto expandShapeResultType =
323 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
324 auto reshapeOp = tensor::ExpandShapeOp::create(
325 rewriter, loc, expandShapeResultType, padOp.getResult(),
326 packingMetadata.reassociations);
327
328 // 6. Transpose stripMinedShape to packedShape.
329 SmallVector<int64_t> transpPerm =
330 invertPermutationVector(packedToStripMinedShapePerm);
331 auto transposeOp = linalg::TransposeOp::create(
332 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
333
334 LDBG() << "reshape op: " << reshapeOp;
335 LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
336 LDBG() << "transpose op: " << transposeOp;
337
338 // 7. Replace packOp by transposeOp.
339 rewriter.replaceOp(packOp, transposeOp->getResults());
340
341 return LowerPackResult{padOp, reshapeOp, transposeOp};
342}
343
344FailureOr<LowerUnPackOpResult>
345linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
346 bool lowerUnpadLikeWithExtractSlice) {
347 Location loc = unPackOp->getLoc();
348 OpBuilder::InsertionGuard g(rewriter);
349 rewriter.setInsertionPoint(unPackOp);
350
351 RankedTensorType packedTensorType = unPackOp.getSourceType();
352 int64_t packedRank = packedTensorType.getRank();
353
354 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
355 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
356 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
357 // This unpack is just a plain unpad.
358 // Just extract the slice from the higher ranked tensor.
359 ArrayRef<int64_t> destShape = destTensorType.getShape();
360 // The inner dimensions stay the same as the destination tensor, but the
361 // outer ones are additional 1s.
362 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
363 sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
364
365 auto extractSliceOp = tensor::ExtractSliceOp::create(
366 rewriter, loc, destTensorType, unPackOp.getSource(),
367 SmallVector<OpFoldResult>(packedRank, zero), sizes,
368 SmallVector<OpFoldResult>(packedRank, one));
369
370 rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
371
372 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
373 /*reshapeOp=*/nullptr, extractSliceOp};
374 }
375
376 // 1. Compute the permutation vector to shuffle packed shape into the shape
377 // before any outer or inner permutations have been applied.
378 PackingMetadata packingMetadata;
379 SmallVector<int64_t> packedToStripMinedShapePerm =
380 getUnPackInverseSrcPerm(unPackOp, packingMetadata);
381
382 // 2. Compute the stripMinedShape: this is the packed shape without outer and
383 // inner permutations.
384 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
385 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
386
387 // 3. Transpose packedShape to stripMinedShape.
388 RankedTensorType stripMinedTensorType =
389 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
390 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
391 stripMinedTensorType, packingMetadata.reassociations);
392
393 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
394 // permutation.
396 tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
397 applyPermutationToVector(dims, packedToStripMinedShapePerm);
398 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
399 stripMinedTensorType.getElementType());
400 auto transposeOp =
401 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
402 packedToStripMinedShapePerm);
403
404 LDBG() << "insertPositions: "
405 << llvm::interleaved(packingMetadata.insertPositions);
406 LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
407 LDBG() << "packedToStripMinedShapePerm: "
408 << llvm::interleaved(packedToStripMinedShapePerm);
409 LDBG() << "reassociations: "
410 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
412 LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
413 LDBG() << "collapsed type: " << collapsedType;
414
415 // 4. Collapse from the stripMinedShape to the padded result.
416 auto reshapeOp = tensor::CollapseShapeOp::create(
417 rewriter, loc, collapsedType, transposeOp->getResult(0),
418 packingMetadata.reassociations);
419
420 // 5. ExtractSlice.
421 int64_t destRank = destTensorType.getRank();
422 auto extractSliceOp = tensor::ExtractSliceOp::create(
423 rewriter, loc, destTensorType, reshapeOp->getResult(0),
424 SmallVector<OpFoldResult>(destRank, zero),
425 tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
426 SmallVector<OpFoldResult>(destRank, one));
427
428 // 6. Inject a copy to preserve DPS.
429 auto copyOp = linalg::CopyOp::create(
430 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
431
432 // 7. Replace unPackOp by copyOp.
433 rewriter.replaceOp(unPackOp, copyOp->getResults());
434
435 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
436}
437
439PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
441 for (auto &i : spec) {
442 if (!i.packedDimForEachOperand[operandPos].has_value())
443 continue;
444 res.push_back(i.packedDimForEachOperand[operandPos].value());
445 }
446 return res;
447}
448
449SmallVector<OpFoldResult>
450PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
451 SmallVector<OpFoldResult> res;
452 for (auto &i : spec) {
453 if (!i.packedDimForEachOperand[operandPos].has_value())
454 continue;
455 res.push_back(i.packedSize);
456 }
457 return res;
458}
459
460/// Implement packing of a single LinalgOp by performing packing by
461/// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
462/// Return the packed Linalg op on success, failure otherwise.
463FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
464 linalg::LinalgOp linalgOp,
465 ArrayRef<OpFoldResult> packedSizes) {
466 if (packedSizes.size() != linalgOp.getNumLoops()) {
467 return rewriter.notifyMatchFailure(linalgOp,
468 "incorrect number of pack sizes");
469 }
470
471 Location loc = linalgOp->getLoc();
472 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
474 linalgOp.getIteratorTypesArray();
475 LDBG() << "Start packing: " << linalgOp;
476 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
477 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
478
481 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
482 PackedOperandsDimList listOfPackedOperandsDim;
483 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
484 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
485 // Skip tile sizes explicitly set to 0.
486 if (maybeConstant.has_value() && maybeConstant.value() == 0)
487 continue;
488
489 PackedOperandsDim packedOperandsDims;
490 packedOperandsDims.packedSize = packedSizes[i];
491 FailureOr<SmallVector<std::optional<int64_t>>>
492 maybePackedDimForEachOperand =
493 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
494 if (failed(maybePackedDimForEachOperand))
495 return failure();
496 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
497
498 LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
499 LDBG() << "maps: " << llvm::interleaved(indexingMaps);
500 LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
501 LDBG() << "packedDimForEachOperand: "
502 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
503
504 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
505 }
506
507 // Step 2. Propagate packing to all LinalgOp operands.
508 SmallVector<Value> inputsAndInits, results;
509 SmallVector<OpOperand *> initOperands =
510 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
511 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
512 for (const auto &operandsList : {inputOperands, initOperands}) {
513 for (OpOperand *opOperand : operandsList) {
514 int64_t pos = opOperand->getOperandNumber();
515 Value operand = opOperand->get();
516 SmallVector<int64_t> innerPos =
517 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
518 SmallVector<OpFoldResult> innerPackSizes =
519 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520 LDBG() << "operand: " << operand;
521 LDBG() << "innerPos: " << llvm::interleaved(innerPos);
522 LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
523 if (innerPackSizes.empty()) {
524 inputsAndInits.push_back(operand);
525 continue;
526 }
527 Value dest = linalg::PackOp::createDestinationTensor(
528 rewriter, loc, operand, innerPackSizes, innerPos,
529 /*outerDimsPerm=*/{});
530 ShapedType operandType = cast<ShapedType>(operand.getType());
531 bool areConstantTiles =
532 llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
533 return getConstantIntValue(tile).has_value();
534 });
535 if (areConstantTiles && operandType.hasStaticShape() &&
536 !linalg::PackOp::requirePaddingValue(
537 operandType.getShape(), innerPos,
538 cast<ShapedType>(dest.getType()).getShape(), {},
539 innerPackSizes)) {
540 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541 innerPos, innerPackSizes));
542 } else {
543 // TODO: value of the padding attribute should be determined by
544 // consumers.
545 auto zeroAttr =
546 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
547 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548 packOps.push_back(linalg::PackOp::create(
549 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
550 }
551 inputsAndInits.push_back(packOps.back());
552 }
553 }
554
555 // Step 3. Build the packed op, use the type of `inits` as result types.
556 ValueRange inputs =
557 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
558 ValueRange inits =
559 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560 auto packedLinalgOp =
561 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
562 inputs, inits, indexingMaps, iteratorTypes);
563 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
564
565 // Step 4. Propagate packing to all the op results.
566 for (OpResult result : packedLinalgOp->getResults()) {
567 int64_t resultNum = result.getResultNumber();
568 linalg::PackOp maybePackedInit =
569 inits[resultNum].getDefiningOp<linalg::PackOp>();
570 if (!maybePackedInit) {
571 results.push_back(result);
572 continue;
573 }
574 // Build the symmetrical UnPackOp to the existing PackOp.
575 unPackOps.push_back(linalg::UnPackOp::create(
576 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
577 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578 results.push_back(unPackOps.back());
579 }
580
581 // Step 5. Replace `linalgOp`.
582 rewriter.replaceOp(linalgOp, results);
583
584 // Return packedLinalgOp.
585 return PackResult{packOps,
586 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
587 unPackOps};
588}
589
590//===----------------------------------------------------------------------===//
591// packTranspose transformation.
592//===----------------------------------------------------------------------===//
593
594/// Return a copy of `tensorType` after permutation by `permutationVector`.
595// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
596// but this would introduce a dependence on Dialect in IR.
597// TODO: Restructure.
598static RankedTensorType permuteShape(RankedTensorType tensorType,
599 ArrayRef<int64_t> permutationVector) {
600 SmallVector<int64_t> shape(tensorType.getShape());
601 applyPermutationToVector(shape, permutationVector);
602 return RankedTensorType::Builder(tensorType).setShape(shape);
603}
604
605/// Return a new GenericOp obtained by transposing opOperand by the permutation
606/// vector:
607/// - the corresponding indexing map is transposed by `permutation`
608/// - the corresponding operand value is replaced by `transposedValue`
609/// `linalgOp` is replaced by the return op in the process.
610/// Asserts that `transposedValue` is of the proper transposed ShapedType.
612 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
613 ArrayRef<int64_t> permutation, Value transposedValue) {
614 // Sanity check the operand.
615 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
616
617 // Sanity check of the expected transposed tensor type.
618 auto tensorType = permuteShape(
619 cast<RankedTensorType>(opOperand.get().getType()), permutation);
620 (void)tensorType;
621 assert(tensorType == transposedValue.getType() &&
622 "expected tensor type mismatch");
623
624 // Compute the transposed indexing map.
625 // Sigh unsigned pollution.
626 SmallVector<unsigned> tmpTransposition = llvm::to_vector(
627 llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
628 AffineMap permutationMap =
629 AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
630 AffineMap transposedMap =
631 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
632
633 // Set the transposed indexing map in the proper position.
634 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
635 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
636 // Set the transposedValue in the proper operand position.
637 SmallVector<Value> operands = linalgOp->getOperands();
638 operands[opOperand.getOperandNumber()] = transposedValue;
639
640 ValueRange operandsRef(operands);
641 auto transposedGenericOp = linalg::GenericOp::create(
642 rewriter,
643 /*location=*/linalgOp->getLoc(),
644 /*resultTensorTypes=*/
645 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
646 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
647 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
648 /*indexingMaps=*/indexingMaps,
649 /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
650 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
652
653 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
654}
655
656FailureOr<PackTransposeResult>
657linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
658 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
659 ArrayRef<int64_t> outerPerm,
660 ArrayRef<int64_t> innerPerm) {
661 Location loc = linalgOp.getLoc();
662
663 // Step 1. Transpose packOp.
664 rewriter.setInsertionPoint(packOp);
665 linalg::PackOp transposedPackOp =
666 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
667
668 if (!packOp.getResult().hasOneUse())
669 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
670
671 OpOperand &packUse = *packOp->getUses().begin();
672 if (packUse.getOwner() != linalgOp) {
673 return rewriter.notifyMatchFailure(
674 linalgOp, "not a single use by the LinalgOp target");
675 }
676 if (maybeUnPackOp &&
677 (!linalgOp.isDpsInit(&packUse) ||
678 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
679 return rewriter.notifyMatchFailure(linalgOp,
680 "not produced by the LinalgOp target");
681 }
682
683 // Step 2. Transpose linalgOp.
684 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
685 // identity. Don't rely on it.
686 int64_t numLeadingDims = packOp.getSourceRank();
687 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
688 // Step 2.a. Compute the permutation on the whole operand.
689 // Leading part just reuse the outerPerm.
690 SmallVector<int64_t> permutation(outerPerm);
691 if (permutation.empty())
692 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
693 // Trailing part needs to reindex positions by `numLeadingDims`.
694 if (innerPerm.empty()) {
695 llvm::append_range(
696 permutation,
697 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
698 } else {
699 llvm::append_range(permutation,
700 llvm::map_range(innerPerm, [&](int64_t pos) {
701 return numLeadingDims + pos;
702 }));
703 }
704 if (!isPermutationVector(permutation))
705 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
706
707 // Step 2.b. Save the transposedPackUse operand number in case we need to
708 // get the tied OpResult after `linalgOp` has been replaced.
709 int64_t packUseOperandNumber = packUse.getOperandNumber();
710 // Step 2.c. Actually perform the transposition.
711 rewriter.setInsertionPoint(linalgOp);
712 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
713 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
714
715 // Step 3. Maybe transpose unPackOp.
716 linalg::UnPackOp transposedUnPackOp;
717 if (maybeUnPackOp) {
718 OpOperand &opOperand =
719 transposedLinalgOp->getOpOperand(packUseOperandNumber);
720 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
721 rewriter.setInsertionPoint(maybeUnPackOp);
722 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723 rewriter, loc, transposedResult, innerPerm, outerPerm);
724
725 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
726 }
727
728 // Step 4. Finally, replace packOp now that we don't need it anymore.
729 rewriter.replaceOp(packOp, transposedPackOp->getResults());
730
731 return PackTransposeResult{transposedPackOp, transposedLinalgOp,
732 transposedUnPackOp};
733}
734
735//===----------------------------------------------------------------------===//
736// packMatmulGreedily transformation.
737//===----------------------------------------------------------------------===//
738
739/// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
740/// and n are proper parallel dimensions and k is a proper reduction
741/// dimension. Packing occurs by rewriting the op as a linalg.generic and
742/// calling linalg::pack by `mnkPackedSizes`. The order of the packed
743/// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
744/// to reorder {m, n, k} into one of the 8 possible forms. The outer
745/// dimensions of the operands are not permuted at this time, this is left for
746/// future work.
747FailureOr<PackResult>
748linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
749 ArrayRef<OpFoldResult> mnkPackedSizes,
750 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
751 ArrayRef<int64_t> mnkOrder) {
752 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
753 assert((mnkPaddedSizesNextMultipleOf.empty() ||
754 mnkPaddedSizesNextMultipleOf.size() == 3) &&
755 "num of packing sizes next multiple should be empty or of size 3");
756 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
757 assert(isPermutationVector(mnkOrder) && "expected a permutation");
758
759 int64_t numLoops = linalgOp.getNumLoops();
760 if (numLoops <= 2) {
761 LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
762 << " in: " << linalgOp;
763 return rewriter.notifyMatchFailure(
764 linalgOp, "need 3+ loops to find a matmul to pack");
765 }
766
767 // Locally adjust the desired iterator position of mnk and packing sizes.
768 int64_t numPackedDims = mnkPackedSizes.size();
769 SmallVector<int64_t> mmnnkkPos(numPackedDims);
770 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
771 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
772 SmallVector<OpFoldResult> packedSizes(numPackedDims);
773 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
774 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
775 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
776 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
777 paddedSizesNextMultipleOf[mnkOrder[i]] =
778 mnkPaddedSizesNextMultipleOf.empty() ? 0
779 : mnkPaddedSizesNextMultipleOf[i];
780 }
781
782 // 1. Infer dims that are important for matmul.
783 FailureOr<ContractionDimensions> maybeDimensions =
784 inferContractionDims(linalgOp);
785 if (failed(maybeDimensions)) {
786 LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
787 return rewriter.notifyMatchFailure(linalgOp,
788 "couldn't infer matmul iterators");
789 }
790
791 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
792 // minor iterators. In cases with multiple options for m, n, k bias towards
793 // the most minor embedding.
794 // If we wanted a different normalization order, this is where it would have
795 // to plug a heuristic.
796 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797 kPos = maybeDimensions->k.back();
798 LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
799 << nPos << ", k@" << kPos << "): " << linalgOp;
800
801 // 2.a. Rewrite as a generic.
802 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
803 if (!genericOp) {
804 FailureOr<GenericOp> generalizeResult =
805 generalizeNamedOp(rewriter, linalgOp);
806 assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
807 genericOp = *generalizeResult;
808 }
809
810 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
811 // iterators. Note that this only normalized the iteration order and does
812 // not change the indexings of any operand.
813 SmallVector<int64_t> permutation =
814 computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
815 LDBG() << "perm: " << llvm::interleaved(permutation);
816 // Sign .. unsigned pollution.
817 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
818 FailureOr<GenericOp> interchangeResult =
819 interchangeGenericOp(rewriter, genericOp, unsignedPerm);
820 assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
821 genericOp = *interchangeResult;
822 LDBG() << "Generalized Op to pack: " << genericOp;
823
824 // At this point, the op iterators are normalized to {leading, k, m, n}.
825 // The layouts induced by packing will always be:
826 // - LHS{leading_lhs, kk, mm}
827 // - RHS{leading_rhs, kk, nn}
828 // - RES{leading_res, mm, nn}
829 // If we wanted to change the packed order, we would reorder (k, m, n) to
830 // something else above.
831 //
832 // Additional permutations of the outer dims of the operands (i.e.
833 // leading_lhs, leading_rhs and leading_res) could follow by computing the
834 // desired outerPerm for each operand.
835 // This is left for future work.
836
837 // TODO: this creates too much IR, go use reifyResultShapes.
838 SmallVector<Range, 4> loopRanges =
839 cast<LinalgOp>(genericOp.getOperation())
840 .createLoopRanges(rewriter, genericOp.getLoc());
841
842 // Add leading zeros to match numLoops, we only pack the last 3 dimensions
843 // post interchange.
844 LDBG() << "paddedSizesNextMultipleOf: "
845 << llvm::interleaved(paddedSizesNextMultipleOf);
846 LDBG() << "loopRanges: "
847 << llvm::interleaved(
848 llvm::map_range(loopRanges, [](Range r) { return r.size; }));
849 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
850 rewriter.getIndexAttr(0));
851 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
852 if (paddedSizesNextMultipleOf[i] == 0) {
853 adjustedPackedSizes.push_back(packedSizes[i]);
854 continue;
855 }
856 AffineExpr d0, s0;
857 bindDims(rewriter.getContext(), d0);
858 bindSymbols(rewriter.getContext(), s0);
859 adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
860 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
861 {loopRanges[adjustedPackedSizes.size()].size,
862 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
863 }
864 LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
865
866 // TODO: If we wanted to give the genericOp a name after packing, after
867 // calling `pack` would be a good time. One would still need to check that
868 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
869 // also allow degenerate matmul cases (i.e. matvec, dot).
870 return pack(rewriter, genericOp, adjustedPackedSizes);
871}
872
873//===----------------------------------------------------------------------===//
874// Transformations exposed as rewrite patterns.
875//===----------------------------------------------------------------------===//
876
879 assert(!tileSizeComputationFunction && "tile sizes already set");
880 SmallVector<int64_t, 4> tileSizes(ts);
881 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
883 b.setInsertionPointToStart(
884 &op->getParentOfType<func::FuncOp>().getBody().front());
885 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
886 Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
887 return v;
888 }));
889 };
890 return *this;
891}
892
894 memref::CopyOp copyOp, PatternRewriter &rewriter) const {
895 return vectorizeCopy(rewriter, copyOp);
896}
897
898/// Filling `dest` using FillOp constant padding value if possible.
899/// Otherwise, generate a tensor::GenerateOp.
901 RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
902 const SmallVector<Value> &dynSizes) const {
903 auto padValue = padOp.getConstantPaddingValue();
904 if (padValue) {
905 // Move the padding value defined inside the PadOp block to outside.
906 if (padValue.getParentBlock() == &padOp.getRegion().front())
907 rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
908 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
909 }
910
911 // Fill could not be optimized: Lower to tensor::GenerateOp with region.
912 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913 padOp.getResultType(), dynSizes);
914 // Copy region to new op.
915 IRMapping bvm;
916 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
917 return generateOp;
918}
919
920LogicalResult
922 PatternRewriter &rewriter) const {
923 // Given an OpFoldResult, return an index-typed value.
924 auto getIdxValue = [&](OpFoldResult ofr) {
925 if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
926 return val;
928 rewriter, padOp.getLoc(),
929 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
930 .getResult();
931 };
932
933 auto resultType = padOp.getResultType();
934 // Compute size of EmptyOp. Any combination of static/dynamic is supported.
935 SmallVector<Value> dynSizes;
936 SmallVector<int64_t> staticSizes;
937 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938 if (resultType.isDynamicDim(dim)) {
939 auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
940 padOp.getSource(), dim));
941 // Add low and high padding value.
942 auto plusLow = rewriter.createOrFold<arith::AddIOp>(
943 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
944 auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
945 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946 dynSizes.push_back(plusHigh);
947 }
948 staticSizes.push_back(resultType.getDimSize(dim));
949 }
950
951 // Init tensor and fill it with padding.
952 Value emptyTensor =
953 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954 resultType.getElementType(), dynSizes);
955 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
956
957 // Generate a InsertSliceOp for copying the PadOp source.
958 auto sourceType = padOp.getSourceType();
959 // Compute size of source of tensor::PadOp.
961 tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
962 // Strides of InsertSliceOp are all 1.
963 SmallVector<OpFoldResult> strides(sourceType.getRank(),
964 rewriter.getIndexAttr(1));
965 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
966 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
967 strides);
968
969 return success();
970}
971
973 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
974 if (!sliceOp.hasUnitStride())
975 return failure();
976
977 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
978 if (!padOp)
979 return failure();
980
981 bool zeroSliceGuard = true;
982 if (controlFn) {
983 if (std::optional<bool> control = controlFn(sliceOp))
984 zeroSliceGuard = *control;
985 else
986 return failure();
987 }
988
989 FailureOr<TilingResult> tilingResult =
990 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
991 sliceOp.getMixedSizes(), zeroSliceGuard);
992 if (failed(tilingResult))
993 return failure();
994
995 RankedTensorType sourceType = sliceOp.getSourceType();
996 RankedTensorType resultType = sliceOp.getResultType();
997
998 // If the extract_slice is not rank-reduced, all shapes are static and the
999 // data source is actually used. Rewrite into pad(extract_slice(x)).
1000 if (sourceType.getRank() == resultType.getRank()) {
1001 rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1002 return success();
1003 }
1004
1005 // Handle rank-reduced slice by creating another extract_slice op.
1007 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1008
1009 rewriter.replaceOp(sliceOp, rankReduced);
1010 return success();
1011}
1012
1013/// If padding value is set, returns a tensor.pad Op for the source tensor,
1014/// with the output shape matching the output of `packOp`. Otherwise, returns
1015/// the source directly.
1016///
1017/// This method assumes that all outer dims for this pack Op are 1.
1019 linalg::PackOp packOp) {
1020 Value input = packOp.getSource();
1021 if (!packOp.getPaddingValue()) {
1022 return input;
1023 }
1024
1025 assert(llvm::all_of(packOp.getAllOuterDims(),
1026 [](int64_t val) { return val == 1; }) &&
1027 "some outer dims are != 1");
1028
1029 Location loc = packOp.getLoc();
1030 ShapedType inputType = packOp.getSourceType();
1031 int64_t inputRank = inputType.getRank();
1032
1033 DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1034 packOp.getDimAndTileMapping();
1035
1036 // The sizes of dynamic tiles
1037 SmallVector<Value> dynamicTileSizes;
1038
1039 // Collect dims for the padded shape.
1040 SmallVector<int64_t> paddedShape;
1041 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1042 // 1. Non-tiled outer dims.
1043 // These dims should be 1 and we simply preserve them.
1044 if (!tileAndPosMapping.count(dimIdx)) {
1045 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046 assert(inputDimSize == 1 &&
1047 "with all outer dims == 1, this non-tiled input dim should be 1!");
1048 paddedShape.push_back(inputDimSize);
1049 continue;
1050 }
1051
1052 // 2. Tiled outer dims
1053 // As all outer dims == 1, it is safe to use the tile size for the padded
1054 // shape.
1055 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1056
1057 // 2.1 Static tile sizes
1058 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1059 if (cstTileSize.has_value()) {
1060 paddedShape.push_back(cstTileSize.value());
1061 continue;
1062 }
1063
1064 // 2.2 Dynamic tile sizes
1065 paddedShape.push_back(ShapedType::kDynamic);
1066
1067 // Get the value that holds the dynamic size.
1068 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1069 }
1070 auto resultType =
1071 RankedTensorType::get(paddedShape, inputType.getElementType());
1072 return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1073 /*nofold=*/false, loc, builder,
1074 dynamicTileSizes);
1075}
1076
1077// Normalizes a permutation on a higher rank space to its actual size, e.g.
1078// perm = [1, 4, 2]
1079// becomes
1080// norm = [0, 2, 1]
1081static SmallVector<int64_t>
1083 constexpr int64_t kNonTiledMarker = -1;
1084 SmallVector<int64_t> vec(rank, kNonTiledMarker);
1085 for (auto [index, value] : llvm::enumerate(perm))
1086 vec[value] = index;
1087 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1088 vec, [&](int64_t v) { return v != kNonTiledMarker; });
1089 // This inverts the permutation in addition to normalizing so invert back.
1090 return invertPermutationVector(normalizedPerm);
1091}
1092
1093// Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1094// assuming rank reduction of unit outer dims.
1095static SmallVector<int64_t>
1097 ArrayRef<int64_t> innerDimsPos,
1098 ArrayRef<int64_t> outerDimsPerm) {
1099 SmallVector<int64_t> rankReducedOuterDimsPerm;
1100 SmallVector<int64_t> outerDims;
1101 SmallVector<int64_t> innerDims;
1102 int64_t dim = 0;
1103 int64_t unpackedRank = shape.size();
1104 for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1105 if (llvm::is_contained(innerDimsPos, i)) {
1106 innerDims.push_back(dim++);
1107 continue;
1108 }
1109 if (shape[i] == 1)
1110 continue;
1111 outerDims.push_back(dim++);
1112 if (!outerDimsPerm.empty())
1113 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1114 }
1115
1116 // Get the position of the inner dims after permutation.
1117 SmallVector<int64_t> innerPerm =
1118 getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1119 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1120
1121 // Ditto for the outer dims.
1122 SmallVector<int64_t> perm = outerDims;
1123
1124 rankReducedOuterDimsPerm =
1125 getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1126 if (!rankReducedOuterDimsPerm.empty())
1127 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1128
1129 // The tile always ends up as the inner most dims after packing.
1130 perm.append(innerDims);
1131
1132 return perm;
1133}
1134
1136 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1137 if (llvm::any_of(packOp.getTiledOuterDims(),
1138 [](int64_t dim) { return dim != 1; })) {
1139 return rewriter.notifyMatchFailure(
1140 packOp, "not all outer dimensions of the result are 1s");
1141 }
1142
1143 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1144 auto outerDimsPerm = packOp.getOuterDimsPerm();
1145
1146 // Verify that there are no:
1147 // * non-unit + un-tiled-outer-dims,
1148 // that are permuted. Supporting such cases would require refining the logic
1149 // that generates the Transpose Op.
1150 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1151 static int prev = 0;
1152 // Skip tiled dims - these can be permuted.
1153 if (llvm::is_contained(innerDimsPos, dim))
1154 return true;
1155
1156 // Check whether this dim has been permuted. Permuting unit dims is fine
1157 // as that's effectively a no-op.
1158 if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159 packOp.getType().getShape()[dim] != 1))
1160 return false;
1161
1162 prev = dim;
1163 return true;
1164 })) {
1165 return rewriter.notifyMatchFailure(
1166 packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1167 "this is not supported ATM!");
1168 }
1169
1170 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1171 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1172 Location loc = packOp.getLoc();
1173
1174 int64_t srcRank = packOp.getSourceRank();
1175 int64_t destRank = packOp.getDestRank();
1176
1177 // 1. Get the input that is going to be packed. If the input requires padding,
1178 // add a padding operation and return that as the input.
1179 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1180
1181 // 2. Transpose the input to match the inner tile order:
1182 // %init = tensor.empty()
1183 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1184 // outs(%init)
1185 // Assumptions made:
1186 // - All tiled outer dims are 1 - the corresponding transposition order
1187 // doesn't matter, but requires all dim indices to be present.
1188 // - Un-tiled outer dims remain un-permuted.
1189
1190 // 2.1 Get the permutation for linalg.transpose:
1191 // [ untiled-dims, inner-dims-pos ]
1192 // Note, this logic assumes that the untiled dims are not permuted.
1193 SmallVector<int64_t> srcPermForTranspose;
1194 for (int64_t i = 0; i < srcRank; i++) {
1195 // We assume the `k` dimensions of the inner dim position, where `k` is the
1196 // rank of the inner tiling, correspond to the last `k` indices of the
1197 // transpose permutation. This is done by adding the indices not contained
1198 // in the inner dimension position in order from 0 to `n`. Where n is the
1199 // rank of the source tensor. For example if we have a source tensor with
1200 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1201 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1202 if (llvm::is_contained(innerDimsPos, i))
1203 continue;
1204 srcPermForTranspose.push_back(i);
1205 }
1206 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1207
1208 // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1209 // [ untiled-dims, tiled-dims ]
1210 ShapedType inputTy = cast<ShapedType>(input.getType());
1211 SmallVector<OpFoldResult> shapeForEmptyOp;
1212 for (int64_t i = 0; i < srcRank; i++) {
1213 if (llvm::is_contained(innerDimsPos, i)) {
1214 // The tiled dims are appended after this loop.
1215 continue;
1216 }
1217 if (inputTy.isStaticDim(i))
1218 shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1219 else
1220 shapeForEmptyOp.emplace_back(
1221 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1222 }
1223 shapeForEmptyOp.append(packOp.getMixedTiles());
1224
1225 // getMixedTiles() may contain Values pointing to constant ops, not the
1226 // constant attributes. Replace them with a true OpFoldResult.
1227 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1228 [&](OpFoldResult ofr) {
1229 if (auto val = llvm::dyn_cast<Value>(ofr))
1230 return getAsOpFoldResult(val);
1231 return ofr;
1232 });
1233
1234 LDBG() << "Pack permutation: " << packOp;
1235 LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1236 LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1237
1238 Value empty = tensor::EmptyOp::create(
1239 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1240
1241 // 2.3 Create linalg.transpose
1242 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1243 srcPermForTranspose);
1244
1245 // 3. Insert the inner tile into the destination tensor:
1246 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1247
1248 // Compute the sizes attribute:
1249 // [ outer-dims, tile-sizes ]
1250 // Note that the output from the transpose Op excludes the tiled outer dims.
1251 // However, given the assumption that:
1252 // * all tiled outer dims == 1,
1253 // we can just use a rank-expanding tensor.insert_slice.
1254 SmallVector<OpFoldResult> writeSizes;
1255 for (auto size : packOp.getAllOuterDims()) {
1256 writeSizes.push_back(rewriter.getIndexAttr(size));
1257 }
1258
1259 for (auto tileSize : packOp.getMixedTiles()) {
1260 auto [_, tileSizeOfr] =
1261 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1262 writeSizes.push_back(tileSizeOfr);
1263 }
1264
1265 // TODO: Add a constructor for tensor.insert_slice that doesn't require
1266 // strides nor offsets.
1267 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1268 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1269
1270 auto insert = tensor::InsertSliceOp::create(
1271 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1272 writeOffsets, writeSizes, writeStrides);
1273
1274 // 4. Replace tensor.packOp with tensor.insert_slice created above
1275 rewriter.replaceOp(packOp, insert.getResult());
1276
1277 return success();
1278}
1279
1281 linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1282 int64_t srcRank = unpackOp.getSourceRank();
1283 int64_t destRank = unpackOp.getDestRank();
1284 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1285 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1286 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1287 [](int64_t dim) { return dim != 1; })) {
1288 return rewriter.notifyMatchFailure(
1289 unpackOp,
1290 "require the tiled outer dimensions of the result are all 1s");
1291 }
1292
1293 // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1294 // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1295 Location loc = unpackOp.getLoc();
1296 Value source = unpackOp.getSource();
1297 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1298 unpackOp.getDimAndTileMapping();
1299 Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1300 Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1301
1302 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1303 // dims:
1304 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1305 SmallVector<int64_t> readShapeForExtractSlice;
1306 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1307 // outer-tiled-dims being all 1), this will be
1308 // [ outer-untiled-dims, tile-sizes ]
1309 SmallVector<OpFoldResult> extractSliceSizes;
1310 // The offset and strides attributes for ExtractSliceOp.
1311 SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1312 SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1313
1314 // Shape for EmptyOp that's used as the init value for TransposeOp below.
1315 // This should be:
1316 // [ outer-untiled-dims, tile-sizes ]
1317 // However, skip unit dims - TransposeOp (below) applies rank-reduced
1318 // permutation.
1319 SmallVector<OpFoldResult> shapeForEmptyOp;
1320
1321 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1322 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1323 //
1324 // As all outer tiled dims are 1, so the corresponding
1325 // slice size to read will also 1. As this will be rank-reducing "extract
1326 // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1327 // update:
1328 // * the output shape for ExtractSliceOp, nor
1329 // * the shape for EmptyOp.
1330 if (dimAndTileMapping.count(i)) {
1331 extractSliceSizes.push_back(oneIdxAttr);
1332 continue;
1333 }
1334
1335 // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1336 // outer-untiled-dims
1337 if (ShapedType::isDynamic(srcShape[i])) {
1338 OpFoldResult dynamicDim =
1339 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1340 extractSliceSizes.push_back(dynamicDim);
1341 shapeForEmptyOp.push_back(dynamicDim);
1342 } else {
1343 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1344 if (srcShape[i] != 1)
1345 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1346 }
1347 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1348 // into account rank-reducing)
1349 if (srcShape[i] != 1) {
1350 readShapeForExtractSlice.push_back(srcShape[i]);
1351 }
1352 }
1353 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1354 // shape for EmptyOp.
1355 auto mixedTiles = unpackOp.getMixedTiles();
1356 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1357 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1358
1359 // Explicitly create the type for extract_slice op because the inner tile
1360 // size could be 1. We want to represent the whole inner tile in this case.
1361 auto tileShape = srcShape.drop_front(destRank);
1362 // Append the inner tile shape to the permuted and rank-reduced outer shape.
1363 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1364 Type elemType = unpackOp.getSourceType().getElementType();
1365 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1366 Value innerTile = tensor::ExtractSliceOp::create(
1367 rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1368 extractSliceSizes, extractSliceStrides);
1369
1370 // 2. Transpose the tile to match the outer corresponding tile order.
1372 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1373 // Unpack is a transition out of packed space so we invert the permutation.
1374 perm = invertPermutationVector(perm);
1375 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1376
1377 Value empty =
1378 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1379 auto transposedOp =
1380 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1381
1382 // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1383 // transposed tile.
1384 int numLoops = shapeForEmptyOp.size();
1385 SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1386 SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1387 SmallVector<OpFoldResult> tileSizes;
1388 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1389 for (auto i : llvm::seq<unsigned>(0, destRank)) {
1390 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1391 tileSizes.push_back(
1392 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1393 }
1394
1395 auto partialTile =
1396 tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1397 tileOffsets, tileSizes, tileStrides);
1398
1399 // 4. Insert the result to the destination tensor.
1400 SmallVector<OpFoldResult> writeSizes;
1401 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1402 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1403 for (int i = 0, idx = 0; i < destRank; ++i) {
1404 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1405 writeSizes.push_back(tileSizes[idx++]);
1406 else
1407 writeSizes.push_back(oneIdxAttr);
1408 }
1409 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1410 unpackOp.getDest(), writeOffsets,
1411 writeSizes, writeStrides);
1412 rewriter.replaceOp(unpackOp, insert.getResult());
1413
1414 return success();
1415}
1416
1417// The following are patterns for downscaling convolution ops with size-1
1418// window dimensions.
1419//
1420// Note that we'd eventually want to write such transformations in a generic
1421// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1422// and then turning back to named ops. But for now it's fine to have a few
1423// patterns matching special ops to get started.
1424
1425template <typename Conv2DOp, typename Conv1DOp>
1427 returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1428 if (convOp.hasPureBufferSemantics())
1429 return failure(); // To be implemented.
1430
1431 Value input = convOp.getInputs().front();
1432 Value kernel = convOp.getInputs().back();
1433 Value output = convOp.getOutputs().front();
1434
1435 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1436 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1437 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1438
1439 auto kernelShape = kernelType.getShape();
1440 auto outputShape = outputType.getShape();
1441
1442 // Get domain indices based on conv2D layout.
1443 auto [khIndex, kwIndex, ohIndex, owIndex] =
1445 convOp)
1446 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1447 return std::make_tuple(0, 1, 1, 2);
1448 })
1449 .Case([&](linalg::Conv2DNchwFchwOp op) {
1450 return std::make_tuple(2, 3, 2, 3);
1451 })
1452 .Case([&](linalg::PoolingNhwcSumOp op) {
1453 return std::make_tuple(0, 1, 1, 2);
1454 })
1455 .Case([&](linalg::PoolingNchwSumOp op) {
1456 return std::make_tuple(0, 1, 2, 3);
1457 })
1458 .Case([&](linalg::PoolingNhwcMaxOp op) {
1459 return std::make_tuple(0, 1, 1, 2);
1460 })
1461 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1462 return std::make_tuple(0, 1, 1, 2);
1463 })
1464 .Case([&](linalg::PoolingNhwcMinOp op) {
1465 return std::make_tuple(0, 1, 1, 2);
1466 })
1467 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1468 return std::make_tuple(0, 1, 1, 2);
1469 })
1470 .Case([&](linalg::PoolingNchwMaxOp op) {
1471 return std::make_tuple(0, 1, 2, 3);
1472 })
1473 .DefaultUnreachable("unexpected conv2d/pool2d operation.");
1474
1475 // Only handle the case where at least one of the window dimensions is
1476 // of size 1. Other cases can rely on tiling to reduce to such cases.
1477 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1478 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1479 bool removeH = (khSize == 1 && ohSize == 1);
1480 bool removeW = (kwSize == 1 && owSize == 1);
1481 if (!removeH && !removeW)
1482 return failure();
1483
1484 // Get new shapes and types for all operands by removing the size-1
1485 // dimension.
1486 using RTTBuilder = RankedTensorType::Builder;
1487 RankedTensorType newInputType =
1488 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1489 RankedTensorType newKernelType =
1490 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1491 RankedTensorType newOutputType =
1492 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1493
1494 // Rank-reduce operands.
1495 Location loc = convOp.getLoc();
1497 rewriter, loc, input, newInputType);
1499 rewriter, loc, kernel, newKernelType);
1501 rewriter, loc, output, newOutputType);
1502
1503 // Rank-reduce strides and dilations too.
1504 // TODO: dropDim 1-liner helper.
1505 auto strides =
1506 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1507 strides.erase(strides.begin() + (removeH ? 0 : 1));
1508 auto stridesAttr = rewriter.getI64VectorAttr(strides);
1509
1510 auto dilations =
1511 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1512 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1513 auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1514
1515 auto conv1DOp = Conv1DOp::create(
1516 rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1517 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1518
1519 // Insert back.
1521 rewriter, loc, conv1DOp.getResult(0), output);
1522 rewriter.replaceOp(convOp, inserted);
1523
1524 return conv1DOp;
1525}
1526
1527template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1528 Conv1DNwcWcfOp>;
1529template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1530 Conv1DNcwFcwOp>;
1531template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1532 PoolingNwcSumOp>;
1533template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1534 PoolingNcwSumOp>;
1535template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1536 PoolingNwcMaxOp>;
1538 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1539template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1540 PoolingNwcMinOp>;
1542 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1543template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1544 PoolingNcwMaxOp>;
1545
1546FailureOr<DepthwiseConv1DNwcWcOp>
1548 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1549 if (convOp.hasPureBufferSemantics())
1550 return failure(); // To be implemented.
1551
1552 Value input = convOp.getInputs().front();
1553 Value kernel = convOp.getInputs().back();
1554 Value output = convOp.getOutputs().front();
1555
1556 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1557 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1558 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1559
1560 auto kernelShape = kernelType.getShape();
1561 auto outputShape = outputType.getShape();
1562
1563 // Only handle the case where at least one of the window dimensions is
1564 // of size 1. Other cases can rely on tiling to reduce to such cases.
1565 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1566 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1567 bool removeH = (khSize == 1 && ohSize == 1);
1568 bool removeW = (kwSize == 1 && owSize == 1);
1569 if (!removeH && !removeW)
1570 return failure();
1571
1572 // Get new shapes and types for all operands by removing the size-1
1573 // dimension.
1574 using RTTBuilder = RankedTensorType::Builder;
1575 RankedTensorType newInputType =
1576 RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1577 RankedTensorType newKernelType =
1578 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1579 RankedTensorType newOutputType =
1580 RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1581
1582 // Rank-reduce operands.
1583 Location loc = convOp.getLoc();
1585 rewriter, loc, input, newInputType);
1587 rewriter, loc, kernel, newKernelType);
1589 rewriter, loc, output, newOutputType);
1590
1591 // Rank-reduce strides and dilations too.
1592 // TODO: dropDim 1-liner helper.
1593 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1594 strides.erase(strides.begin() + (removeH ? 0 : 1));
1595 auto stridesAttr = rewriter.getI64VectorAttr(strides);
1596
1597 auto dilations =
1598 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1599 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1600 auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1601
1602 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1603 rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1604 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1605
1606 // Insert back.
1608 rewriter, loc, conv1DOp.getResult(0), output);
1609 rewriter.replaceOp(convOp, inserted);
1610
1611 return conv1DOp;
1612}
1613
1614FailureOr<Conv1DOp>
1616 PatternRewriter &rewriter) const {
1617 if (convOp.hasPureBufferSemantics())
1618 return failure(); // To be implemented.
1619
1620 Value input = convOp.getInputs().front();
1621 Value kernel = convOp.getInputs().back();
1622 Value output = convOp.getOutputs().front();
1623
1624 auto inputType = dyn_cast<RankedTensorType>(input.getType());
1625 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1626 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1627
1628 auto kernelShape = kernelType.getShape();
1629 auto outputShape = outputType.getShape();
1630
1631 // Only handle the case where at least one of the window dimensions is
1632 // of size 1. Other cases can rely on tiling to reduce to such cases.
1633 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1634 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1635 bool removeH = (khSize == 1 && ohSize == 1);
1636 bool removeW = (kwSize == 1 && owSize == 1);
1637 if (!removeH && !removeW)
1638 return failure();
1639
1640 // Get new shapes and types for all operands by removing the size-1
1641 // dimension.
1642 using RTTBuilder = RankedTensorType::Builder;
1643 RankedTensorType newInputType =
1644 RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1645 RankedTensorType newKernelType =
1646 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1647 RankedTensorType newOutputType =
1648 RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1649
1650 // Rank-reduce operands.
1651 Location loc = convOp.getLoc();
1653 rewriter, loc, input, newInputType);
1655 rewriter, loc, kernel, newKernelType);
1657 rewriter, loc, output, newOutputType);
1658
1659 auto conv1DOp =
1660 Conv1DOp::create(rewriter, loc, newOutputType,
1661 ValueRange{newInput, newKernel}, ValueRange{newOutput});
1662
1663 // Insert back.
1665 rewriter, loc, conv1DOp.getResult(0), output);
1666 rewriter.replaceOp(convOp, inserted);
1667
1668 return conv1DOp;
1669}
1670
1691
1696
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition AffineMap.h:315
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:128
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition Transforms.h:204
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Definition Transforms.h:194
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.