MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/AffineMap.h"
32#include "mlir/IR/Matchers.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <optional>
36
37#define DEBUG_TYPE "linalg-utils"
38
39using namespace mlir;
40using namespace presburger;
41using namespace mlir::affine;
42using namespace mlir::linalg;
43using namespace mlir::scf;
44
45namespace {
46
47// Helper visitor to determine whether an AffineExpr is tiled.
48// This is achieved by traversing every AffineDimExpr with position `pos` and
49// checking whether the corresponding `tileSizes[pos]` is non-zero.
50// This also enforces only positive coefficients occur in multiplications.
51//
52// Example:
53// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
54//
55struct TileCheck : public AffineExprVisitor<TileCheck> {
56 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
57
58 void visitDimExpr(AffineDimExpr expr) {
59 isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
60 }
61 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
62 visit(expr.getLHS());
63 visit(expr.getRHS());
65 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
67 }
68 bool isTiled = false;
69 ArrayRef<OpFoldResult> tileSizes;
70};
71
72} // namespace
73
74static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
75 if (!expr)
76 return false;
77 TileCheck t(tileSizes);
78 t.visit(expr);
79 return t.isTiled;
80}
81
82// Checks whether the `map varies with respect to a non-zero `tileSize`.
83static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
84 if (!map)
85 return false;
86 for (unsigned r = 0; r < map.getNumResults(); ++r)
87 if (isTiled(map.getResult(r), tileSizes))
88 return true;
89 return false;
90}
91
92std::optional<RegionMatcher::BinaryOpKind>
94 auto &region = op.getRegion();
95 if (!region.hasOneBlock())
96 return std::nullopt;
97
98 Block &block = region.front();
99 if (block.getNumArguments() != 2 ||
102 return std::nullopt;
103
104 auto &ops = block.getOperations();
105 if (!llvm::hasSingleElement(block.without_terminator()))
106 return std::nullopt;
107
109 auto a = m_Val(block.getArgument(0));
110 auto b = m_Val(block.getArgument(1));
111
112 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
113 if (addPattern.match(&ops.back()))
114 return BinaryOpKind::IAdd;
115
116 return std::nullopt;
117}
118
119/// Explicit instantiation of loop nest generator for different loop types.
123
124/// Given a list of subview ranges, extract individual values for lower, upper
125/// bounds and steps and put them into the corresponding vectors.
126static void unpackRanges(OpBuilder &builder, Location loc,
129 SmallVectorImpl<Value> &steps) {
130 for (Range range : ranges) {
131 lbs.emplace_back(
132 getValueOrCreateConstantIndexOp(builder, loc, range.offset));
133 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
134 steps.emplace_back(
135 getValueOrCreateConstantIndexOp(builder, loc, range.stride));
136 }
137}
138
139//===----------------------------------------------------------------------===//
140// General utilities
141//===----------------------------------------------------------------------===//
142//
143/// The permutation can be obtained from two permutations:
144/// a) Compute the permutation vector to move the last `numPackedDims` into
145/// the `innerPosDims` of a shape of rank `rank`.
146/// b) Compute the permutation vector to move outer dims if the
147/// `outerPerm` parameter is not empty.
148/// Apply (b) permutation on (a) permutation to get the final permutation.
149static SmallVector<int64_t>
151 ArrayRef<int64_t> &outerPerm,
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
154 auto lastDims =
155 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
156 packingMetadata = computePackingMetadata(rank, innerDimsPos);
157 SmallVector<int64_t> innerPositionsPerm =
158 computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
159
160 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
161 if (!outerPerm.empty())
162 applyPermutationToVector(outerPos, outerPerm);
163 SmallVector<int64_t> outerPositionPerm =
164 computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
165
166 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
167 applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
168 return packInverseDestPermutation;
169}
170
171namespace mlir {
172namespace linalg {
173
175 PackingMetadata &metadata) {
176
177 int64_t packedRank = packOp.getDestType().getRank();
178 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
180 SmallVector<int64_t> packInvDestPerm =
181 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
182 return packInvDestPerm;
183}
184
186 PackingMetadata &metadata) {
187 int64_t packedRank = unpackOp.getSourceType().getRank();
188 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
189 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
190 SmallVector<int64_t> unpackInvSrcPerm =
191 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
192 return unpackInvSrcPerm;
193}
194
196 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
197 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
198 });
199}
200
202 if (!r.hasOneBlock())
203 return false;
204 for (Operation &op : r.front()) {
205 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
206 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
208 llvm::any_of(op.getResultTypes(),
209 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
210 return false;
211 }
212 return true;
213}
214
215bool isElementwise(LinalgOp op) {
216 if (op.getNumLoops() != op.getNumParallelLoops())
217 return false;
218
220 return false;
221
222 // TODO: relax the restrictions on indexing map.
223 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
224 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
225 return false;
226 }
227 return hasOnlyScalarElementwiseOp(op->getRegion(0));
228}
229
230bool isParallelIterator(utils::IteratorType iteratorType) {
231 return iteratorType == utils::IteratorType::parallel;
232}
233
234bool isReductionIterator(utils::IteratorType iteratorType) {
235 return iteratorType == utils::IteratorType::reduction;
236}
237
238//===----------------------------------------------------------------------===//
239// Convolution matcher utilities
240//===----------------------------------------------------------------------===//
241
242/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
243/// ext*/sitofp ops.
245 BlockArgument blockArg = dyn_cast<BlockArgument>(val);
246 if ((blockArg))
247 return blockArg;
248
249 Operation *defOp = val.getDefiningOp();
250 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
251 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
252 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
253 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
254 return nullptr;
255 }
256 return dyn_cast<BlockArgument>(defOp->getOperand(0));
257}
258
259/// Utility function to match the zero point offset body of quantized
260/// convolution ops.
261///
262/// Quantized convolutions have a body of the form:
263/// %out + ((%input - %inputZp) * (%filter - %filterZp))
264/// where:
265/// - %input is the input tensor element (block arg 0)
266/// - %filter is the filter tensor element (block arg 1)
267/// - %inputZp is the input zero-point scalar (block arg 2)
268/// - %filterZp is the filter zero-point scalar (block arg 3)
269/// - %out is the output accumulator (block arg 4)
270///
271/// This function verifies that the multiplication operands are subtraction
272/// operations matching this pattern.
274 Block *body) {
275 // The multiplication should have two subtraction operands:
276 // one for (input - inputZp) and one for (filter - filterZp).
277 Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
278 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
279 return false;
280
281 Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
282 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
283 return false;
284
285 // Extract block arguments from subtraction operands.
286 BlockArgument inputBlockArg =
288 BlockArgument inputZpBlockArg =
290 BlockArgument filterBlockArg =
292 BlockArgument filterZpBlockArg =
294 BlockArgument outBlockArg =
296
297 // Verify all block arguments are valid.
298 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
299 !filterZpBlockArg || !outBlockArg)
300 return false;
301
302 // Verify all block arguments belong to the convolution body.
303 if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body ||
304 filterBlockArg.getOwner() != body ||
305 filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body)
306 return false;
307
308 // Verify block arguments have expected indices:
309 // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output
310 if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 ||
311 inputZpBlockArg.getArgNumber() != 2 ||
312 filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4)
313 return false;
314
315 return true;
316}
317
318/// Utility to match block body for convolution ops.
319/// The body is thus expected to yield :-
320/// %out + (%lhs * %rhs)
321/// where: %lhs, %rhs and %out are block arguments and
322/// %lhs and %rhs can have optional upcast operation.
323/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
324/// %input - %input_scalar
325/// where, %input_scalar can have optional upcast operation.
326static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
327 bool containsZeroPointOffset = false) {
328 Operation *addOp = yieldVal.getDefiningOp();
329 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
330 return false;
331
332 Operation *mulOp = addOp->getOperand(1).getDefiningOp();
333 if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
334 return false;
335
336 if (containsZeroPointOffset) {
337 return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
338 }
339 BlockArgument lhsBlockArg =
341 BlockArgument rhsBlockArg =
343 BlockArgument outBlockArg =
345 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
346 lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
347 outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
348 rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
349 return false;
350 return true;
351}
352
353/// Utility to match block body for linalg.pool* ops.
354template <typename... OpTypes>
355static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
356 Operation *defOp = yieldVal.getDefiningOp();
357 if (!(isa_and_present<OpTypes>(defOp) || ...))
358 return false;
359
360 BlockArgument lhsArg =
362 BlockArgument rhsArg =
364 if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
365 rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
366 rhsArg.getArgNumber() != 0)
367 return false;
368 return true;
369}
370
371static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
373 body);
374}
375
376// max_unsigned ops should not allow float data type.
377// TODO(#164800): Retire OPDSL logic.
378static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
380 body);
381}
382
383static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
385 body);
386}
387
388// min_unsigned ops should not allow float data type.
389// TODO(#164800): Retire OPDSL logic.
390static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
392 body);
393}
394
395static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
397}
398
399static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
400 uint32_t dimIndex) {
401 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
402 if (dimIndex < affineMap.getNumResults())
403 return affineMap.getResult(dimIndex);
404 return nullptr;
405}
406
407/// Check if `expr` is either:
408/// - a dimension expr alone (implying multiplication by 1), or
409/// - a multiplication of dimension expr by any positive constant != 1
410/// In both cases we will capture the dimension expression into `dim` and
411/// return the constant multiplier. Returns -1 in case of a match failure.
413 if ((dim = dyn_cast<AffineDimExpr>(expr)))
414 return 1;
415
416 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
417 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
418 return -1;
419
420 AffineExpr lhs = mulExpr.getLHS();
421 AffineExpr rhs = mulExpr.getRHS();
422
423 AffineConstantExpr cst = nullptr;
424 if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
425 (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
426 ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
427 (cst = dyn_cast<AffineConstantExpr>(lhs))))
428 return cst.getValue();
429 return -1;
430}
431
432/// Given an array of AffineMaps `indexingMaps` verify the following
433/// commutatively:-
434/// indexingMaps[0].getResult(iDim) ==
435/// indexingMaps[1].getResult(fDim) * <c0> +
436/// indexingMaps[n-1].getResult(oDim) * <c1>
437/// where,
438/// - c0 and c1 can be any constant,
439/// - n is the size of the indexingMaps' array,
440/// - 0, 1 and n-1 are input, filter and output map indices respectively,
441/// - iDim, fDim and oDim are the input, filter and output dimension
442/// indices in their respective indexing maps
443/// Example:
444/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
445/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
446/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
447/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
448///
449/// Here,
450/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
451/// Therefore,
452/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
453/// would return true and update dilation = 3 and stride = 2
454static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
455 unsigned fDim, unsigned oDim,
456 int64_t &dilation, int64_t &stride) {
457 unsigned inputMapIdx = 0, filterMapIdx = 1,
458 outputMapIdx = indexingMaps.size() - 1;
459 AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
460 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
461 if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
462 return false;
463
464 AffineExpr dim0, dim1;
465 int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
466 int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
467
468 if (c0 == -1 || c1 == -1)
469 return false;
470 // Pattern matched with dims and constants extracted.
471 AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
472 AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
473 if (dim0 == fExpr && dim1 == oExpr) {
474 dilation = c0;
475 stride = c1;
476 return true;
477 }
478 if (dim1 == fExpr && dim0 == oExpr) {
479 dilation = c1;
480 stride = c0;
481 return true;
482 }
483 return false;
484}
485
486/// Returns true if the given indexing maps matches with the expected indexing
487/// maps.
489 ArrayAttr indexingMaps, MLIRContext *context) {
490 SmallVector<AffineMap, 4> expectedIndexingMaps =
491 AffineMap::inferFromExprList(mapListExpected, context);
492 return indexingMaps ==
493 ArrayAttr::get(
494 context, llvm::to_vector<4>(llvm::map_range(
495 expectedIndexingMaps, [&](AffineMap m) -> Attribute {
496 return AffineMapAttr::get(m);
497 })));
498}
499
500/// Enum representing pooling operation types used by ConvMatcherBuilder.
509
510/// Helper class for building convolution op matchers with minimal boilerplate.
511/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
512/// as Pooling ops.
513///
514/// Usage: Create an instance with the op, spatial rank, and output pointers for
515/// extracted dilations/strides. Then chain matchStride() calls for each spatial
516/// dimension, followed by matchMaps() to verify indexing maps, and finally
517/// matchBody() to verify the operation body pattern.
518///
519/// The `matched` flag starts as `true` and is set to `false` if any match step
520/// fails. This allows chaining multiple match calls; once any match fails, all
521/// subsequent calls become no-ops and the final result is `false`.
522///
523/// The `dilations` and `strides` pointers are output parameters that get
524/// populated with the extracted dilation and stride values from the operation's
525/// indexing maps during matchStride() calls. These values are initially set to
526/// 1 for each spatial dimension and updated as patterns are matched.
528 LinalgOp op;
529 MLIRContext *ctx;
530 SmallVector<int64_t> *dilations, *strides;
531 ArrayAttr indexingMaps;
532 PoolingType poolingType;
533 bool matched = true;
534
535public:
536 ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
538 PoolingType poolingType = PoolingType::None)
539 : op(op), ctx(op->getContext()), dilations(d), strides(s),
540 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
541 *dilations = SmallVector<int64_t>(spatialRank, 1);
542 *strides = SmallVector<int64_t>(spatialRank, 1);
543 }
544
545 /// Get affine dimension expression for dimension `i`.
546 AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
547
548 /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
549 AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
550 return base * (*strides)[idx] + kernel * (*dilations)[idx];
551 }
552
553 /// Match stride/dilation pattern for a spatial dimension.
554 /// Returns *this for method chaining.
555 ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
556 unsigned idx) {
557 if (matched) {
558 matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
559 (*dilations)[idx], (*strides)[idx]);
560 }
561 return *this;
562 }
563
564 /// Match expected indexing maps layout. Returns *this for method chaining.
566 if (matched)
567 matched &= convLayoutMatches(maps, indexingMaps, ctx);
568 return *this;
569 }
570
571 /// Match body pattern. This should be called last.
572 bool matchBody(bool containsZeroPointOffset = false) {
573 if (!matched)
574 return false;
575 Block *body = op.getBlock();
576 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
577 switch (poolingType) {
579 return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
580 containsZeroPointOffset);
582 return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
584 return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
586 return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
588 return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
589 case PoolingType::Sum:
590 return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
591 }
592 return false;
593 }
594};
595
596//===----------------------------------------------------------------------===//
597// Matchers for specific convolution operation.
598//===----------------------------------------------------------------------===//
599
600template <>
602 SmallVector<int64_t> *dilations,
603 SmallVector<int64_t> *strides) {
604 if (isa<linalg::Conv1DOp>(op))
605 return true;
606
607 assert(isaConvolutionOpInterface(op) &&
608 "expected op to implement ConvolutionOpInterface");
609
610 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
611 AffineExpr W = m.dim(0);
612 AffineExpr w = m.dim(1);
613
614 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
615 .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
616 /*filterMap=*/{w},
617 /*outputMap=*/{W}})
618 .matchBody();
619}
620
621template <>
623 LinalgOp op, SmallVector<int64_t> *dilations,
624 SmallVector<int64_t> *strides) {
625 if (isa<linalg::Conv1DNwcWcfOp>(op))
626 return true;
627
628 assert(isaConvolutionOpInterface(op) &&
629 "expected op to implement ConvolutionOpInterface");
630
631 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
632 AffineExpr N = m.dim(0);
633 AffineExpr W = m.dim(1);
634 AffineExpr F = m.dim(2);
635 AffineExpr w = m.dim(3);
636 AffineExpr c = m.dim(4);
637
638 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
639 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
640 /*filterMap=*/{w, c, F},
641 /*outputMap=*/{N, W, F}})
642 .matchBody();
643}
644
645template <>
647 LinalgOp op, SmallVector<int64_t> *dilations,
648 SmallVector<int64_t> *strides) {
649 if (isa<linalg::Conv1DNcwFcwOp>(op))
650 return true;
651
652 assert(isaConvolutionOpInterface(op) &&
653 "expected op to implement ConvolutionOpInterface");
654
655 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
656 AffineExpr N = m.dim(0);
657 AffineExpr F = m.dim(1);
658 AffineExpr W = m.dim(2);
659 AffineExpr c = m.dim(3);
660 AffineExpr w = m.dim(4);
661
662 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
663 .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
664 /*filterMap=*/{F, c, w},
665 /*outputMap=*/{N, F, W}})
666 .matchBody();
667}
668
669template <>
671 SmallVector<int64_t> *dilations,
672 SmallVector<int64_t> *strides) {
673 if (isa<linalg::Conv2DOp>(op))
674 return true;
675
676 assert(isaConvolutionOpInterface(op) &&
677 "expected op to implement ConvolutionOpInterface");
678
679 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
680 AffineExpr H = m.dim(0);
681 AffineExpr W = m.dim(1);
682 AffineExpr h = m.dim(2);
683 AffineExpr w = m.dim(3);
684
685 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
686 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
687 .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
688 /*filterMap=*/{h, w},
689 /*outputMap=*/{H, W}})
690 .matchBody();
691}
692
693template <>
695 LinalgOp op, SmallVector<int64_t> *dilations,
696 SmallVector<int64_t> *strides) {
697 if (isa<linalg::Conv2DNhwcHwcfOp>(op))
698 return true;
699
700 assert(isaConvolutionOpInterface(op) &&
701 "expected op to implement ConvolutionOpInterface");
702
703 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
704 AffineExpr N = m.dim(0);
705 AffineExpr H = m.dim(1);
706 AffineExpr W = m.dim(2);
707 AffineExpr F = m.dim(3);
708 AffineExpr h = m.dim(4);
709 AffineExpr w = m.dim(5);
710 AffineExpr c = m.dim(6);
711
712 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
713 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
714 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
715 /*filterMap=*/{h, w, c, F},
716 /*outputMap=*/{N, H, W, F}})
717 .matchBody();
718}
719
720template <>
722 LinalgOp op, SmallVector<int64_t> *dilations,
723 SmallVector<int64_t> *strides) {
724 if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
725 return true;
726
727 assert(isaConvolutionOpInterface(op) &&
728 "expected op to implement ConvolutionOpInterface");
729
730 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
731 AffineExpr N = m.dim(0);
732 AffineExpr H = m.dim(1);
733 AffineExpr W = m.dim(2);
734 AffineExpr F = m.dim(3);
735 AffineExpr h = m.dim(4);
736 AffineExpr w = m.dim(5);
737 AffineExpr c = m.dim(6);
738
739 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
740 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
741 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
742 /*filterMap=*/{h, w, c, F},
743 /*scalarMap=*/{},
744 /*scalarMap=*/{},
745 /*outputMap=*/{N, H, W, F}})
746 .matchBody(/*containsZeroPointOffset=*/true);
747}
748
749template <>
751 LinalgOp op, SmallVector<int64_t> *dilations,
752 SmallVector<int64_t> *strides) {
753 if (isa<linalg::Conv2DNhwcFhwcOp>(op))
754 return true;
755
756 assert(isaConvolutionOpInterface(op) &&
757 "expected op to implement ConvolutionOpInterface");
758
759 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
760 AffineExpr N = m.dim(0);
761 AffineExpr H = m.dim(1);
762 AffineExpr W = m.dim(2);
763 AffineExpr F = m.dim(3);
764 AffineExpr h = m.dim(4);
765 AffineExpr w = m.dim(5);
766 AffineExpr c = m.dim(6);
767
768 return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
769 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
770 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
771 /*filterMap=*/{F, h, w, c},
772 /*outputMap=*/{N, H, W, F}})
773 .matchBody();
774}
775
776template <>
778 LinalgOp op, SmallVector<int64_t> *dilations,
779 SmallVector<int64_t> *strides) {
780 if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
781 return true;
782
783 assert(isaConvolutionOpInterface(op) &&
784 "expected op to implement ConvolutionOpInterface");
785
786 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
787 AffineExpr N = m.dim(0);
788 AffineExpr H = m.dim(1);
789 AffineExpr W = m.dim(2);
790 AffineExpr F = m.dim(3);
791 AffineExpr h = m.dim(4);
792 AffineExpr w = m.dim(5);
793 AffineExpr c = m.dim(6);
794
795 return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
796 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
797 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
798 /*filterMap=*/{F, h, w, c},
799 /*scalarMap=*/{},
800 /*scalarMap=*/{},
801 /*outputMap=*/{N, H, W, F}})
802 .matchBody(/*containsZeroPointOffset=*/true);
803}
804
805template <>
807 LinalgOp op, SmallVector<int64_t> *dilations,
808 SmallVector<int64_t> *strides) {
809 if (isa<linalg::Conv2DNchwFchwOp>(op))
810 return true;
811
812 assert(isaConvolutionOpInterface(op) &&
813 "expected op to implement ConvolutionOpInterface");
814
815 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
816 AffineExpr N = m.dim(0);
817 AffineExpr F = m.dim(1);
818 AffineExpr H = m.dim(2);
819 AffineExpr W = m.dim(3);
820 AffineExpr c = m.dim(4);
821 AffineExpr h = m.dim(5);
822 AffineExpr w = m.dim(6);
823
824 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
825 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
826 .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
827 /*filterMap=*/{F, c, h, w},
828 /*outputMap=*/{N, F, H, W}})
829 .matchBody();
830}
831
832template <>
834 LinalgOp op, SmallVector<int64_t> *dilations,
835 SmallVector<int64_t> *strides) {
836 if (isa<linalg::Conv2DNchwFchwQOp>(op))
837 return true;
838
839 assert(isaConvolutionOpInterface(op) &&
840 "expected op to implement ConvolutionOpInterface");
841
842 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
843 AffineExpr N = m.dim(0);
844 AffineExpr F = m.dim(1);
845 AffineExpr H = m.dim(2);
846 AffineExpr W = m.dim(3);
847 AffineExpr c = m.dim(4);
848 AffineExpr h = m.dim(5);
849 AffineExpr w = m.dim(6);
850
851 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
852 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
853 .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
854 /*filterMap=*/{F, c, h, w},
855 /*scalarMap=*/{},
856 /*scalarMap=*/{},
857 /*outputMap=*/{N, F, H, W}})
858 .matchBody(/*containsZeroPointOffset=*/true);
859}
860
861template <>
863 LinalgOp op, SmallVector<int64_t> *dilations,
864 SmallVector<int64_t> *strides) {
865 if (isa<linalg::Conv2DNgchwFgchwOp>(op))
866 return true;
867
868 assert(isaConvolutionOpInterface(op) &&
869 "expected op to implement ConvolutionOpInterface");
870
871 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
872 AffineExpr N = m.dim(0);
873 AffineExpr G = m.dim(1);
874 AffineExpr F = m.dim(2);
875 AffineExpr H = m.dim(3);
876 AffineExpr W = m.dim(4);
877 AffineExpr c = m.dim(5);
878 AffineExpr h = m.dim(6);
879 AffineExpr w = m.dim(7);
880
881 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
882 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
883 .matchMaps(
884 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
885 /*filterMap=*/{F, G, c, h, w},
886 /*outputMap=*/{N, G, F, H, W}})
887 .matchBody();
888}
889
890template <>
892 LinalgOp op, SmallVector<int64_t> *dilations,
893 SmallVector<int64_t> *strides) {
894 if (isa<linalg::Conv2DNgchwGfchwOp>(op))
895 return true;
896
897 assert(isaConvolutionOpInterface(op) &&
898 "expected op to implement ConvolutionOpInterface");
899
900 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
901 AffineExpr N = m.dim(0);
902 AffineExpr G = m.dim(1);
903 AffineExpr F = m.dim(2);
904 AffineExpr H = m.dim(3);
905 AffineExpr W = m.dim(4);
906 AffineExpr c = m.dim(5);
907 AffineExpr h = m.dim(6);
908 AffineExpr w = m.dim(7);
909
910 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
911 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
912 .matchMaps(
913 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
914 /*filterMap=*/{G, F, c, h, w},
915 /*outputMap=*/{N, G, F, H, W}})
916 .matchBody();
917}
918
919template <>
921 LinalgOp op, SmallVector<int64_t> *dilations,
922 SmallVector<int64_t> *strides) {
923 if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
924 return true;
925
926 assert(isaConvolutionOpInterface(op) &&
927 "expected op to implement ConvolutionOpInterface");
928
929 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
930 AffineExpr N = m.dim(0);
931 AffineExpr G = m.dim(1);
932 AffineExpr F = m.dim(2);
933 AffineExpr H = m.dim(3);
934 AffineExpr W = m.dim(4);
935 AffineExpr c = m.dim(5);
936 AffineExpr h = m.dim(6);
937 AffineExpr w = m.dim(7);
938
939 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
940 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
941 .matchMaps(
942 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
943 /*filterMap=*/{G, F, c, h, w},
944 /*scalarMap=*/{},
945 /*scalarMap=*/{},
946 /*outputMap=*/{N, G, F, H, W}})
947 .matchBody(/*containsZeroPointOffset=*/true);
948}
949
950template <>
952 LinalgOp op, SmallVector<int64_t> *dilations,
953 SmallVector<int64_t> *strides) {
954 if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
955 return true;
956
957 assert(isaConvolutionOpInterface(op) &&
958 "expected op to implement ConvolutionOpInterface");
959
960 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
961 AffineExpr N = m.dim(0);
962 AffineExpr H = m.dim(1);
963 AffineExpr W = m.dim(2);
964 AffineExpr G = m.dim(3);
965 AffineExpr F = m.dim(4);
966 AffineExpr h = m.dim(5);
967 AffineExpr w = m.dim(6);
968 AffineExpr c = m.dim(7);
969
970 return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
971 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
972 .matchMaps(
973 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
974 /*filterMap=*/{G, F, h, w, c},
975 /*outputMap=*/{N, H, W, G, F}})
976 .matchBody();
977}
978
979template <>
981 LinalgOp op, SmallVector<int64_t> *dilations,
982 SmallVector<int64_t> *strides) {
983 if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
984 return true;
985
986 assert(isaConvolutionOpInterface(op) &&
987 "expected op to implement ConvolutionOpInterface");
988
989 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
990 AffineExpr N = m.dim(0);
991 AffineExpr H = m.dim(1);
992 AffineExpr W = m.dim(2);
993 AffineExpr G = m.dim(3);
994 AffineExpr F = m.dim(4);
995 AffineExpr h = m.dim(5);
996 AffineExpr w = m.dim(6);
997 AffineExpr c = m.dim(7);
998
999 return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1000 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1001 .matchMaps(
1002 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1003 /*filterMap=*/{G, F, h, w, c},
1004 /*scalarMap=*/{},
1005 /*scalarMap=*/{},
1006 /*outputMap=*/{N, H, W, G, F}})
1007 .matchBody(/*containsZeroPointOffset=*/true);
1008}
1009
1010template <>
1012 SmallVector<int64_t> *dilations,
1013 SmallVector<int64_t> *strides) {
1014 if (isa<linalg::Conv3DOp>(op))
1015 return true;
1016
1017 assert(isaConvolutionOpInterface(op) &&
1018 "expected op to implement ConvolutionOpInterface");
1019
1020 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1021 AffineExpr D = m.dim(0);
1022 AffineExpr H = m.dim(1);
1023 AffineExpr W = m.dim(2);
1024 AffineExpr d = m.dim(3);
1025 AffineExpr h = m.dim(4);
1026 AffineExpr w = m.dim(5);
1027
1028 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
1029 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
1030 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
1031 .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
1032 m.strided(W, w, 2)},
1033 /*filterMap=*/{d, h, w},
1034 /*outputMap=*/{D, H, W}})
1035 .matchBody();
1036}
1037
1038template <>
1040 LinalgOp op, SmallVector<int64_t> *dilations,
1041 SmallVector<int64_t> *strides) {
1042 if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
1043 return true;
1044
1045 assert(isaConvolutionOpInterface(op) &&
1046 "expected op to implement ConvolutionOpInterface");
1047
1048 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1049 AffineExpr N = m.dim(0);
1050 AffineExpr D = m.dim(1);
1051 AffineExpr H = m.dim(2);
1052 AffineExpr W = m.dim(3);
1053 AffineExpr F = m.dim(4);
1054 AffineExpr d = m.dim(5);
1055 AffineExpr h = m.dim(6);
1056 AffineExpr w = m.dim(7);
1057 AffineExpr c = m.dim(8);
1058
1059 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1060 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1061 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1062 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1063 m.strided(W, w, 2), c},
1064 /*filterMap=*/{d, h, w, c, F},
1065 /*outputMap=*/{N, D, H, W, F}})
1066 .matchBody();
1067}
1068
1069template <>
1071 LinalgOp op, SmallVector<int64_t> *dilations,
1072 SmallVector<int64_t> *strides) {
1073 if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
1074 return true;
1075
1076 assert(isaConvolutionOpInterface(op) &&
1077 "expected op to implement ConvolutionOpInterface");
1078
1079 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1080 AffineExpr N = m.dim(0);
1081 AffineExpr D = m.dim(1);
1082 AffineExpr H = m.dim(2);
1083 AffineExpr W = m.dim(3);
1084 AffineExpr F = m.dim(4);
1085 AffineExpr d = m.dim(5);
1086 AffineExpr h = m.dim(6);
1087 AffineExpr w = m.dim(7);
1088 AffineExpr c = m.dim(8);
1089
1090 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1091 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1092 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1093 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1094 m.strided(W, w, 2), c},
1095 /*filterMap=*/{d, h, w, c, F},
1096 /*scalarMap=*/{},
1097 /*scalarMap=*/{},
1098 /*outputMap=*/{N, D, H, W, F}})
1099 .matchBody(/*containsZeroPointOffset=*/true);
1100}
1101
1102template <>
1104 LinalgOp op, SmallVector<int64_t> *dilations,
1105 SmallVector<int64_t> *strides) {
1106 if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
1107 return true;
1108
1109 assert(isaConvolutionOpInterface(op) &&
1110 "expected op to implement ConvolutionOpInterface");
1111
1112 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1113 AffineExpr N = m.dim(0);
1114 AffineExpr F = m.dim(1);
1115 AffineExpr D = m.dim(2);
1116 AffineExpr H = m.dim(3);
1117 AffineExpr W = m.dim(4);
1118 AffineExpr c = m.dim(5);
1119 AffineExpr d = m.dim(6);
1120 AffineExpr h = m.dim(7);
1121 AffineExpr w = m.dim(8);
1122
1123 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
1124 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
1125 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
1126 .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0), m.strided(H, h, 1),
1127 m.strided(W, w, 2)},
1128 /*filterMap=*/{F, c, d, h, w},
1129 /*outputMap=*/{N, F, D, H, W}})
1130 .matchBody();
1131}
1132
1133template <>
1135 LinalgOp op, SmallVector<int64_t> *dilations,
1136 SmallVector<int64_t> *strides) {
1137 if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
1138 return true;
1139
1140 assert(isaConvolutionOpInterface(op) &&
1141 "expected op to implement ConvolutionOpInterface");
1142
1143 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1144 AffineExpr N = m.dim(0);
1145 AffineExpr W = m.dim(1);
1146 AffineExpr C = m.dim(2);
1147 AffineExpr w = m.dim(3);
1148
1149 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1150 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1151 /*filterMap=*/{C, w},
1152 /*outputMap=*/{N, C, W}})
1153 .matchBody();
1154}
1155
1156template <>
1158 LinalgOp op, SmallVector<int64_t> *dilations,
1159 SmallVector<int64_t> *strides) {
1160 if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
1161 return true;
1162
1163 assert(isaConvolutionOpInterface(op) &&
1164 "expected op to implement ConvolutionOpInterface");
1165
1166 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1167 AffineExpr N = m.dim(0);
1168 AffineExpr W = m.dim(1);
1169 AffineExpr C = m.dim(2);
1170 AffineExpr w = m.dim(3);
1171
1172 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1173 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1174 /*filterMap=*/{w, C},
1175 /*outputMap=*/{N, W, C}})
1176 .matchBody();
1177}
1178
1179template <>
1181 LinalgOp op, SmallVector<int64_t> *dilations,
1182 SmallVector<int64_t> *strides) {
1183 if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
1184 return true;
1185
1186 assert(isaConvolutionOpInterface(op) &&
1187 "expected op to implement ConvolutionOpInterface");
1188
1189 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1190 AffineExpr N = m.dim(0);
1191 AffineExpr W = m.dim(1);
1192 AffineExpr C = m.dim(2);
1193 AffineExpr CM = m.dim(3);
1194 AffineExpr w = m.dim(4);
1195
1196 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1197 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1198 /*filterMap=*/{w, C, CM},
1199 /*outputMap=*/{N, W, C, CM}})
1200 .matchBody();
1201}
1202
1203template <>
1205 LinalgOp op, SmallVector<int64_t> *dilations,
1206 SmallVector<int64_t> *strides) {
1207 if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
1208 return true;
1209
1210 assert(isaConvolutionOpInterface(op) &&
1211 "expected op to implement ConvolutionOpInterface");
1212
1213 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1214 AffineExpr N = m.dim(0);
1215 AffineExpr H = m.dim(1);
1216 AffineExpr W = m.dim(2);
1217 AffineExpr C = m.dim(3);
1218 AffineExpr h = m.dim(4);
1219 AffineExpr w = m.dim(5);
1220
1221 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1222 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1223 .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1224 /*filterMap=*/{C, h, w},
1225 /*outputMap=*/{N, C, H, W}})
1226 .matchBody();
1227}
1228
1229template <>
1231 LinalgOp op, SmallVector<int64_t> *dilations,
1232 SmallVector<int64_t> *strides) {
1233 if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
1234 return true;
1235
1236 assert(isaConvolutionOpInterface(op) &&
1237 "expected op to implement ConvolutionOpInterface");
1238
1239 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1240 AffineExpr N = m.dim(0);
1241 AffineExpr H = m.dim(1);
1242 AffineExpr W = m.dim(2);
1243 AffineExpr C = m.dim(3);
1244 AffineExpr h = m.dim(4);
1245 AffineExpr w = m.dim(5);
1246
1247 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1248 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1249 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1250 /*filterMap=*/{h, w, C},
1251 /*outputMap=*/{N, H, W, C}})
1252 .matchBody();
1253}
1254
1255template <>
1257 LinalgOp op, SmallVector<int64_t> *dilations,
1258 SmallVector<int64_t> *strides) {
1259 if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
1260 return true;
1261
1262 assert(isaConvolutionOpInterface(op) &&
1263 "expected op to implement ConvolutionOpInterface");
1264
1265 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1266 AffineExpr N = m.dim(0);
1267 AffineExpr H = m.dim(1);
1268 AffineExpr W = m.dim(2);
1269 AffineExpr C = m.dim(3);
1270 AffineExpr h = m.dim(4);
1271 AffineExpr w = m.dim(5);
1272
1273 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1274 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1275 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1276 /*filterMap=*/{h, w, C},
1277 /*scalarMap=*/{},
1278 /*scalarMap=*/{},
1279 /*outputMap=*/{N, H, W, C}})
1280 .matchBody(/*containsZeroPointOffset=*/true);
1281}
1282
1283template <>
1285 LinalgOp op, SmallVector<int64_t> *dilations,
1286 SmallVector<int64_t> *strides) {
1287 if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
1288 return true;
1289
1290 assert(isaConvolutionOpInterface(op) &&
1291 "expected op to implement ConvolutionOpInterface");
1292
1293 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1294 AffineExpr N = m.dim(0);
1295 AffineExpr H = m.dim(1);
1296 AffineExpr W = m.dim(2);
1297 AffineExpr C = m.dim(3);
1298 AffineExpr CM = m.dim(4);
1299 AffineExpr h = m.dim(5);
1300 AffineExpr w = m.dim(6);
1301
1302 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1303 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1304 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1305 /*filterMap=*/{h, w, C, CM},
1306 /*outputMap=*/{N, H, W, C, CM}})
1307 .matchBody();
1308}
1309
1310template <>
1312 LinalgOp op, SmallVector<int64_t> *dilations,
1313 SmallVector<int64_t> *strides) {
1314 if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
1315 return true;
1316
1317 assert(isaConvolutionOpInterface(op) &&
1318 "expected op to implement ConvolutionOpInterface");
1319
1320 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1321 AffineExpr N = m.dim(0);
1322 AffineExpr H = m.dim(1);
1323 AffineExpr W = m.dim(2);
1324 AffineExpr C = m.dim(3);
1325 AffineExpr CM = m.dim(4);
1326 AffineExpr h = m.dim(5);
1327 AffineExpr w = m.dim(6);
1328
1329 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1330 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1331 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1332 /*filterMap=*/{h, w, C, CM},
1333 /*scalarMap=*/{},
1334 /*scalarMap=*/{},
1335 /*outputMap=*/{N, H, W, C, CM}})
1336 .matchBody(/*containsZeroPointOffset=*/true);
1337}
1338
1339template <>
1341 LinalgOp op, SmallVector<int64_t> *dilations,
1342 SmallVector<int64_t> *strides) {
1343 if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
1344 return true;
1345
1346 assert(isaConvolutionOpInterface(op) &&
1347 "expected op to implement ConvolutionOpInterface");
1348
1349 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1350 AffineExpr N = m.dim(0);
1351 AffineExpr D = m.dim(1);
1352 AffineExpr H = m.dim(2);
1353 AffineExpr W = m.dim(3);
1354 AffineExpr d = m.dim(4);
1355 AffineExpr h = m.dim(5);
1356 AffineExpr w = m.dim(6);
1357 AffineExpr C = m.dim(7);
1358
1359 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1360 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1361 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1362 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1363 m.strided(W, w, 2), C},
1364 /*filterMap=*/{d, h, w, C},
1365 /*outputMap=*/{N, D, H, W, C}})
1366 .matchBody();
1367}
1368
1369template <>
1371 LinalgOp op, SmallVector<int64_t> *dilations,
1372 SmallVector<int64_t> *strides) {
1373 if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
1374 return true;
1375
1376 assert(isaConvolutionOpInterface(op) &&
1377 "expected op to implement ConvolutionOpInterface");
1378
1379 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1380 AffineExpr N = m.dim(0);
1381 AffineExpr D = m.dim(1);
1382 AffineExpr H = m.dim(2);
1383 AffineExpr W = m.dim(3);
1384 AffineExpr d = m.dim(4);
1385 AffineExpr h = m.dim(5);
1386 AffineExpr w = m.dim(6);
1387 AffineExpr C = m.dim(7);
1388
1389 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1390 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1391 .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
1392 .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0), m.strided(H, h, 1),
1393 m.strided(W, w, 2)},
1394 /*filterMap=*/{C, d, h, w},
1395 /*outputMap=*/{N, C, D, H, W}})
1396 .matchBody();
1397}
1398
1399template <>
1401 LinalgOp op, SmallVector<int64_t> *dilations,
1402 SmallVector<int64_t> *strides) {
1403 if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
1404 return true;
1405
1406 assert(isaConvolutionOpInterface(op) &&
1407 "expected op to implement ConvolutionOpInterface");
1408
1409 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1410 AffineExpr N = m.dim(0);
1411 AffineExpr D = m.dim(1);
1412 AffineExpr H = m.dim(2);
1413 AffineExpr W = m.dim(3);
1414 AffineExpr CM = m.dim(4);
1415 AffineExpr d = m.dim(5);
1416 AffineExpr h = m.dim(6);
1417 AffineExpr w = m.dim(7);
1418 AffineExpr C = m.dim(8);
1419
1420 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1421 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1422 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1423 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1424 m.strided(W, w, 2), C},
1425 /*filterMap=*/{d, h, w, C, CM},
1426 /*outputMap=*/{N, D, H, W, C, CM}})
1427 .matchBody();
1428}
1429
1430template <>
1432 LinalgOp op, SmallVector<int64_t> *dilations,
1433 SmallVector<int64_t> *strides) {
1434 if (isa<linalg::PoolingNhwcMaxOp>(op))
1435 return true;
1436
1437 assert(isaConvolutionOpInterface(op) &&
1438 "expected op to implement ConvolutionOpInterface");
1439
1440 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1442 AffineExpr N = m.dim(0);
1443 AffineExpr H = m.dim(1);
1444 AffineExpr W = m.dim(2);
1445 AffineExpr C = m.dim(3);
1446 AffineExpr h = m.dim(4);
1447 AffineExpr w = m.dim(5);
1448
1449 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1450 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1451 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1452 /*filterMap=*/{h, w},
1453 /*outputMap=*/{N, H, W, C}})
1454 .matchBody();
1455}
1456
1457template <>
1459 LinalgOp op, SmallVector<int64_t> *dilations,
1460 SmallVector<int64_t> *strides) {
1461 if (isa<linalg::PoolingNhwcMinOp>(op))
1462 return true;
1463
1464 assert(isaConvolutionOpInterface(op) &&
1465 "expected op to implement ConvolutionOpInterface");
1466
1467 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1469 AffineExpr N = m.dim(0);
1470 AffineExpr H = m.dim(1);
1471 AffineExpr W = m.dim(2);
1472 AffineExpr C = m.dim(3);
1473 AffineExpr h = m.dim(4);
1474 AffineExpr w = m.dim(5);
1475
1476 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1477 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1478 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1479 /*filterMap=*/{h, w},
1480 /*outputMap=*/{N, H, W, C}})
1481 .matchBody();
1482}
1483
1484template <>
1486 LinalgOp op, SmallVector<int64_t> *dilations,
1487 SmallVector<int64_t> *strides) {
1488 if (isa<linalg::PoolingNhwcSumOp>(op))
1489 return true;
1490
1491 assert(isaConvolutionOpInterface(op) &&
1492 "expected op to implement ConvolutionOpInterface");
1493
1494 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1496 AffineExpr N = m.dim(0);
1497 AffineExpr H = m.dim(1);
1498 AffineExpr W = m.dim(2);
1499 AffineExpr C = m.dim(3);
1500 AffineExpr h = m.dim(4);
1501 AffineExpr w = m.dim(5);
1502
1503 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1504 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1505 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1506 /*filterMap=*/{h, w},
1507 /*outputMap=*/{N, H, W, C}})
1508 .matchBody();
1509}
1510
1511template <>
1513 LinalgOp op, SmallVector<int64_t> *dilations,
1514 SmallVector<int64_t> *strides) {
1515 if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
1516 return true;
1517
1518 assert(isaConvolutionOpInterface(op) &&
1519 "expected op to implement ConvolutionOpInterface");
1520
1521 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1523 AffineExpr N = m.dim(0);
1524 AffineExpr H = m.dim(1);
1525 AffineExpr W = m.dim(2);
1526 AffineExpr C = m.dim(3);
1527 AffineExpr h = m.dim(4);
1528 AffineExpr w = m.dim(5);
1529
1530 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1531 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1532 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1533 /*filterMap=*/{h, w},
1534 /*outputMap=*/{N, H, W, C}})
1535 .matchBody();
1536}
1537
1538template <>
1540 LinalgOp op, SmallVector<int64_t> *dilations,
1541 SmallVector<int64_t> *strides) {
1542 if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
1543 return true;
1544
1545 assert(isaConvolutionOpInterface(op) &&
1546 "expected op to implement ConvolutionOpInterface");
1547
1548 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1550 AffineExpr N = m.dim(0);
1551 AffineExpr H = m.dim(1);
1552 AffineExpr W = m.dim(2);
1553 AffineExpr C = m.dim(3);
1554 AffineExpr h = m.dim(4);
1555 AffineExpr w = m.dim(5);
1556
1557 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1558 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1559 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1560 /*filterMap=*/{h, w},
1561 /*outputMap=*/{N, H, W, C}})
1562 .matchBody();
1563}
1564
1565template <>
1567 LinalgOp op, SmallVector<int64_t> *dilations,
1568 SmallVector<int64_t> *strides) {
1569 if (isa<linalg::PoolingNchwSumOp>(op))
1570 return true;
1571
1572 assert(isaConvolutionOpInterface(op) &&
1573 "expected op to implement ConvolutionOpInterface");
1574
1575 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1577 AffineExpr N = m.dim(0);
1578 AffineExpr C = m.dim(1);
1579 AffineExpr H = m.dim(2);
1580 AffineExpr W = m.dim(3);
1581 AffineExpr h = m.dim(4);
1582 AffineExpr w = m.dim(5);
1583
1584 return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1585 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1586 .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1587 /*filterMap=*/{h, w},
1588 /*outputMap=*/{N, C, H, W}})
1589 .matchBody();
1590}
1591
1592template <>
1594 LinalgOp op, SmallVector<int64_t> *dilations,
1595 SmallVector<int64_t> *strides) {
1596 if (isa<linalg::PoolingNchwMaxOp>(op))
1597 return true;
1598
1599 assert(isaConvolutionOpInterface(op) &&
1600 "expected op to implement ConvolutionOpInterface");
1601
1602 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1604 AffineExpr N = m.dim(0);
1605 AffineExpr C = m.dim(1);
1606 AffineExpr H = m.dim(2);
1607 AffineExpr W = m.dim(3);
1608 AffineExpr h = m.dim(4);
1609 AffineExpr w = m.dim(5);
1610
1611 return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1612 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1613 .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1614 /*filterMap=*/{h, w},
1615 /*outputMap=*/{N, C, H, W}})
1616 .matchBody();
1617}
1618
1619template <>
1621 LinalgOp op, SmallVector<int64_t> *dilations,
1622 SmallVector<int64_t> *strides) {
1623 if (isa<linalg::PoolingNwcSumOp>(op))
1624 return true;
1625
1626 assert(isaConvolutionOpInterface(op) &&
1627 "expected op to implement ConvolutionOpInterface");
1628
1629 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1631 AffineExpr N = m.dim(0);
1632 AffineExpr W = m.dim(1);
1633 AffineExpr C = m.dim(2);
1634 AffineExpr w = m.dim(3);
1635
1636 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1637 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1638 /*filterMap=*/{w},
1639 /*outputMap=*/{N, W, C}})
1640 .matchBody();
1641}
1642
1643template <>
1645 LinalgOp op, SmallVector<int64_t> *dilations,
1646 SmallVector<int64_t> *strides) {
1647 if (isa<linalg::PoolingNcwSumOp>(op))
1648 return true;
1649
1650 assert(isaConvolutionOpInterface(op) &&
1651 "expected op to implement ConvolutionOpInterface");
1652
1653 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1655 AffineExpr N = m.dim(0);
1656 AffineExpr C = m.dim(1);
1657 AffineExpr W = m.dim(2);
1658 AffineExpr w = m.dim(3);
1659
1660 return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1661 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1662 /*filterMap=*/{w},
1663 /*outputMap=*/{N, C, W}})
1664 .matchBody();
1665}
1666
1667template <>
1669 LinalgOp op, SmallVector<int64_t> *dilations,
1670 SmallVector<int64_t> *strides) {
1671 if (isa<linalg::PoolingNwcMaxOp>(op))
1672 return true;
1673
1674 assert(isaConvolutionOpInterface(op) &&
1675 "expected op to implement ConvolutionOpInterface");
1676
1677 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1679 AffineExpr N = m.dim(0);
1680 AffineExpr W = m.dim(1);
1681 AffineExpr C = m.dim(2);
1682 AffineExpr w = m.dim(3);
1683
1684 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1685 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1686 /*filterMap=*/{w},
1687 /*outputMap=*/{N, W, C}})
1688 .matchBody();
1689}
1690
1691template <>
1693 LinalgOp op, SmallVector<int64_t> *dilations,
1694 SmallVector<int64_t> *strides) {
1695 if (isa<linalg::PoolingNwcMaxUnsignedOp>(op))
1696 return true;
1697
1698 assert(isaConvolutionOpInterface(op) &&
1699 "expected op to implement ConvolutionOpInterface");
1700
1701 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1703 AffineExpr N = m.dim(0);
1704 AffineExpr W = m.dim(1);
1705 AffineExpr C = m.dim(2);
1706 AffineExpr w = m.dim(3);
1707
1708 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1709 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1710 /*filterMap=*/{w},
1711 /*outputMap=*/{N, W, C}})
1712 .matchBody();
1713}
1714
1715template <>
1717 LinalgOp op, SmallVector<int64_t> *dilations,
1718 SmallVector<int64_t> *strides) {
1719 if (isa<linalg::PoolingNcwMaxOp>(op))
1720 return true;
1721
1722 assert(isaConvolutionOpInterface(op) &&
1723 "expected op to implement ConvolutionOpInterface");
1724
1725 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1727 AffineExpr N = m.dim(0);
1728 AffineExpr C = m.dim(1);
1729 AffineExpr W = m.dim(2);
1730 AffineExpr w = m.dim(3);
1731
1732 return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1733 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1734 /*filterMap=*/{w},
1735 /*outputMap=*/{N, C, W}})
1736 .matchBody();
1737}
1738
1739template <>
1741 LinalgOp op, SmallVector<int64_t> *dilations,
1742 SmallVector<int64_t> *strides) {
1743 if (isa<linalg::PoolingNwcMinOp>(op))
1744 return true;
1745
1746 assert(isaConvolutionOpInterface(op) &&
1747 "expected op to implement ConvolutionOpInterface");
1748
1749 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1751 AffineExpr N = m.dim(0);
1752 AffineExpr W = m.dim(1);
1753 AffineExpr C = m.dim(2);
1754 AffineExpr w = m.dim(3);
1755
1756 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1757 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1758 /*filterMap=*/{w},
1759 /*outputMap=*/{N, W, C}})
1760 .matchBody();
1761}
1762
1763template <>
1765 LinalgOp op, SmallVector<int64_t> *dilations,
1766 SmallVector<int64_t> *strides) {
1767 if (isa<linalg::PoolingNwcMinUnsignedOp>(op))
1768 return true;
1769
1770 assert(isaConvolutionOpInterface(op) &&
1771 "expected op to implement ConvolutionOpInterface");
1772
1773 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
1775 AffineExpr N = m.dim(0);
1776 AffineExpr W = m.dim(1);
1777 AffineExpr C = m.dim(2);
1778 AffineExpr w = m.dim(3);
1779
1780 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1781 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1782 /*filterMap=*/{w},
1783 /*outputMap=*/{N, W, C}})
1784 .matchBody();
1785}
1786
1787template <>
1789 LinalgOp op, SmallVector<int64_t> *dilations,
1790 SmallVector<int64_t> *strides) {
1791 if (isa<linalg::PoolingNdhwcSumOp>(op))
1792 return true;
1793
1794 assert(isaConvolutionOpInterface(op) &&
1795 "expected op to implement ConvolutionOpInterface");
1796
1797 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
1799 AffineExpr N = m.dim(0);
1800 AffineExpr D = m.dim(1);
1801 AffineExpr H = m.dim(2);
1802 AffineExpr W = m.dim(3);
1803 AffineExpr C = m.dim(4);
1804 AffineExpr d = m.dim(5);
1805 AffineExpr h = m.dim(6);
1806 AffineExpr w = m.dim(7);
1807
1808 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1809 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1810 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1811 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1812 m.strided(W, w, 2), C},
1813 /*filterMap=*/{d, h, w},
1814 /*outputMap=*/{N, D, H, W, C}})
1815 .matchBody();
1816}
1817
1818template <>
1820 LinalgOp op, SmallVector<int64_t> *dilations,
1821 SmallVector<int64_t> *strides) {
1822 if (isa<linalg::PoolingNdhwcMaxOp>(op))
1823 return true;
1824
1825 assert(isaConvolutionOpInterface(op) &&
1826 "expected op to implement ConvolutionOpInterface");
1827
1828 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
1830 AffineExpr N = m.dim(0);
1831 AffineExpr D = m.dim(1);
1832 AffineExpr H = m.dim(2);
1833 AffineExpr W = m.dim(3);
1834 AffineExpr C = m.dim(4);
1835 AffineExpr d = m.dim(5);
1836 AffineExpr h = m.dim(6);
1837 AffineExpr w = m.dim(7);
1838
1839 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1840 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1841 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1842 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1843 m.strided(W, w, 2), C},
1844 /*filterMap=*/{d, h, w},
1845 /*outputMap=*/{N, D, H, W, C}})
1846 .matchBody();
1847}
1848
1849template <>
1851 LinalgOp op, SmallVector<int64_t> *dilations,
1852 SmallVector<int64_t> *strides) {
1853 if (isa<linalg::PoolingNdhwcMinOp>(op))
1854 return true;
1855
1856 assert(isaConvolutionOpInterface(op) &&
1857 "expected op to implement ConvolutionOpInterface");
1858
1859 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
1861 AffineExpr N = m.dim(0);
1862 AffineExpr D = m.dim(1);
1863 AffineExpr H = m.dim(2);
1864 AffineExpr W = m.dim(3);
1865 AffineExpr C = m.dim(4);
1866 AffineExpr d = m.dim(5);
1867 AffineExpr h = m.dim(6);
1868 AffineExpr w = m.dim(7);
1869
1870 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1871 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1872 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1873 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1874 m.strided(W, w, 2), C},
1875 /*filterMap=*/{d, h, w},
1876 /*outputMap=*/{N, D, H, W, C}})
1877 .matchBody();
1878}
1879
1880Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
1881 Value source, Value pad, bool nofold,
1882 ValueRange typeDynDims) {
1883 // Exit if `source` is not defined by an ExtractSliceOp.
1884 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
1885 if (!sliceOp)
1886 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1887 typeDynDims);
1888
1889 // Search the `source` use-def chain for padded LinalgOps.
1890 Value current = sliceOp.getSource();
1891 while (current) {
1892 auto linalgOp = current.getDefiningOp<LinalgOp>();
1893 if (!linalgOp)
1894 break;
1895 OpResult opResult = cast<OpResult>(current);
1896 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
1897 }
1898 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
1899
1900 // Exit if the search fails to match a tensor::PadOp at the end of the matched
1901 // LinalgOp sequence.
1902 if (!padOp)
1903 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1904 typeDynDims);
1905
1906 // Exit if the padded result type does not match.
1907 if (sliceOp.getSource().getType() != type)
1908 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1909 typeDynDims);
1910
1911 // Exit if the LinalgOps are not high padded.
1912 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
1913 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
1914 }))
1915 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1916 typeDynDims);
1917
1918 // Exit if `padOpSliceOp`, which defines the slice used by
1919 // `padOp`, is rank-reducing.
1920 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
1921 if (!padOpSliceOp ||
1922 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
1923 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1924 typeDynDims);
1925
1926 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
1927 // of the slice padded by `padOp`.
1928 if (llvm::any_of(
1929 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
1930 [](std::tuple<OpFoldResult, OpFoldResult> it) {
1931 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
1932 }))
1933 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1934 typeDynDims);
1935
1936 // Exit if the padding values do not match.
1937 Attribute padOpPadAttr, padAttr;
1938 Value padOpPad = padOp.getConstantPaddingValue();
1939 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
1940 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
1941 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1942 typeDynDims);
1943
1944 // Return the padded result if the padding values and sizes match.
1945 return sliceOp.getSource();
1946}
1947
1948GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
1949 auto memrefTypeTo = cast<MemRefType>(to.getType());
1950#ifndef NDEBUG
1951 auto memrefTypeFrom = cast<MemRefType>(from.getType());
1952 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
1953 "`from` and `to` memref must have the same rank");
1954#endif // NDEBUG
1955
1956 AffineMap id =
1957 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
1958 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
1959 utils::IteratorType::parallel);
1960 return linalg::GenericOp::create(
1961 b, loc,
1962 /*inputs=*/from,
1963 /*outputs=*/to,
1964 /*indexingMaps=*/llvm::ArrayRef({id, id}),
1965 /*iteratorTypes=*/iteratorTypes,
1966 [](OpBuilder &b, Location loc, ValueRange args) {
1967 linalg::YieldOp::create(b, loc, args.front());
1968 });
1969}
1970
1971/// Specialization to build an scf "for" nest.
1972template <>
1974 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
1975 ArrayRef<utils::IteratorType> iteratorTypes,
1977 ValueRange)>
1978 bodyBuilderFn,
1979 ArrayRef<linalg::ProcInfo> procInfo) {
1980 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
1981 "expected as many entries for proc info as number of loops, even if "
1982 "they are null entries");
1983 SmallVector<Value> iterArgInitValues;
1984 if (!linalgOp.hasPureBufferSemantics())
1985 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1986 SmallVector<Value, 4> lbs, ubs, steps;
1987 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
1989 b, loc, lbs, ubs, steps, iterArgInitValues,
1990 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
1991 assert(iterArgs.size() == iterArgInitValues.size() &&
1992 "expect the number of output tensors and iter args to match");
1993 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
1994 if (!iterArgs.empty()) {
1995 operandValuesToUse = linalgOp.getDpsInputs();
1996 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
1997 }
1998 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
1999 });
2000
2001 if (loopNest.loops.empty() || procInfo.empty())
2002 return;
2003
2004 // Filter out scf.for loops that were created out of parallel dimensions.
2005 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
2006 if (procInfo[loop.index()].distributionMethod ==
2008 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
2009 procInfo[loop.index()].nprocs);
2010 }
2011 }
2012}
2013
2014/// Specialization to build affine "for" nest.
2015template <>
2017 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2018 ArrayRef<utils::IteratorType> iteratorTypes,
2020 ValueRange)>
2021 bodyBuilderFn,
2022 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
2023 SmallVector<Value> iterArgInitValues;
2024 if (!linalgOp.hasPureBufferSemantics())
2025 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2026 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
2027 SmallVector<Value, 4> lbs, ubs, steps;
2028 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
2029
2030 // Affine loops require constant steps.
2031 SmallVector<int64_t, 4> constantSteps;
2032 constantSteps.reserve(steps.size());
2033 for (Value v : steps) {
2034 auto constVal = getConstantIntValue(v);
2035 assert(constVal.has_value() && "Affine loops require constant steps");
2036 constantSteps.push_back(constVal.value());
2037 }
2038
2039 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
2040 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2041 bodyBuilderFn(b, loc, ivs,
2042 linalgOp->getOperands());
2043 });
2044}
2045
2046/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
2048 Value nprocs, Value &lb, Value &ub,
2049 Value &step) {
2050 AffineExpr d0, d1;
2051 bindDims(b.getContext(), d0, d1);
2052 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
2053 lb =
2054 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
2055 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
2056}
2057
2058/// Generates a loop nest consisting of scf.parallel and scf.for, depending
2059/// on the `iteratorTypes.` Consecutive parallel loops create a single
2060/// scf.parallel operation; each sequential loop creates a new scf.for
2061/// operation. The body of the innermost loop is populated by
2062/// `bodyBuilderFn` that accepts a range of induction variables for all
2063/// loops. `ivStorage` is used to store the partial list of induction
2064/// variables.
2065// TODO: this function can be made iterative instead. However, it
2066// will have at most as many recursive calls as nested loops, which rarely
2067// exceeds 10.
2069 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
2070 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
2072 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2073 SmallVectorImpl<Value> &ivStorage) {
2074 assert(lbs.size() == ubs.size());
2075 assert(lbs.size() == steps.size());
2076 assert(lbs.size() == iteratorTypes.size());
2077 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
2078
2079 // If there are no (more) loops to be generated, generate the body and be
2080 // done with it.
2081 if (iteratorTypes.empty()) {
2082 bodyBuilderFn(b, loc, ivStorage);
2083 return;
2084 }
2085
2086 // If there are no outer parallel loops, generate one sequential loop and
2087 // recurse.
2088 if (!isParallelIterator(iteratorTypes.front())) {
2089 LoopNest singleLoop = buildLoopNest(
2090 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
2091 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2092 ivStorage.append(ivs.begin(), ivs.end());
2093 generateParallelLoopNest(
2094 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
2095 iteratorTypes.drop_front(),
2096 procInfo.empty() ? procInfo : procInfo.drop_front(),
2097 bodyBuilderFn, ivStorage);
2098 });
2099 return;
2100 }
2101
2102 unsigned nLoops = iteratorTypes.size();
2103 unsigned numProcessed = 0;
2104 DistributionMethod distributionMethod = DistributionMethod::None;
2105 if (procInfo.empty()) {
2106 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
2107 } else {
2108 distributionMethod = procInfo.front().distributionMethod;
2109 numProcessed =
2110 nLoops - procInfo
2111 .drop_while([&](linalg::ProcInfo p) {
2112 return p.distributionMethod == distributionMethod;
2113 })
2114 .size();
2115 }
2116
2117 auto remainderProcInfo =
2118 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
2119 switch (distributionMethod) {
2121 // Generate a single parallel loop-nest operation for all outermost
2122 // parallel loops and recurse.
2123 scf::ParallelOp::create(
2124 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2125 steps.take_front(numProcessed),
2126 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2127 ivStorage.append(localIvs.begin(), localIvs.end());
2128 generateParallelLoopNest(
2129 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2130 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2131 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2132 bodyBuilderFn, ivStorage);
2133 });
2134 return;
2135 }
2137 // Generate a single parallel loop-nest operation for all outermost
2138 // parallel loops and recurse.
2139 scf::ParallelOp::create(
2140 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2141 steps.take_front(numProcessed),
2142 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2143 ivStorage.append(localIvs.begin(), localIvs.end());
2144 generateParallelLoopNest(
2145 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2146 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2147 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2148 bodyBuilderFn, ivStorage);
2149 });
2150 return;
2151 }
2153 // Check (for the processed loops) that the iteration is in-bounds.
2154 ArithBuilder ab(b, loc);
2155 Value cond = ab.slt(lbs[0], ubs[0]);
2156 for (unsigned i = 1; i < numProcessed; ++i)
2157 cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
2158 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2159 scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) {
2160 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
2161 ubs.drop_front(numProcessed),
2162 steps.drop_front(numProcessed),
2163 iteratorTypes.drop_front(numProcessed),
2164 remainderProcInfo, bodyBuilderFn, ivStorage);
2165 scf::YieldOp::create(b, loc, ValueRange{});
2166 });
2167 return;
2168 }
2170 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
2171 // with inner loop generation.
2172 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2174 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
2175 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
2176 remainderProcInfo, bodyBuilderFn, ivStorage);
2177 return;
2178 }
2179}
2180
2181/// Specialization for generating a mix of parallel and sequential scf loops.
2182template <>
2184 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2185 ArrayRef<utils::IteratorType> iteratorTypes,
2187 ValueRange)>
2188 bodyBuilderFn,
2189 ArrayRef<linalg::ProcInfo> procInfo) {
2190 SmallVector<Value> iterArgInitValues;
2191 if (!linalgOp.hasPureBufferSemantics())
2192 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2193 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
2194 // This function may be passed more iterator types than ranges.
2195 assert(iteratorTypes.size() >= loopRanges.size() &&
2196 "expected iterator type for all ranges");
2197 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2198 "expected proc information for all loops when present");
2199 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
2200 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
2201 unsigned numLoops = iteratorTypes.size();
2202 ivs.reserve(numLoops);
2203 lbsStorage.reserve(numLoops);
2204 ubsStorage.reserve(numLoops);
2205 stepsStorage.reserve(numLoops);
2206
2207 // Get the loop lb, ub, and step.
2208 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
2209
2210 // Modify the lb, ub, and step based on the distribution options.
2211 for (const auto &it : llvm::enumerate(procInfo)) {
2212 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
2214 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
2215 ubsStorage[it.index()], stepsStorage[it.index()]);
2216 }
2217 }
2218 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
2220 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
2221 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2222 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
2223 },
2224 ivs);
2225
2226 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
2227}
2228
2230 Value valueToTile,
2231 const SliceParameters &sliceParams) {
2232 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2233 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
2234 .Case([&](MemRefType) {
2235 return memref::SubViewOp::create(
2236 builder, loc, valueToTile, sliceParams.offsets,
2237 sliceParams.sizes, sliceParams.strides);
2238 })
2239 .Case([&](RankedTensorType) {
2240 return tensor::ExtractSliceOp::create(
2241 builder, loc, valueToTile, sliceParams.offsets,
2242 sliceParams.sizes, sliceParams.strides);
2243 })
2244 .DefaultUnreachable("Unexpected shaped type");
2245 return sliceOp;
2246}
2247
2249 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2252 ArrayRef<OpFoldResult> subShapeSizes,
2253 bool omitPartialTileCheck) {
2254 SliceParameters sliceParams =
2255 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
2256 ubs, subShapeSizes, omitPartialTileCheck);
2257 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
2258}
2259
2262 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2264 ArrayRef<OpFoldResult> subShapeSizes,
2265 bool omitPartialTileCheck) {
2266 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2267 assert(shapedType && "only shaped types can be tiled");
2268 ArrayRef<int64_t> shape = shapedType.getShape();
2269 int64_t rank = shapedType.getRank();
2270
2271 // Compute offsets/sizes/strides for the tile.
2272 SliceParameters sliceParams;
2273 sliceParams.offsets.reserve(rank);
2274 sliceParams.sizes.reserve(rank);
2275 sliceParams.strides.reserve(rank);
2276 for (unsigned r = 0; r < rank; ++r) {
2277 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
2278 if (!isTiled(map.getSubMap({r}), tileSizes)) {
2279 sliceParams.offsets.push_back(builder.getIndexAttr(0));
2280 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
2281 sliceParams.sizes.push_back(dim);
2282 sliceParams.strides.push_back(builder.getIndexAttr(1));
2283 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
2284 continue;
2285 }
2286 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
2287
2288 // Tiling creates a new slice at the proper index, the slice step is 1
2289 // (i.e. the op does not subsample, stepping occurs in the loop).
2290 auto m = map.getSubMap({r});
2291 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
2292 IRRewriter rewriter(builder);
2293 // The offset of the slice is m(lbs) - m(0).
2294 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
2295 SmallVector<Attribute> mAtZero;
2296 [[maybe_unused]] auto res = m.constantFold(zeros, mAtZero);
2297 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
2298 int64_t mAtZeroInt =
2299 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
2301 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
2302 sliceParams.offsets.push_back(offset);
2303
2304 OpFoldResult closedIntSize =
2305 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
2306 // Resulting size needs to be made half open interval again.
2307 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
2308 OpFoldResult size =
2309 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
2310 LLVM_DEBUG(llvm::dbgs()
2311 << "computeSliceParameters: raw size: " << size << "\n");
2312 LLVM_DEBUG(llvm::dbgs()
2313 << "computeSliceParameters: new offset: " << offset << "\n");
2314 sliceParams.strides.push_back(builder.getIndexAttr(1));
2315
2316 if (omitPartialTileCheck) {
2317 // We statically know that the partial/boundary tile condition is
2318 // unnecessary.
2319 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2320 sliceParams.sizes.push_back(size);
2321 continue;
2322 }
2323
2324 // The size of the subview / extract_slice should be trimmed to avoid
2325 // out-of-bounds accesses, unless:
2326 // a. We statically know the subshape size divides the shape size evenly.
2327 // b. The subshape size is 1. According to the way the loops are set up,
2328 // tensors with "0" dimensions would never be constructed.
2329 int64_t shapeSize = shape[r];
2330 std::optional<int64_t> sizeCst = getConstantIntValue(size);
2331 auto hasTileSizeOne = sizeCst == 1;
2332 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2333 ((shapeSize % *sizeCst) == 0);
2334 if (!hasTileSizeOne && !dividesEvenly) {
2335 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
2336 << ", size: " << size
2337 << ": make sure in bound with affine.min\n");
2338
2339 AffineExpr dim0, dim1, dim2;
2340 MLIRContext *context = builder.getContext();
2341 bindDims(context, dim0, dim1, dim2);
2342
2343 // Get the dimension size for this dimension. We need to first calculate
2344 // the max index and then plus one. This is important because for
2345 // convolution ops, we have its input window dimension's affine map of the
2346 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
2347 // dimension and `s0` is stride. Directly use the dimension size of
2348 // output/filer window dimensions will cause incorrect calculation.
2350 {ArrayRef<AffineExpr>{dim0 - 1}}, context)
2351 .front();
2353 {ArrayRef<AffineExpr>{dim0 + 1}}, context)
2354 .front();
2355 SmallVector<OpFoldResult> maxIndices =
2356 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
2357 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
2358 {ub});
2359 }));
2360 OpFoldResult maxIndex =
2361 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
2362 OpFoldResult d =
2363 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
2364
2365 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
2367 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
2368 .front();
2369 size =
2370 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
2371 }
2372 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2373 sliceParams.sizes.push_back(size);
2374 }
2375 return sliceParams;
2376}
2377
2380 ArrayRef<OpFoldResult> tileSizes) {
2382 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2383 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
2384 bool isTiled = !isZeroInteger(tileSizes[idx]);
2385 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
2386 LLVM_DEBUG(llvm::dbgs()
2387 << "computeTileOffsets: " << offsets.back() << "\n");
2388 }
2389 return offsets;
2390}
2391
2393 ArrayRef<OpFoldResult> tileSizes,
2394 ArrayRef<OpFoldResult> sizeBounds) {
2396 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2397 bool isTiled = !isZeroInteger(tileSizes[idx]);
2398 // Before composing, we need to make range a closed interval.
2399 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
2400 AffineExpr d0 = getAffineDimExpr(0, b.getContext());
2401 IRRewriter rewriter(b);
2402 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
2403 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
2404 }
2405 return sizes;
2406}
2407
2409 if (op.hasPureBufferSemantics())
2410 return {};
2411 return llvm::to_vector(
2412 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
2413 return operands[opOperand.getOperandNumber()].getType();
2414 }));
2415}
2416
2418 LinalgOp op, ValueRange operands,
2419 ValueRange results) {
2420 if (op.hasPureBufferSemantics())
2421 return {};
2422 SmallVector<Value> tensorResults;
2423 tensorResults.reserve(results.size());
2424 // Insert a insert_slice for each output tensor.
2425 unsigned resultIdx = 0;
2426 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2427 // TODO: use an interface/adaptor to avoid leaking position in
2428 // `tiledOperands`.
2429 Value outputTensor = operands[opOperand.getOperandNumber()];
2430 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
2431 Value inserted = tensor::InsertSliceOp::create(
2432 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2433 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2434 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2435 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2436 tensorResults.push_back(inserted);
2437 } else {
2438 tensorResults.push_back(results[resultIdx]);
2439 }
2440 ++resultIdx;
2441 }
2442 return tensorResults;
2443}
2444
2446computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
2447 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
2448 ArrayRef<OpFoldResult> tileSizes,
2449 ArrayRef<OpFoldResult> sizeBounds,
2450 bool omitPartialTileCheck) {
2451 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
2452 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2453 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
2454 "expected as many ivs as non-zero sizes");
2455
2456 // Construct (potentially temporary) mins and maxes on which to apply maps
2457 // that define tile subshapes.
2459 computeTileOffsets(builder, loc, ivs, tileSizes);
2460 SmallVector<OpFoldResult> subShapeSizes =
2461 computeTileSizes(builder, loc, tileSizes, sizeBounds);
2462
2463 assert(static_cast<int64_t>(valuesToTile.size()) <=
2464 linalgOp->getNumOperands() &&
2465 "more value to tile than operands.");
2467 allSliceParams.reserve(valuesToTile.size());
2468 for (auto [opOperand, val] :
2469 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2470 Value shapedOp = val;
2471 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
2472 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2473 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
2474 // an extract/insert slice pair for all output tensors simplifies follow up
2475 // transformations such as padding and bufferization since the
2476 // extract/insert slice pairs make the accessed iteration argument
2477 // subdomains explicit.
2478
2479 Type operandType = opOperand.get().getType();
2480 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2481 linalgOp.isDpsInit(&opOperand))) {
2482 allSliceParams.push_back(std::nullopt);
2483 LLVM_DEBUG(llvm::dbgs()
2484 << ": not tiled: use shape: " << operandType << "\n");
2485 continue;
2486 }
2487 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
2488
2489 allSliceParams.push_back(computeSliceParameters(
2490 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2491 omitPartialTileCheck));
2492 }
2493
2494 return allSliceParams;
2495}
2496
2498 LinalgOp linalgOp, ValueRange valuesToTile,
2500 ArrayRef<OpFoldResult> tileSizes,
2501 ArrayRef<OpFoldResult> sizeBounds,
2502 bool omitPartialTileCheck) {
2504 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
2505 tileSizes, sizeBounds, omitPartialTileCheck);
2506 SmallVector<Value> tiledShapes;
2507 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2508 Value valueToTile = std::get<0>(item);
2509 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2510 tiledShapes.push_back(
2511 sliceParams.has_value()
2512 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
2513 ->getResult(0)
2514 : valueToTile);
2515 }
2516 return tiledShapes;
2517}
2518
2519void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
2520 ArrayRef<OpFoldResult> offsets) {
2521 IRRewriter rewriter(b);
2522 offsetIndices(rewriter, linalgOp, offsets);
2523}
2524
2525void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
2526 ArrayRef<OpFoldResult> offsets) {
2527 if (!linalgOp.hasIndexSemantics())
2528 return;
2529
2530 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2531 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2532 continue;
2534 b.setInsertionPointAfter(indexOp);
2535 AffineExpr index, offset;
2536 bindDims(b.getContext(), index, offset);
2538 b, indexOp.getLoc(), index + offset,
2539 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2540 Value materialized =
2541 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
2542 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
2543 return use.getOwner() != materialized.getDefiningOp();
2544 });
2545 }
2546}
2547
2548/// Get the reassociation maps to fold the result of a extract_slice (or source
2549/// of a insert_slice) operation with given offsets, and sizes to its
2550/// rank-reduced version. This is only done for the cases where the size is 1
2551/// and offset is 0. Strictly speaking the offset 0 is not required in general,
2552/// but non-zero offsets are not handled by SPIR-V backend at this point (and
2553/// potentially cannot be handled).
2554std::optional<SmallVector<ReassociationIndices>>
2558 for (const auto &it : llvm::enumerate(mixedSizes)) {
2559 auto dim = it.index();
2560 auto size = it.value();
2561 curr.push_back(dim);
2562 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2563 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2564 continue;
2565 reassociation.emplace_back(ReassociationIndices{});
2566 std::swap(reassociation.back(), curr);
2567 }
2568 // When the reassociations are not empty, then fold the remaining
2569 // unit-dimensions into the last dimension. If the reassociations so far is
2570 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
2571 if (!curr.empty() && !reassociation.empty())
2572 reassociation.back().append(curr.begin(), curr.end());
2573 return reassociation;
2574}
2575
2576} // namespace linalg
2577} // namespace mlir
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
Definition Utils.cpp:150
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:74
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
Definition Utils.cpp:126
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
Definition Types.cpp:108
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Helper class for building convolution op matchers with minimal boilerplate.
Definition Utils.cpp:527
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
Definition Utils.cpp:555
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
Definition Utils.cpp:572
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
Definition Utils.cpp:549
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
Definition Utils.cpp:546
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
Definition Utils.cpp:536
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
Definition Utils.cpp:565
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1400
bool isaConvolutionOpOfType< linalg::PoolingNcwMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1716
bool isaConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1485
bool isaConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:721
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1204
bool isaConvolutionOpOfType< linalg::PoolingNwcMinUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1764
bool isaConvolutionOpOfType< linalg::PoolingNdhwcSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1788
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2497
bool isaConvolutionOpOfType< linalg::PoolingNwcMinOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1740
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
Definition Utils.cpp:326
bool isaConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:601
bool isaConvolutionOpOfType< linalg::PoolingNwcSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1620
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
Definition Utils.cpp:2392
bool isaConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:646
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1311
bool isaConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1039
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
Definition Utils.cpp:1948
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1340
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
Definition Utils.cpp:2068
bool isaConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:980
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
Definition Utils.cpp:2378
bool isaConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:670
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
Definition Utils.cpp:501
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1157
bool isaConvolutionOpOfType< linalg::PoolingNwcMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1668
bool isaConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:862
bool isaConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:806
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:390
bool isaConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1011
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1431
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
Definition Utils.cpp:273
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:371
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
Definition Utils.cpp:201
bool isaConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:891
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
Definition Utils.cpp:399
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
Definition Utils.cpp:244
bool isaConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1103
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
Definition Utils.cpp:355
bool isaConvolutionOpOfType< linalg::PoolingNwcMaxUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1692
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1230
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2555
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
DistributionMethod
Scheme used to distribute loops to processors.
Definition Utils.h:262
@ None
No Distribution.
Definition Utils.h:307
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
Definition Utils.h:292
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
Definition Utils.h:274
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
Definition Utils.h:304
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:378
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:2417
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1134
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1284
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
bool isaConvolutionOpOfType< linalg::PoolingNchwMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1593
bool isaConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1070
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2519
bool isaConvolutionOpOfType< linalg::PoolingNdhwcMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1819
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:395
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:920
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1180
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1458
bool isaConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:833
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
Definition Utils.cpp:2446
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
Definition Utils.cpp:2248
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
Definition Utils.cpp:488
bool isaConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:951
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:383
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
Definition Utils.cpp:454
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1539
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1256
bool isaConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:750
bool isaConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:694
bool isaConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:777
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1370
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
Definition Utils.cpp:2229
bool isaConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:622
bool isaConvolutionOpOfType< linalg::PoolingNdhwcMinOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1850
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition Utils.cpp:1880
bool isaConvolutionOpOfType< linalg::PoolingNcwSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1644
bool isaConvolutionOpOfType< linalg::PoolingNchwSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1566
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
Definition Utils.cpp:412
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
Definition Utils.cpp:2047
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2408
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:2261
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1512
auto m_Val(Value v)
Definition Matchers.h:539
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:838
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
Definition Matchers.h:379
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:311
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:334
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
Definition Utils.h:376
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:312
DistributionMethod distributionMethod
Definition Utils.h:315
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
Definition Utils.cpp:93
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:158
SmallVector< OpFoldResult > strides
Definition Utils.h:161
SmallVector< OpFoldResult > sizes
Definition Utils.h:160
SmallVector< OpFoldResult > offsets
Definition Utils.h:159
LoopVector loops
Definition SCF.h:67