MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/AffineMap.h"
32#include "mlir/IR/Matchers.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <optional>
36
37#define DEBUG_TYPE "linalg-utils"
38
39using namespace mlir;
40using namespace presburger;
41using namespace mlir::affine;
42using namespace mlir::linalg;
43using namespace mlir::scf;
44
45namespace {
46
47// Helper visitor to determine whether an AffineExpr is tiled.
48// This is achieved by traversing every AffineDimExpr with position `pos` and
49// checking whether the corresponding `tileSizes[pos]` is non-zero.
50// This also enforces only positive coefficients occur in multiplications.
51//
52// Example:
53// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
54//
55struct TileCheck : public AffineExprVisitor<TileCheck> {
56 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
57
58 void visitDimExpr(AffineDimExpr expr) {
59 isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
60 }
61 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
62 visit(expr.getLHS());
63 visit(expr.getRHS());
65 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
67 }
68 bool isTiled = false;
69 ArrayRef<OpFoldResult> tileSizes;
70};
71
72} // namespace
73
74static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
75 if (!expr)
76 return false;
77 TileCheck t(tileSizes);
78 t.visit(expr);
79 return t.isTiled;
80}
81
82// Checks whether the `map varies with respect to a non-zero `tileSize`.
83static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
84 if (!map)
85 return false;
86 for (unsigned r = 0; r < map.getNumResults(); ++r)
87 if (isTiled(map.getResult(r), tileSizes))
88 return true;
89 return false;
90}
91
92std::optional<RegionMatcher::BinaryOpKind>
94 auto &region = op.getRegion();
95 if (!region.hasOneBlock())
96 return std::nullopt;
97
98 Block &block = region.front();
99 if (block.getNumArguments() != 2 ||
102 return std::nullopt;
103
104 auto &ops = block.getOperations();
105 if (!llvm::hasSingleElement(block.without_terminator()))
106 return std::nullopt;
107
109 auto a = m_Val(block.getArgument(0));
110 auto b = m_Val(block.getArgument(1));
111
112 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
113 if (addPattern.match(&ops.back()))
114 return BinaryOpKind::IAdd;
115
116 return std::nullopt;
117}
118
119/// Explicit instantiation of loop nest generator for different loop types.
123
124/// Given a list of subview ranges, extract individual values for lower, upper
125/// bounds and steps and put them into the corresponding vectors.
126static void unpackRanges(OpBuilder &builder, Location loc,
129 SmallVectorImpl<Value> &steps) {
130 for (Range range : ranges) {
131 lbs.emplace_back(
132 getValueOrCreateConstantIndexOp(builder, loc, range.offset));
133 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
134 steps.emplace_back(
135 getValueOrCreateConstantIndexOp(builder, loc, range.stride));
136 }
137}
138
139//===----------------------------------------------------------------------===//
140// General utilities
141//===----------------------------------------------------------------------===//
142//
143/// The permutation can be obtained from two permutations:
144/// a) Compute the permutation vector to move the last `numPackedDims` into
145/// the `innerPosDims` of a shape of rank `rank`.
146/// b) Compute the permutation vector to move outer dims if the
147/// `outerPerm` parameter is not empty.
148/// Apply (b) permutation on (a) permutation to get the final permutation.
149static SmallVector<int64_t>
151 ArrayRef<int64_t> &outerPerm,
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
154 auto lastDims =
155 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
156 packingMetadata = computePackingMetadata(rank, innerDimsPos);
157 SmallVector<int64_t> innerPositionsPerm =
158 computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
159
160 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
161 if (!outerPerm.empty())
162 applyPermutationToVector(outerPos, outerPerm);
163 SmallVector<int64_t> outerPositionPerm =
164 computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
165
166 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
167 applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
168 return packInverseDestPermutation;
169}
170
171namespace mlir {
172namespace linalg {
173
175 PackingMetadata &metadata) {
176
177 int64_t packedRank = packOp.getDestType().getRank();
178 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
180 SmallVector<int64_t> packInvDestPerm =
181 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
182 return packInvDestPerm;
183}
184
186 PackingMetadata &metadata) {
187 int64_t packedRank = unpackOp.getSourceType().getRank();
188 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
189 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
190 SmallVector<int64_t> unpackInvSrcPerm =
191 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
192 return unpackInvSrcPerm;
193}
194
196 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
197 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
198 });
199}
200
202 if (!r.hasOneBlock())
203 return false;
204 for (Operation &op : r.front()) {
205 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
206 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
208 llvm::any_of(op.getResultTypes(),
209 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
210 return false;
211 }
212 return true;
213}
214
215bool isElementwise(LinalgOp op) {
216 if (op.getNumLoops() != op.getNumParallelLoops())
217 return false;
218
220 return false;
221
222 // TODO: relax the restrictions on indexing map.
223 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
224 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
225 return false;
226 }
227 return hasOnlyScalarElementwiseOp(op->getRegion(0));
228}
229
230bool isParallelIterator(utils::IteratorType iteratorType) {
231 return iteratorType == utils::IteratorType::parallel;
232}
233
234bool isReductionIterator(utils::IteratorType iteratorType) {
235 return iteratorType == utils::IteratorType::reduction;
236}
237
238//===----------------------------------------------------------------------===//
239// Convolution matcher utilities
240//===----------------------------------------------------------------------===//
241
242/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
243/// ext*/sitofp ops.
245 BlockArgument blockArg = dyn_cast<BlockArgument>(val);
246 if ((blockArg))
247 return blockArg;
248
249 Operation *defOp = val.getDefiningOp();
250 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
251 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
252 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
253 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
254 return nullptr;
255 }
256 return dyn_cast<BlockArgument>(defOp->getOperand(0));
257}
258
259/// Utility function to match the zero point offset body of quantized
260/// convolution ops.
261///
262/// Quantized convolutions have a body of the form:
263/// %out + ((%input - %inputZp) * (%filter - %filterZp))
264/// where:
265/// - %input is the input tensor element (block arg 0)
266/// - %filter is the filter tensor element (block arg 1)
267/// - %inputZp is the input zero-point scalar (block arg 2)
268/// - %filterZp is the filter zero-point scalar (block arg 3)
269/// - %out is the output accumulator (block arg 4)
270///
271/// This function verifies that the multiplication operands are subtraction
272/// operations matching this pattern.
274 Block *body) {
275 // The multiplication should have two subtraction operands:
276 // one for (input - inputZp) and one for (filter - filterZp).
277 Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
278 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
279 return false;
280
281 Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
282 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
283 return false;
284
285 // Extract block arguments from subtraction operands.
286 BlockArgument inputBlockArg =
288 BlockArgument inputZpBlockArg =
290 BlockArgument filterBlockArg =
292 BlockArgument filterZpBlockArg =
294 BlockArgument outBlockArg =
296
297 // Verify all block arguments are valid.
298 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
299 !filterZpBlockArg || !outBlockArg)
300 return false;
301
302 // Verify all block arguments belong to the convolution body.
303 if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body ||
304 filterBlockArg.getOwner() != body ||
305 filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body)
306 return false;
307
308 // Verify block arguments have expected indices:
309 // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output
310 if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 ||
311 inputZpBlockArg.getArgNumber() != 2 ||
312 filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4)
313 return false;
314
315 return true;
316}
317
318/// Utility to match block body for convolution ops.
319/// The body is thus expected to yield :-
320/// %out + (%lhs * %rhs)
321/// where: %lhs, %rhs and %out are block arguments and
322/// %lhs and %rhs can have optional upcast operation.
323/// For i1 element types, the pattern matches:
324/// %out | (%lhs & %rhs)
325/// using arith.ori for accumulation and arith.andi for multiplication.
326/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
327/// %input - %input_scalar
328/// where, %input_scalar can have optional upcast operation.
329static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
330 bool containsZeroPointOffset = false) {
331 bool isOrOp = false;
332 Operation *accOp = yieldVal.getDefiningOp();
333 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
334 if (!isa_and_present<arith::OrIOp>(accOp))
335 return false;
336 isOrOp = true;
337 }
338
339 Operation *mulOp = accOp->getOperand(1).getDefiningOp();
340 if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
341 return false;
342 if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
343 return false;
344
345 if (containsZeroPointOffset) {
346 return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
347 }
348 BlockArgument lhsBlockArg =
350 BlockArgument rhsBlockArg =
352 BlockArgument outBlockArg =
354 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
355 lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
356 outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
357 rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
358 return false;
359 return true;
360}
361
362/// Utility to match block body for linalg.pool* ops.
363template <typename... OpTypes>
364static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
365 Operation *defOp = yieldVal.getDefiningOp();
366 if (!(isa_and_present<OpTypes>(defOp) || ...))
367 return false;
368
369 BlockArgument lhsArg =
371 BlockArgument rhsArg =
373 if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
374 rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
375 rhsArg.getArgNumber() != 0)
376 return false;
377 return true;
378}
379
380static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
382 body);
383}
384
385static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
386 return bodyMatcherForPoolOps<arith::MaxUIOp>(yieldVal, body);
387}
388
389static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
391 body);
392}
393
394static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
395 return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
396}
397
398/// Matches sum pooling body pattern. For i1 element types, arith.ori is used
399/// instead of arith.addi/arith.addf for accumulation.
400static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
402 yieldVal, body);
403}
404
405static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
406 uint32_t dimIndex) {
407 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
408 if (dimIndex < affineMap.getNumResults())
409 return affineMap.getResult(dimIndex);
410 return nullptr;
411}
412
413/// Check if `expr` is either:
414/// - a dimension expr alone (implying multiplication by 1), or
415/// - a multiplication of dimension expr by any positive constant != 1
416/// In both cases we will capture the dimension expression into `dim` and
417/// return the constant multiplier. Returns -1 in case of a match failure.
419 if ((dim = dyn_cast<AffineDimExpr>(expr)))
420 return 1;
421
422 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
423 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
424 return -1;
425
426 AffineExpr lhs = mulExpr.getLHS();
427 AffineExpr rhs = mulExpr.getRHS();
428
429 AffineConstantExpr cst = nullptr;
430 if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
431 (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
432 ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
433 (cst = dyn_cast<AffineConstantExpr>(lhs))))
434 return cst.getValue();
435 return -1;
436}
437
438/// Given an array of AffineMaps `indexingMaps` verify the following
439/// commutatively:-
440/// indexingMaps[0].getResult(iDim) ==
441/// indexingMaps[1].getResult(fDim) * <c0> +
442/// indexingMaps[n-1].getResult(oDim) * <c1>
443/// where,
444/// - c0 and c1 can be any constant,
445/// - n is the size of the indexingMaps' array,
446/// - 0, 1 and n-1 are input, filter and output map indices respectively,
447/// - iDim, fDim and oDim are the input, filter and output dimension
448/// indices in their respective indexing maps
449/// Example:
450/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
451/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
452/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
453/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
454///
455/// Here,
456/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
457/// Therefore,
458/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
459/// would return true and update dilation = 3 and stride = 2
460static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
461 unsigned fDim, unsigned oDim,
462 int64_t &dilation, int64_t &stride) {
463 unsigned inputMapIdx = 0, filterMapIdx = 1,
464 outputMapIdx = indexingMaps.size() - 1;
465 AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
466 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
467 if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
468 return false;
469
470 AffineExpr dim0, dim1;
471 int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
472 int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
473
474 if (c0 == -1 || c1 == -1)
475 return false;
476 // Pattern matched with dims and constants extracted.
477 AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
478 AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
479 if (dim0 == fExpr && dim1 == oExpr) {
480 dilation = c0;
481 stride = c1;
482 return true;
483 }
484 if (dim1 == fExpr && dim0 == oExpr) {
485 dilation = c1;
486 stride = c0;
487 return true;
488 }
489 return false;
490}
491
492/// Returns true if the given indexing maps matches with the expected indexing
493/// maps.
495 ArrayAttr indexingMaps, MLIRContext *context) {
496 SmallVector<AffineMap, 4> expectedIndexingMaps =
497 AffineMap::inferFromExprList(mapListExpected, context);
498 return indexingMaps ==
499 ArrayAttr::get(
500 context, llvm::to_vector<4>(llvm::map_range(
501 expectedIndexingMaps, [&](AffineMap m) -> Attribute {
502 return AffineMapAttr::get(m);
503 })));
504}
505
506/// Enum representing pooling operation types used by ConvMatcherBuilder.
515
516/// Helper class for building convolution op matchers with minimal boilerplate.
517/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
518/// as Pooling ops.
519///
520/// Usage: Create an instance with the op, spatial rank, and output pointers for
521/// extracted dilations/strides. Then chain matchStride() calls for each spatial
522/// dimension, followed by matchMaps() to verify indexing maps, and finally
523/// matchBody() to verify the operation body pattern.
524///
525/// The `matched` flag starts as `true` and is set to `false` if any match step
526/// fails. This allows chaining multiple match calls; once any match fails, all
527/// subsequent calls become no-ops and the final result is `false`.
528///
529/// The `dilations` and `strides` pointers are output parameters that get
530/// populated with the extracted dilation and stride values from the operation's
531/// indexing maps during matchStride() calls. These values are initially set to
532/// 1 for each spatial dimension and updated as patterns are matched.
534 LinalgOp op;
535 MLIRContext *ctx;
536 SmallVector<int64_t> *dilations, *strides;
537 ArrayAttr indexingMaps;
538 PoolingType poolingType;
539 bool matched = true;
540
541public:
542 ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
544 PoolingType poolingType = PoolingType::None)
545 : op(op), ctx(op->getContext()), dilations(d), strides(s),
546 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
547 *dilations = SmallVector<int64_t>(spatialRank, 1);
548 *strides = SmallVector<int64_t>(spatialRank, 1);
549 }
550
551 /// Get affine dimension expression for dimension `i`.
552 AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
553
554 /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
555 AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
556 return base * (*strides)[idx] + kernel * (*dilations)[idx];
557 }
558
559 /// Match stride/dilation pattern for a spatial dimension.
560 /// Returns *this for method chaining.
561 ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
562 unsigned idx) {
563 if (matched) {
564 matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
565 (*dilations)[idx], (*strides)[idx]);
566 }
567 return *this;
568 }
569
570 /// Match expected indexing maps layout. Returns *this for method chaining.
572 if (matched)
573 matched &= convLayoutMatches(maps, indexingMaps, ctx);
574 return *this;
575 }
576
577 /// Match body pattern. This should be called last.
578 bool matchBody(bool containsZeroPointOffset = false) {
579 if (!matched)
580 return false;
581 Block *body = op.getBlock();
582 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
583 switch (poolingType) {
585 return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
586 containsZeroPointOffset);
588 return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
590 return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
592 return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
594 return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
595 case PoolingType::Sum:
596 return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
597 }
598 return false;
599 }
600};
601
602//===----------------------------------------------------------------------===//
603// Matchers for specific convolution operation.
604//===----------------------------------------------------------------------===//
605
606template <>
607std::optional<DilationsAndStrides>
610 if (isa<linalg::Conv1DOp>(op)) {
611 // Conv1DOp has no strides/dilations attributes, default to 1.
612 result.dilations = SmallVector<int64_t>(1, 1);
613 result.strides = SmallVector<int64_t>(1, 1);
614 return result;
615 }
616
618 return std::nullopt;
619
620 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
621 &result.strides);
622 AffineExpr W = m.dim(0);
623 AffineExpr w = m.dim(1);
624
625 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
626 .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
627 /*filterMap=*/{w},
628 /*outputMap=*/{W}})
629 .matchBody())
630 return result;
631 return std::nullopt;
632}
633
634template <>
635std::optional<DilationsAndStrides>
638 if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
639 result.dilations =
640 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
641 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
642 return result;
643 }
644
646 return std::nullopt;
647
648 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
649 &result.strides);
650 AffineExpr N = m.dim(0);
651 AffineExpr W = m.dim(1);
652 AffineExpr F = m.dim(2);
653 AffineExpr w = m.dim(3);
654 AffineExpr c = m.dim(4);
655
656 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
657 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
658 /*filterMap=*/{w, c, F},
659 /*outputMap=*/{N, W, F}})
660 .matchBody())
661 return result;
662 return std::nullopt;
663}
664
665template <>
666std::optional<DilationsAndStrides>
669 if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
670 result.dilations =
671 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
672 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
673 return result;
674 }
675
677 return std::nullopt;
678
679 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
680 &result.strides);
681 AffineExpr N = m.dim(0);
682 AffineExpr F = m.dim(1);
683 AffineExpr W = m.dim(2);
684 AffineExpr c = m.dim(3);
685 AffineExpr w = m.dim(4);
686
687 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
688 .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
689 /*filterMap=*/{F, c, w},
690 /*outputMap=*/{N, F, W}})
691 .matchBody())
692 return result;
693 return std::nullopt;
694}
695
696template <>
697std::optional<DilationsAndStrides>
700 if (isa<linalg::Conv2DOp>(op)) {
701 // Conv2DOp has no strides/dilations attributes, default to 1.
702 result.dilations = SmallVector<int64_t>(2, 1);
703 result.strides = SmallVector<int64_t>(2, 1);
704 return result;
705 }
706
708 return std::nullopt;
709
710 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
711 &result.strides);
712 AffineExpr H = m.dim(0);
713 AffineExpr W = m.dim(1);
714 AffineExpr h = m.dim(2);
715 AffineExpr w = m.dim(3);
716
717 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
718 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
719 .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
720 /*filterMap=*/{h, w},
721 /*outputMap=*/{H, W}})
722 .matchBody())
723 return result;
724 return std::nullopt;
725}
726
727template <>
728std::optional<DilationsAndStrides>
731 if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
732 result.dilations =
733 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
734 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
735 return result;
736 }
737
739 return std::nullopt;
740
741 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
742 &result.strides);
743 AffineExpr N = m.dim(0);
744 AffineExpr H = m.dim(1);
745 AffineExpr W = m.dim(2);
746 AffineExpr F = m.dim(3);
747 AffineExpr h = m.dim(4);
748 AffineExpr w = m.dim(5);
749 AffineExpr c = m.dim(6);
750
751 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
752 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
753 .matchMaps(
754 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
755 /*filterMap=*/{h, w, c, F},
756 /*outputMap=*/{N, H, W, F}})
757 .matchBody())
758 return result;
759 return std::nullopt;
760}
761
762template <>
763std::optional<DilationsAndStrides>
766 if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
767 result.dilations =
768 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
769 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
770 return result;
771 }
772
774 return std::nullopt;
775
776 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
777 &result.strides);
778 AffineExpr N = m.dim(0);
779 AffineExpr H = m.dim(1);
780 AffineExpr W = m.dim(2);
781 AffineExpr F = m.dim(3);
782 AffineExpr h = m.dim(4);
783 AffineExpr w = m.dim(5);
784 AffineExpr c = m.dim(6);
785
786 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
787 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
788 .matchMaps(
789 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
790 /*filterMap=*/{h, w, c, F},
791 /*scalarMap=*/{},
792 /*scalarMap=*/{},
793 /*outputMap=*/{N, H, W, F}})
794 .matchBody(/*containsZeroPointOffset=*/true))
795 return result;
796 return std::nullopt;
797}
798
799template <>
800std::optional<DilationsAndStrides>
803 if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
804 result.dilations =
805 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
806 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
807 return result;
808 }
809
811 return std::nullopt;
812
813 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
814 &result.strides);
815 AffineExpr N = m.dim(0);
816 AffineExpr H = m.dim(1);
817 AffineExpr W = m.dim(2);
818 AffineExpr F = m.dim(3);
819 AffineExpr h = m.dim(4);
820 AffineExpr w = m.dim(5);
821 AffineExpr c = m.dim(6);
822
823 if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
824 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
825 .matchMaps(
826 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
827 /*filterMap=*/{F, h, w, c},
828 /*outputMap=*/{N, H, W, F}})
829 .matchBody())
830 return result;
831 return std::nullopt;
832}
833
834template <>
835std::optional<DilationsAndStrides>
838 if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
839 result.dilations =
840 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
841 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
842 return result;
843 }
844
846 return std::nullopt;
847
848 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
849 &result.strides);
850 AffineExpr N = m.dim(0);
851 AffineExpr H = m.dim(1);
852 AffineExpr W = m.dim(2);
853 AffineExpr F = m.dim(3);
854 AffineExpr h = m.dim(4);
855 AffineExpr w = m.dim(5);
856 AffineExpr c = m.dim(6);
857
858 if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
859 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
860 .matchMaps(
861 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
862 /*filterMap=*/{F, h, w, c},
863 /*scalarMap=*/{},
864 /*scalarMap=*/{},
865 /*outputMap=*/{N, H, W, F}})
866 .matchBody(/*containsZeroPointOffset=*/true))
867 return result;
868 return std::nullopt;
869}
870
871template <>
872std::optional<DilationsAndStrides>
875 if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
876 result.dilations =
877 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
878 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
879 return result;
880 }
881
883 return std::nullopt;
884
885 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
886 &result.strides);
887 AffineExpr N = m.dim(0);
888 AffineExpr F = m.dim(1);
889 AffineExpr H = m.dim(2);
890 AffineExpr W = m.dim(3);
891 AffineExpr c = m.dim(4);
892 AffineExpr h = m.dim(5);
893 AffineExpr w = m.dim(6);
894
895 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
896 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
897 .matchMaps(
898 {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
899 /*filterMap=*/{F, c, h, w},
900 /*outputMap=*/{N, F, H, W}})
901 .matchBody())
902 return result;
903 return std::nullopt;
904}
905
906template <>
907std::optional<DilationsAndStrides>
910 if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
911 result.dilations =
912 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
913 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
914 return result;
915 }
916
918 return std::nullopt;
919
920 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
921 &result.strides);
922 AffineExpr N = m.dim(0);
923 AffineExpr F = m.dim(1);
924 AffineExpr H = m.dim(2);
925 AffineExpr W = m.dim(3);
926 AffineExpr c = m.dim(4);
927 AffineExpr h = m.dim(5);
928 AffineExpr w = m.dim(6);
929
930 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
931 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
932 .matchMaps(
933 {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
934 /*filterMap=*/{F, c, h, w},
935 /*scalarMap=*/{},
936 /*scalarMap=*/{},
937 /*outputMap=*/{N, F, H, W}})
938 .matchBody(/*containsZeroPointOffset=*/true))
939 return result;
940 return std::nullopt;
941}
942
943template <>
944std::optional<DilationsAndStrides>
947 if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
948 result.dilations =
949 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
950 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
951 return result;
952 }
953
955 return std::nullopt;
956
957 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
958 &result.strides);
959 AffineExpr N = m.dim(0);
960 AffineExpr G = m.dim(1);
961 AffineExpr F = m.dim(2);
962 AffineExpr H = m.dim(3);
963 AffineExpr W = m.dim(4);
964 AffineExpr c = m.dim(5);
965 AffineExpr h = m.dim(6);
966 AffineExpr w = m.dim(7);
967
968 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
969 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
970 .matchMaps(
971 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
972 /*filterMap=*/{F, G, c, h, w},
973 /*outputMap=*/{N, G, F, H, W}})
974 .matchBody())
975 return result;
976 return std::nullopt;
977}
978
979template <>
980std::optional<DilationsAndStrides>
983 if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
984 result.dilations =
985 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
986 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
987 return result;
988 }
989
991 return std::nullopt;
992
993 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
994 &result.strides);
995 AffineExpr N = m.dim(0);
996 AffineExpr G = m.dim(1);
997 AffineExpr F = m.dim(2);
998 AffineExpr H = m.dim(3);
999 AffineExpr W = m.dim(4);
1000 AffineExpr c = m.dim(5);
1001 AffineExpr h = m.dim(6);
1002 AffineExpr w = m.dim(7);
1003
1004 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
1005 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
1006 .matchMaps(
1007 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1008 /*filterMap=*/{G, F, c, h, w},
1009 /*outputMap=*/{N, G, F, H, W}})
1010 .matchBody())
1011 return result;
1012 return std::nullopt;
1013}
1014
1015template <>
1016std::optional<DilationsAndStrides>
1019 if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
1020 result.dilations =
1021 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1022 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1023 return result;
1024 }
1025
1027 return std::nullopt;
1028
1029 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1030 &result.strides);
1031 AffineExpr N = m.dim(0);
1032 AffineExpr G = m.dim(1);
1033 AffineExpr F = m.dim(2);
1034 AffineExpr H = m.dim(3);
1035 AffineExpr W = m.dim(4);
1036 AffineExpr c = m.dim(5);
1037 AffineExpr h = m.dim(6);
1038 AffineExpr w = m.dim(7);
1039
1040 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
1041 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
1042 .matchMaps(
1043 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1044 /*filterMap=*/{G, F, c, h, w},
1045 /*scalarMap=*/{},
1046 /*scalarMap=*/{},
1047 /*outputMap=*/{N, G, F, H, W}})
1048 .matchBody(/*containsZeroPointOffset=*/true))
1049 return result;
1050 return std::nullopt;
1051}
1052
1053template <>
1054std::optional<DilationsAndStrides>
1057 if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
1058 result.dilations =
1059 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1060 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1061 return result;
1062 }
1063
1065 return std::nullopt;
1066
1067 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1068 &result.strides);
1069 AffineExpr N = m.dim(0);
1070 AffineExpr H = m.dim(1);
1071 AffineExpr W = m.dim(2);
1072 AffineExpr G = m.dim(3);
1073 AffineExpr F = m.dim(4);
1074 AffineExpr h = m.dim(5);
1075 AffineExpr w = m.dim(6);
1076 AffineExpr c = m.dim(7);
1077
1078 if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1079 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1080 .matchMaps(
1081 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1082 /*filterMap=*/{G, F, h, w, c},
1083 /*outputMap=*/{N, H, W, G, F}})
1084 .matchBody())
1085 return result;
1086 return std::nullopt;
1087}
1088
1089template <>
1090std::optional<DilationsAndStrides>
1093 if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
1094 result.dilations =
1095 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1096 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1097 return result;
1098 }
1099
1101 return std::nullopt;
1102
1103 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1104 &result.strides);
1105 AffineExpr N = m.dim(0);
1106 AffineExpr H = m.dim(1);
1107 AffineExpr W = m.dim(2);
1108 AffineExpr G = m.dim(3);
1109 AffineExpr F = m.dim(4);
1110 AffineExpr h = m.dim(5);
1111 AffineExpr w = m.dim(6);
1112 AffineExpr c = m.dim(7);
1113
1114 if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1115 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1116 .matchMaps(
1117 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1118 /*filterMap=*/{G, F, h, w, c},
1119 /*scalarMap=*/{},
1120 /*scalarMap=*/{},
1121 /*outputMap=*/{N, H, W, G, F}})
1122 .matchBody(/*containsZeroPointOffset=*/true))
1123 return result;
1124 return std::nullopt;
1125}
1126
1127template <>
1128std::optional<DilationsAndStrides>
1131 if (isa<linalg::Conv3DOp>(op)) {
1132 // Conv3DOp has no strides/dilations attributes, default to 1.
1133 result.dilations = SmallVector<int64_t>(3, 1);
1134 result.strides = SmallVector<int64_t>(3, 1);
1135 return result;
1136 }
1137
1139 return std::nullopt;
1140
1141 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1142 &result.strides);
1143 AffineExpr D = m.dim(0);
1144 AffineExpr H = m.dim(1);
1145 AffineExpr W = m.dim(2);
1146 AffineExpr d = m.dim(3);
1147 AffineExpr h = m.dim(4);
1148 AffineExpr w = m.dim(5);
1149
1150 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
1151 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
1152 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
1153 .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
1154 m.strided(W, w, 2)},
1155 /*filterMap=*/{d, h, w},
1156 /*outputMap=*/{D, H, W}})
1157 .matchBody())
1158 return result;
1159 return std::nullopt;
1160}
1161
1162template <>
1163std::optional<DilationsAndStrides>
1166 if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
1167 result.dilations =
1168 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1169 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1170 return result;
1171 }
1172
1174 return std::nullopt;
1175
1176 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1177 &result.strides);
1178 AffineExpr N = m.dim(0);
1179 AffineExpr D = m.dim(1);
1180 AffineExpr H = m.dim(2);
1181 AffineExpr W = m.dim(3);
1182 AffineExpr F = m.dim(4);
1183 AffineExpr d = m.dim(5);
1184 AffineExpr h = m.dim(6);
1185 AffineExpr w = m.dim(7);
1186 AffineExpr c = m.dim(8);
1187
1188 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1189 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1190 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1191 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1192 m.strided(W, w, 2), c},
1193 /*filterMap=*/{d, h, w, c, F},
1194 /*outputMap=*/{N, D, H, W, F}})
1195 .matchBody())
1196 return result;
1197 return std::nullopt;
1198}
1199
1200template <>
1201std::optional<DilationsAndStrides>
1204 if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
1205 result.dilations =
1206 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1207 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1208 return result;
1209 }
1210
1212 return std::nullopt;
1213
1214 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1215 &result.strides);
1216 AffineExpr N = m.dim(0);
1217 AffineExpr D = m.dim(1);
1218 AffineExpr H = m.dim(2);
1219 AffineExpr W = m.dim(3);
1220 AffineExpr F = m.dim(4);
1221 AffineExpr d = m.dim(5);
1222 AffineExpr h = m.dim(6);
1223 AffineExpr w = m.dim(7);
1224 AffineExpr c = m.dim(8);
1225
1226 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1227 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1228 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1229 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1230 m.strided(W, w, 2), c},
1231 /*filterMap=*/{d, h, w, c, F},
1232 /*scalarMap=*/{},
1233 /*scalarMap=*/{},
1234 /*outputMap=*/{N, D, H, W, F}})
1235 .matchBody(/*containsZeroPointOffset=*/true))
1236 return result;
1237 return std::nullopt;
1238}
1239
1240template <>
1241std::optional<DilationsAndStrides>
1244 if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
1245 result.dilations =
1246 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1247 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1248 return result;
1249 }
1250
1252 return std::nullopt;
1253
1254 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1255 &result.strides);
1256 AffineExpr N = m.dim(0);
1257 AffineExpr F = m.dim(1);
1258 AffineExpr D = m.dim(2);
1259 AffineExpr H = m.dim(3);
1260 AffineExpr W = m.dim(4);
1261 AffineExpr c = m.dim(5);
1262 AffineExpr d = m.dim(6);
1263 AffineExpr h = m.dim(7);
1264 AffineExpr w = m.dim(8);
1265
1266 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
1267 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
1268 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
1269 .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0),
1270 m.strided(H, h, 1), m.strided(W, w, 2)},
1271 /*filterMap=*/{F, c, d, h, w},
1272 /*outputMap=*/{N, F, D, H, W}})
1273 .matchBody())
1274 return result;
1275 return std::nullopt;
1276}
1277
1278template <>
1279std::optional<DilationsAndStrides>
1282 if (auto convOp =
1283 dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
1284 result.dilations =
1285 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1286 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1287 return result;
1288 }
1289
1291 return std::nullopt;
1292
1293 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1294 &result.strides);
1295 AffineExpr N = m.dim(0);
1296 AffineExpr W = m.dim(1);
1297 AffineExpr C = m.dim(2);
1298 AffineExpr w = m.dim(3);
1299
1300 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1301 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1302 /*filterMap=*/{C, w},
1303 /*outputMap=*/{N, C, W}})
1304 .matchBody())
1305 return result;
1306 return std::nullopt;
1307}
1308
1309template <>
1310std::optional<DilationsAndStrides>
1313 if (auto convOp =
1314 dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
1315 result.dilations =
1316 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1317 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1318 return result;
1319 }
1320
1322 return std::nullopt;
1323
1324 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1325 &result.strides);
1326 AffineExpr N = m.dim(0);
1327 AffineExpr W = m.dim(1);
1328 AffineExpr C = m.dim(2);
1329 AffineExpr w = m.dim(3);
1330
1331 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1332 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1333 /*filterMap=*/{w, C},
1334 /*outputMap=*/{N, W, C}})
1335 .matchBody())
1336 return result;
1337 return std::nullopt;
1338}
1339
1340template <>
1341std::optional<DilationsAndStrides>
1344 if (auto convOp =
1345 dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
1346 result.dilations =
1347 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1348 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1349 return result;
1350 }
1351
1353 return std::nullopt;
1354
1355 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1356 &result.strides);
1357 AffineExpr N = m.dim(0);
1358 AffineExpr W = m.dim(1);
1359 AffineExpr C = m.dim(2);
1360 AffineExpr CM = m.dim(3);
1361 AffineExpr w = m.dim(4);
1362
1363 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1364 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1365 /*filterMap=*/{w, C, CM},
1366 /*outputMap=*/{N, W, C, CM}})
1367 .matchBody())
1368 return result;
1369 return std::nullopt;
1370}
1371
1372template <>
1373std::optional<DilationsAndStrides>
1376 if (auto convOp =
1377 dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
1378 result.dilations =
1379 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1380 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1381 return result;
1382 }
1383
1385 return std::nullopt;
1386
1387 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1388 &result.strides);
1389 AffineExpr N = m.dim(0);
1390 AffineExpr H = m.dim(1);
1391 AffineExpr W = m.dim(2);
1392 AffineExpr C = m.dim(3);
1393 AffineExpr h = m.dim(4);
1394 AffineExpr w = m.dim(5);
1395
1396 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1397 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1398 .matchMaps(
1399 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1400 /*filterMap=*/{C, h, w},
1401 /*outputMap=*/{N, C, H, W}})
1402 .matchBody())
1403 return result;
1404 return std::nullopt;
1405}
1406
1407template <>
1408std::optional<DilationsAndStrides>
1411 if (auto convOp =
1412 dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
1413 result.dilations =
1414 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1415 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1416 return result;
1417 }
1418
1420 return std::nullopt;
1421
1422 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1423 &result.strides);
1424 AffineExpr N = m.dim(0);
1425 AffineExpr H = m.dim(1);
1426 AffineExpr W = m.dim(2);
1427 AffineExpr C = m.dim(3);
1428 AffineExpr h = m.dim(4);
1429 AffineExpr w = m.dim(5);
1430
1431 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1432 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1433 .matchMaps(
1434 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1435 /*filterMap=*/{h, w, C},
1436 /*outputMap=*/{N, H, W, C}})
1437 .matchBody())
1438 return result;
1439 return std::nullopt;
1440}
1441
1442template <>
1443std::optional<DilationsAndStrides>
1446 if (auto convOp =
1447 dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
1448 result.dilations =
1449 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1450 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1451 return result;
1452 }
1453
1455 return std::nullopt;
1456
1457 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1458 &result.strides);
1459 AffineExpr N = m.dim(0);
1460 AffineExpr H = m.dim(1);
1461 AffineExpr W = m.dim(2);
1462 AffineExpr C = m.dim(3);
1463 AffineExpr h = m.dim(4);
1464 AffineExpr w = m.dim(5);
1465
1466 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1467 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1468 .matchMaps(
1469 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1470 /*filterMap=*/{h, w, C},
1471 /*scalarMap=*/{},
1472 /*scalarMap=*/{},
1473 /*outputMap=*/{N, H, W, C}})
1474 .matchBody(/*containsZeroPointOffset=*/true))
1475 return result;
1476 return std::nullopt;
1477}
1478
1479template <>
1480std::optional<DilationsAndStrides>
1483 if (auto convOp =
1484 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
1485 result.dilations =
1486 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1487 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1488 return result;
1489 }
1490
1492 return std::nullopt;
1493
1494 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1495 &result.strides);
1496 AffineExpr N = m.dim(0);
1497 AffineExpr H = m.dim(1);
1498 AffineExpr W = m.dim(2);
1499 AffineExpr C = m.dim(3);
1500 AffineExpr CM = m.dim(4);
1501 AffineExpr h = m.dim(5);
1502 AffineExpr w = m.dim(6);
1503
1504 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1505 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1506 .matchMaps(
1507 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1508 /*filterMap=*/{h, w, C, CM},
1509 /*outputMap=*/{N, H, W, C, CM}})
1510 .matchBody())
1511 return result;
1512 return std::nullopt;
1513}
1514
1515template <>
1516std::optional<DilationsAndStrides>
1519 if (auto convOp =
1520 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
1521 result.dilations =
1522 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1523 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1524 return result;
1525 }
1526
1528 return std::nullopt;
1529
1530 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1531 &result.strides);
1532 AffineExpr N = m.dim(0);
1533 AffineExpr H = m.dim(1);
1534 AffineExpr W = m.dim(2);
1535 AffineExpr C = m.dim(3);
1536 AffineExpr CM = m.dim(4);
1537 AffineExpr h = m.dim(5);
1538 AffineExpr w = m.dim(6);
1539
1540 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1541 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1542 .matchMaps(
1543 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1544 /*filterMap=*/{h, w, C, CM},
1545 /*scalarMap=*/{},
1546 /*scalarMap=*/{},
1547 /*outputMap=*/{N, H, W, C, CM}})
1548 .matchBody(/*containsZeroPointOffset=*/true))
1549 return result;
1550 return std::nullopt;
1551}
1552
1553template <>
1554std::optional<DilationsAndStrides>
1557 if (auto convOp =
1558 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
1559 result.dilations =
1560 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1561 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1562 return result;
1563 }
1564
1566 return std::nullopt;
1567
1568 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1569 &result.strides);
1570 AffineExpr N = m.dim(0);
1571 AffineExpr D = m.dim(1);
1572 AffineExpr H = m.dim(2);
1573 AffineExpr W = m.dim(3);
1574 AffineExpr d = m.dim(4);
1575 AffineExpr h = m.dim(5);
1576 AffineExpr w = m.dim(6);
1577 AffineExpr C = m.dim(7);
1578
1579 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1580 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1581 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1582 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1583 m.strided(W, w, 2), C},
1584 /*filterMap=*/{d, h, w, C},
1585 /*outputMap=*/{N, D, H, W, C}})
1586 .matchBody())
1587 return result;
1588 return std::nullopt;
1589}
1590
1591template <>
1592std::optional<DilationsAndStrides>
1595 if (auto convOp =
1596 dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
1597 result.dilations =
1598 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1599 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1600 return result;
1601 }
1602
1604 return std::nullopt;
1605
1606 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1607 &result.strides);
1608 AffineExpr N = m.dim(0);
1609 AffineExpr D = m.dim(1);
1610 AffineExpr H = m.dim(2);
1611 AffineExpr W = m.dim(3);
1612 AffineExpr d = m.dim(4);
1613 AffineExpr h = m.dim(5);
1614 AffineExpr w = m.dim(6);
1615 AffineExpr C = m.dim(7);
1616
1617 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1618 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1619 .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
1620 .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0),
1621 m.strided(H, h, 1), m.strided(W, w, 2)},
1622 /*filterMap=*/{C, d, h, w},
1623 /*outputMap=*/{N, C, D, H, W}})
1624 .matchBody())
1625 return result;
1626 return std::nullopt;
1627}
1628
1629template <>
1630std::optional<DilationsAndStrides>
1633 if (auto convOp =
1634 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
1635 result.dilations =
1636 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1637 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1638 return result;
1639 }
1640
1642 return std::nullopt;
1643
1644 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1645 &result.strides);
1646 AffineExpr N = m.dim(0);
1647 AffineExpr D = m.dim(1);
1648 AffineExpr H = m.dim(2);
1649 AffineExpr W = m.dim(3);
1650 AffineExpr CM = m.dim(4);
1651 AffineExpr d = m.dim(5);
1652 AffineExpr h = m.dim(6);
1653 AffineExpr w = m.dim(7);
1654 AffineExpr C = m.dim(8);
1655
1656 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1657 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1658 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1659 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1660 m.strided(W, w, 2), C},
1661 /*filterMap=*/{d, h, w, C, CM},
1662 /*outputMap=*/{N, D, H, W, C, CM}})
1663 .matchBody())
1664 return result;
1665 return std::nullopt;
1666}
1667
1668template <>
1669std::optional<DilationsAndStrides>
1672 if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
1673 result.dilations =
1674 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1675 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1676 return result;
1677 }
1678
1680 return std::nullopt;
1681
1682 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1683 &result.strides, PoolingType::MaxSigned);
1684 AffineExpr N = m.dim(0);
1685 AffineExpr H = m.dim(1);
1686 AffineExpr W = m.dim(2);
1687 AffineExpr C = m.dim(3);
1688 AffineExpr h = m.dim(4);
1689 AffineExpr w = m.dim(5);
1690
1691 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1692 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1693 .matchMaps(
1694 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1695 /*filterMap=*/{h, w},
1696 /*outputMap=*/{N, H, W, C}})
1697 .matchBody())
1698 return result;
1699 return std::nullopt;
1700}
1701
1702template <>
1703std::optional<DilationsAndStrides>
1706 if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
1707 result.dilations =
1708 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1709 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1710 return result;
1711 }
1712
1714 return std::nullopt;
1715
1716 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1717 &result.strides, PoolingType::MinSigned);
1718 AffineExpr N = m.dim(0);
1719 AffineExpr H = m.dim(1);
1720 AffineExpr W = m.dim(2);
1721 AffineExpr C = m.dim(3);
1722 AffineExpr h = m.dim(4);
1723 AffineExpr w = m.dim(5);
1724
1725 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1726 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1727 .matchMaps(
1728 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1729 /*filterMap=*/{h, w},
1730 /*outputMap=*/{N, H, W, C}})
1731 .matchBody())
1732 return result;
1733 return std::nullopt;
1734}
1735
1736template <>
1737std::optional<DilationsAndStrides>
1740 if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
1741 result.dilations =
1742 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1743 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1744 return result;
1745 }
1746
1748 return std::nullopt;
1749
1750 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1751 &result.strides, PoolingType::Sum);
1752 AffineExpr N = m.dim(0);
1753 AffineExpr H = m.dim(1);
1754 AffineExpr W = m.dim(2);
1755 AffineExpr C = m.dim(3);
1756 AffineExpr h = m.dim(4);
1757 AffineExpr w = m.dim(5);
1758
1759 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1760 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1761 .matchMaps(
1762 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1763 /*filterMap=*/{h, w},
1764 /*outputMap=*/{N, H, W, C}})
1765 .matchBody())
1766 return result;
1767 return std::nullopt;
1768}
1769
1770template <>
1771std::optional<DilationsAndStrides>
1774 if (auto poolOp =
1775 dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
1776 result.dilations =
1777 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1778 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1779 return result;
1780 }
1781
1783 return std::nullopt;
1784
1785 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1787 AffineExpr N = m.dim(0);
1788 AffineExpr H = m.dim(1);
1789 AffineExpr W = m.dim(2);
1790 AffineExpr C = m.dim(3);
1791 AffineExpr h = m.dim(4);
1792 AffineExpr w = m.dim(5);
1793
1794 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1795 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1796 .matchMaps(
1797 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1798 /*filterMap=*/{h, w},
1799 /*outputMap=*/{N, H, W, C}})
1800 .matchBody())
1801 return result;
1802 return std::nullopt;
1803}
1804
1805template <>
1806std::optional<DilationsAndStrides>
1809 if (auto poolOp =
1810 dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
1811 result.dilations =
1812 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1813 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1814 return result;
1815 }
1816
1818 return std::nullopt;
1819
1820 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1822 AffineExpr N = m.dim(0);
1823 AffineExpr H = m.dim(1);
1824 AffineExpr W = m.dim(2);
1825 AffineExpr C = m.dim(3);
1826 AffineExpr h = m.dim(4);
1827 AffineExpr w = m.dim(5);
1828
1829 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1830 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1831 .matchMaps(
1832 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1833 /*filterMap=*/{h, w},
1834 /*outputMap=*/{N, H, W, C}})
1835 .matchBody())
1836 return result;
1837 return std::nullopt;
1838}
1839
1840template <>
1841std::optional<DilationsAndStrides>
1844 if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
1845 result.dilations =
1846 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1847 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1848 return result;
1849 }
1850
1852 return std::nullopt;
1853
1854 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1855 &result.strides, PoolingType::Sum);
1856 AffineExpr N = m.dim(0);
1857 AffineExpr C = m.dim(1);
1858 AffineExpr H = m.dim(2);
1859 AffineExpr W = m.dim(3);
1860 AffineExpr h = m.dim(4);
1861 AffineExpr w = m.dim(5);
1862
1863 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1864 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1865 .matchMaps(
1866 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1867 /*filterMap=*/{h, w},
1868 /*outputMap=*/{N, C, H, W}})
1869 .matchBody())
1870 return result;
1871 return std::nullopt;
1872}
1873
1874template <>
1875std::optional<DilationsAndStrides>
1878 if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
1879 result.dilations =
1880 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1881 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1882 return result;
1883 }
1884
1886 return std::nullopt;
1887
1888 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1889 &result.strides, PoolingType::MaxSigned);
1890 AffineExpr N = m.dim(0);
1891 AffineExpr C = m.dim(1);
1892 AffineExpr H = m.dim(2);
1893 AffineExpr W = m.dim(3);
1894 AffineExpr h = m.dim(4);
1895 AffineExpr w = m.dim(5);
1896
1897 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1898 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1899 .matchMaps(
1900 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1901 /*filterMap=*/{h, w},
1902 /*outputMap=*/{N, C, H, W}})
1903 .matchBody())
1904 return result;
1905 return std::nullopt;
1906}
1907
1908template <>
1909std::optional<DilationsAndStrides>
1912 if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
1913 result.dilations =
1914 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1915 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1916 return result;
1917 }
1918
1920 return std::nullopt;
1921
1922 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1923 &result.strides, PoolingType::Sum);
1924 AffineExpr N = m.dim(0);
1925 AffineExpr W = m.dim(1);
1926 AffineExpr C = m.dim(2);
1927 AffineExpr w = m.dim(3);
1928
1929 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1930 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1931 /*filterMap=*/{w},
1932 /*outputMap=*/{N, W, C}})
1933 .matchBody())
1934 return result;
1935 return std::nullopt;
1936}
1937
1938template <>
1939std::optional<DilationsAndStrides>
1942 if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
1943 result.dilations =
1944 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1945 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1946 return result;
1947 }
1948
1950 return std::nullopt;
1951
1952 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1953 &result.strides, PoolingType::Sum);
1954 AffineExpr N = m.dim(0);
1955 AffineExpr C = m.dim(1);
1956 AffineExpr W = m.dim(2);
1957 AffineExpr w = m.dim(3);
1958
1959 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1960 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1961 /*filterMap=*/{w},
1962 /*outputMap=*/{N, C, W}})
1963 .matchBody())
1964 return result;
1965 return std::nullopt;
1966}
1967
1968template <>
1969std::optional<DilationsAndStrides>
1972 if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
1973 result.dilations =
1974 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1975 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1976 return result;
1977 }
1978
1980 return std::nullopt;
1981
1982 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1983 &result.strides, PoolingType::MaxSigned);
1984 AffineExpr N = m.dim(0);
1985 AffineExpr W = m.dim(1);
1986 AffineExpr C = m.dim(2);
1987 AffineExpr w = m.dim(3);
1988
1989 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1990 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1991 /*filterMap=*/{w},
1992 /*outputMap=*/{N, W, C}})
1993 .matchBody())
1994 return result;
1995 return std::nullopt;
1996}
1997
1998template <>
1999std::optional<DilationsAndStrides>
2002 if (auto poolOp =
2003 dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
2004 result.dilations =
2005 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2006 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2007 return result;
2008 }
2009
2011 return std::nullopt;
2012
2013 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2015 AffineExpr N = m.dim(0);
2016 AffineExpr W = m.dim(1);
2017 AffineExpr C = m.dim(2);
2018 AffineExpr w = m.dim(3);
2019
2020 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2021 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2022 /*filterMap=*/{w},
2023 /*outputMap=*/{N, W, C}})
2024 .matchBody())
2025 return result;
2026 return std::nullopt;
2027}
2028
2029template <>
2030std::optional<DilationsAndStrides>
2033 if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
2034 result.dilations =
2035 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2036 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2037 return result;
2038 }
2039
2041 return std::nullopt;
2042
2043 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2044 &result.strides, PoolingType::MaxSigned);
2045 AffineExpr N = m.dim(0);
2046 AffineExpr C = m.dim(1);
2047 AffineExpr W = m.dim(2);
2048 AffineExpr w = m.dim(3);
2049
2050 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
2051 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
2052 /*filterMap=*/{w},
2053 /*outputMap=*/{N, C, W}})
2054 .matchBody())
2055 return result;
2056 return std::nullopt;
2057}
2058
2059template <>
2060std::optional<DilationsAndStrides>
2063 if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
2064 result.dilations =
2065 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2066 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2067 return result;
2068 }
2069
2071 return std::nullopt;
2072
2073 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2074 &result.strides, PoolingType::MinSigned);
2075 AffineExpr N = m.dim(0);
2076 AffineExpr W = m.dim(1);
2077 AffineExpr C = m.dim(2);
2078 AffineExpr w = m.dim(3);
2079
2080 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2081 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2082 /*filterMap=*/{w},
2083 /*outputMap=*/{N, W, C}})
2084 .matchBody())
2085 return result;
2086 return std::nullopt;
2087}
2088
2089template <>
2090std::optional<DilationsAndStrides>
2093 if (auto poolOp =
2094 dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
2095 result.dilations =
2096 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2097 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2098 return result;
2099 }
2100
2102 return std::nullopt;
2103
2104 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2106 AffineExpr N = m.dim(0);
2107 AffineExpr W = m.dim(1);
2108 AffineExpr C = m.dim(2);
2109 AffineExpr w = m.dim(3);
2110
2111 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2112 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2113 /*filterMap=*/{w},
2114 /*outputMap=*/{N, W, C}})
2115 .matchBody())
2116 return result;
2117 return std::nullopt;
2118}
2119
2120template <>
2121std::optional<DilationsAndStrides>
2124 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
2125 result.dilations =
2126 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2127 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2128 return result;
2129 }
2130
2132 return std::nullopt;
2133
2134 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2135 &result.strides, PoolingType::Sum);
2136 AffineExpr N = m.dim(0);
2137 AffineExpr D = m.dim(1);
2138 AffineExpr H = m.dim(2);
2139 AffineExpr W = m.dim(3);
2140 AffineExpr C = m.dim(4);
2141 AffineExpr d = m.dim(5);
2142 AffineExpr h = m.dim(6);
2143 AffineExpr w = m.dim(7);
2144
2145 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2146 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2147 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2148 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2149 m.strided(W, w, 2), C},
2150 /*filterMap=*/{d, h, w},
2151 /*outputMap=*/{N, D, H, W, C}})
2152 .matchBody())
2153 return result;
2154 return std::nullopt;
2155}
2156
2157template <>
2158std::optional<DilationsAndStrides>
2161 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
2162 result.dilations =
2163 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2164 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2165 return result;
2166 }
2167
2169 return std::nullopt;
2170
2171 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2172 &result.strides, PoolingType::MaxSigned);
2173 AffineExpr N = m.dim(0);
2174 AffineExpr D = m.dim(1);
2175 AffineExpr H = m.dim(2);
2176 AffineExpr W = m.dim(3);
2177 AffineExpr C = m.dim(4);
2178 AffineExpr d = m.dim(5);
2179 AffineExpr h = m.dim(6);
2180 AffineExpr w = m.dim(7);
2181
2182 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2183 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2184 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2185 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2186 m.strided(W, w, 2), C},
2187 /*filterMap=*/{d, h, w},
2188 /*outputMap=*/{N, D, H, W, C}})
2189 .matchBody())
2190 return result;
2191 return std::nullopt;
2192}
2193
2194template <>
2195std::optional<DilationsAndStrides>
2198 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
2199 result.dilations =
2200 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2201 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2202 return result;
2203 }
2204
2206 return std::nullopt;
2207
2208 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2209 &result.strides, PoolingType::MinSigned);
2210 AffineExpr N = m.dim(0);
2211 AffineExpr D = m.dim(1);
2212 AffineExpr H = m.dim(2);
2213 AffineExpr W = m.dim(3);
2214 AffineExpr C = m.dim(4);
2215 AffineExpr d = m.dim(5);
2216 AffineExpr h = m.dim(6);
2217 AffineExpr w = m.dim(7);
2218
2219 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2220 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2221 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2222 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2223 m.strided(W, w, 2), C},
2224 /*filterMap=*/{d, h, w},
2225 /*outputMap=*/{N, D, H, W, C}})
2226 .matchBody())
2227 return result;
2228 return std::nullopt;
2229}
2230
2231Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
2232 Value source, Value pad, bool nofold,
2233 ValueRange typeDynDims) {
2234 // Exit if `source` is not defined by an ExtractSliceOp.
2235 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
2236 if (!sliceOp)
2237 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2238 typeDynDims);
2239
2240 // Search the `source` use-def chain for padded LinalgOps.
2241 Value current = sliceOp.getSource();
2242 while (current) {
2243 auto linalgOp = current.getDefiningOp<LinalgOp>();
2244 if (!linalgOp)
2245 break;
2246 OpResult opResult = cast<OpResult>(current);
2247 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
2248 }
2249 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
2250
2251 // Exit if the search fails to match a tensor::PadOp at the end of the matched
2252 // LinalgOp sequence.
2253 if (!padOp)
2254 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2255 typeDynDims);
2256
2257 // Exit if the padded result type does not match.
2258 if (sliceOp.getSource().getType() != type)
2259 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2260 typeDynDims);
2261
2262 // Exit if the LinalgOps are not high padded.
2263 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
2264 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
2265 }))
2266 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2267 typeDynDims);
2268
2269 // Exit if `padOpSliceOp`, which defines the slice used by
2270 // `padOp`, is rank-reducing.
2271 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
2272 if (!padOpSliceOp ||
2273 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
2274 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2275 typeDynDims);
2276
2277 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
2278 // of the slice padded by `padOp`.
2279 if (llvm::any_of(
2280 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
2281 [](std::tuple<OpFoldResult, OpFoldResult> it) {
2282 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
2283 }))
2284 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2285 typeDynDims);
2286
2287 // Exit if the padding values do not match.
2288 Attribute padOpPadAttr, padAttr;
2289 Value padOpPad = padOp.getConstantPaddingValue();
2290 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
2291 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
2292 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2293 typeDynDims);
2294
2295 // Return the padded result if the padding values and sizes match.
2296 return sliceOp.getSource();
2297}
2298
2299GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
2300 auto memrefTypeTo = cast<MemRefType>(to.getType());
2301#ifndef NDEBUG
2302 auto memrefTypeFrom = cast<MemRefType>(from.getType());
2303 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
2304 "`from` and `to` memref must have the same rank");
2305#endif // NDEBUG
2306
2307 AffineMap id =
2308 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
2309 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
2310 utils::IteratorType::parallel);
2311 return linalg::GenericOp::create(
2312 b, loc,
2313 /*inputs=*/from,
2314 /*outputs=*/to,
2315 /*indexingMaps=*/llvm::ArrayRef({id, id}),
2316 /*iteratorTypes=*/iteratorTypes,
2317 [](OpBuilder &b, Location loc, ValueRange args) {
2318 linalg::YieldOp::create(b, loc, args.front());
2319 });
2320}
2321
2322/// Specialization to build an scf "for" nest.
2323template <>
2325 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2326 ArrayRef<utils::IteratorType> iteratorTypes,
2328 ValueRange)>
2329 bodyBuilderFn,
2330 ArrayRef<linalg::ProcInfo> procInfo) {
2331 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2332 "expected as many entries for proc info as number of loops, even if "
2333 "they are null entries");
2334 SmallVector<Value> iterArgInitValues;
2335 if (!linalgOp.hasPureBufferSemantics())
2336 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2337 SmallVector<Value, 4> lbs, ubs, steps;
2338 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
2340 b, loc, lbs, ubs, steps, iterArgInitValues,
2341 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
2342 assert(iterArgs.size() == iterArgInitValues.size() &&
2343 "expect the number of output tensors and iter args to match");
2344 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
2345 if (!iterArgs.empty()) {
2346 operandValuesToUse = linalgOp.getDpsInputs();
2347 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
2348 }
2349 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
2350 });
2351
2352 if (loopNest.loops.empty() || procInfo.empty())
2353 return;
2354
2355 // Filter out scf.for loops that were created out of parallel dimensions.
2356 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
2357 if (procInfo[loop.index()].distributionMethod ==
2359 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
2360 procInfo[loop.index()].nprocs);
2361 }
2362 }
2363}
2364
2365/// Specialization to build affine "for" nest.
2366template <>
2368 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2369 ArrayRef<utils::IteratorType> iteratorTypes,
2371 ValueRange)>
2372 bodyBuilderFn,
2373 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
2374 SmallVector<Value> iterArgInitValues;
2375 if (!linalgOp.hasPureBufferSemantics())
2376 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2377 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
2378 SmallVector<Value, 4> lbs, ubs, steps;
2379 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
2380
2381 // Affine loops require constant steps.
2382 SmallVector<int64_t, 4> constantSteps;
2383 constantSteps.reserve(steps.size());
2384 for (Value v : steps) {
2385 auto constVal = getConstantIntValue(v);
2386 assert(constVal.has_value() && "Affine loops require constant steps");
2387 constantSteps.push_back(constVal.value());
2388 }
2389
2390 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
2391 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2392 bodyBuilderFn(b, loc, ivs,
2393 linalgOp->getOperands());
2394 });
2395}
2396
2397/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
2399 Value nprocs, Value &lb, Value &ub,
2400 Value &step) {
2401 AffineExpr d0, d1;
2402 bindDims(b.getContext(), d0, d1);
2403 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
2404 lb =
2405 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
2406 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
2407}
2408
2409/// Generates a loop nest consisting of scf.parallel and scf.for, depending
2410/// on the `iteratorTypes.` Consecutive parallel loops create a single
2411/// scf.parallel operation; each sequential loop creates a new scf.for
2412/// operation. The body of the innermost loop is populated by
2413/// `bodyBuilderFn` that accepts a range of induction variables for all
2414/// loops. `ivStorage` is used to store the partial list of induction
2415/// variables.
2416// TODO: this function can be made iterative instead. However, it
2417// will have at most as many recursive calls as nested loops, which rarely
2418// exceeds 10.
2420 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
2421 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
2423 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2424 SmallVectorImpl<Value> &ivStorage) {
2425 assert(lbs.size() == ubs.size());
2426 assert(lbs.size() == steps.size());
2427 assert(lbs.size() == iteratorTypes.size());
2428 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
2429
2430 // If there are no (more) loops to be generated, generate the body and be
2431 // done with it.
2432 if (iteratorTypes.empty()) {
2433 bodyBuilderFn(b, loc, ivStorage);
2434 return;
2435 }
2436
2437 // If there are no outer parallel loops, generate one sequential loop and
2438 // recurse.
2439 if (!isParallelIterator(iteratorTypes.front())) {
2440 LoopNest singleLoop = buildLoopNest(
2441 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
2442 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2443 ivStorage.append(ivs.begin(), ivs.end());
2444 generateParallelLoopNest(
2445 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
2446 iteratorTypes.drop_front(),
2447 procInfo.empty() ? procInfo : procInfo.drop_front(),
2448 bodyBuilderFn, ivStorage);
2449 });
2450 return;
2451 }
2452
2453 unsigned nLoops = iteratorTypes.size();
2454 unsigned numProcessed = 0;
2455 DistributionMethod distributionMethod = DistributionMethod::None;
2456 if (procInfo.empty()) {
2457 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
2458 } else {
2459 distributionMethod = procInfo.front().distributionMethod;
2460 numProcessed =
2461 nLoops - procInfo
2462 .drop_while([&](linalg::ProcInfo p) {
2463 return p.distributionMethod == distributionMethod;
2464 })
2465 .size();
2466 }
2467
2468 auto remainderProcInfo =
2469 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
2470 switch (distributionMethod) {
2472 // Generate a single parallel loop-nest operation for all outermost
2473 // parallel loops and recurse.
2474 scf::ParallelOp::create(
2475 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2476 steps.take_front(numProcessed),
2477 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2478 ivStorage.append(localIvs.begin(), localIvs.end());
2479 generateParallelLoopNest(
2480 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2481 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2482 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2483 bodyBuilderFn, ivStorage);
2484 });
2485 return;
2486 }
2488 // Generate a single parallel loop-nest operation for all outermost
2489 // parallel loops and recurse.
2490 scf::ParallelOp::create(
2491 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2492 steps.take_front(numProcessed),
2493 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2494 ivStorage.append(localIvs.begin(), localIvs.end());
2495 generateParallelLoopNest(
2496 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2497 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2498 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2499 bodyBuilderFn, ivStorage);
2500 });
2501 return;
2502 }
2504 // Check (for the processed loops) that the iteration is in-bounds.
2505 ArithBuilder ab(b, loc);
2506 Value cond = ab.slt(lbs[0], ubs[0]);
2507 for (unsigned i = 1; i < numProcessed; ++i)
2508 cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
2509 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2510 scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) {
2511 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
2512 ubs.drop_front(numProcessed),
2513 steps.drop_front(numProcessed),
2514 iteratorTypes.drop_front(numProcessed),
2515 remainderProcInfo, bodyBuilderFn, ivStorage);
2516 scf::YieldOp::create(b, loc, ValueRange{});
2517 });
2518 return;
2519 }
2521 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
2522 // with inner loop generation.
2523 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2525 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
2526 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
2527 remainderProcInfo, bodyBuilderFn, ivStorage);
2528 return;
2529 }
2530}
2531
2532/// Specialization for generating a mix of parallel and sequential scf loops.
2533template <>
2535 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2536 ArrayRef<utils::IteratorType> iteratorTypes,
2538 ValueRange)>
2539 bodyBuilderFn,
2540 ArrayRef<linalg::ProcInfo> procInfo) {
2541 SmallVector<Value> iterArgInitValues;
2542 if (!linalgOp.hasPureBufferSemantics())
2543 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2544 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
2545 // This function may be passed more iterator types than ranges.
2546 assert(iteratorTypes.size() >= loopRanges.size() &&
2547 "expected iterator type for all ranges");
2548 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2549 "expected proc information for all loops when present");
2550 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
2551 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
2552 unsigned numLoops = iteratorTypes.size();
2553 ivs.reserve(numLoops);
2554 lbsStorage.reserve(numLoops);
2555 ubsStorage.reserve(numLoops);
2556 stepsStorage.reserve(numLoops);
2557
2558 // Get the loop lb, ub, and step.
2559 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
2560
2561 // Modify the lb, ub, and step based on the distribution options.
2562 for (const auto &it : llvm::enumerate(procInfo)) {
2563 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
2565 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
2566 ubsStorage[it.index()], stepsStorage[it.index()]);
2567 }
2568 }
2569 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
2571 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
2572 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2573 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
2574 },
2575 ivs);
2576
2577 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
2578}
2579
2581 Value valueToTile,
2582 const SliceParameters &sliceParams) {
2583 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2584 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
2585 .Case([&](MemRefType) {
2586 return memref::SubViewOp::create(
2587 builder, loc, valueToTile, sliceParams.offsets,
2588 sliceParams.sizes, sliceParams.strides);
2589 })
2590 .Case([&](RankedTensorType) {
2591 return tensor::ExtractSliceOp::create(
2592 builder, loc, valueToTile, sliceParams.offsets,
2593 sliceParams.sizes, sliceParams.strides);
2594 })
2595 .DefaultUnreachable("Unexpected shaped type");
2596 return sliceOp;
2597}
2598
2600 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2603 ArrayRef<OpFoldResult> subShapeSizes,
2604 bool omitPartialTileCheck) {
2605 SliceParameters sliceParams =
2606 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
2607 ubs, subShapeSizes, omitPartialTileCheck);
2608 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
2609}
2610
2613 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2615 ArrayRef<OpFoldResult> subShapeSizes,
2616 bool omitPartialTileCheck) {
2617 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2618 assert(shapedType && "only shaped types can be tiled");
2619 ArrayRef<int64_t> shape = shapedType.getShape();
2620 int64_t rank = shapedType.getRank();
2621
2622 // Compute offsets/sizes/strides for the tile.
2623 SliceParameters sliceParams;
2624 sliceParams.offsets.reserve(rank);
2625 sliceParams.sizes.reserve(rank);
2626 sliceParams.strides.reserve(rank);
2627 for (unsigned r = 0; r < rank; ++r) {
2628 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
2629 if (!isTiled(map.getSubMap({r}), tileSizes)) {
2630 sliceParams.offsets.push_back(builder.getIndexAttr(0));
2631 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
2632 sliceParams.sizes.push_back(dim);
2633 sliceParams.strides.push_back(builder.getIndexAttr(1));
2634 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
2635 continue;
2636 }
2637 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
2638
2639 // Tiling creates a new slice at the proper index, the slice step is 1
2640 // (i.e. the op does not subsample, stepping occurs in the loop).
2641 auto m = map.getSubMap({r});
2642 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
2643 IRRewriter rewriter(builder);
2644 // The offset of the slice is m(lbs) - m(0).
2645 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
2646 SmallVector<Attribute> mAtZero;
2647 [[maybe_unused]] auto res = m.constantFold(zeros, mAtZero);
2648 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
2649 int64_t mAtZeroInt =
2650 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
2652 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
2653 sliceParams.offsets.push_back(offset);
2654
2655 OpFoldResult closedIntSize =
2656 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
2657 // Resulting size needs to be made half open interval again.
2658 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
2659 OpFoldResult size =
2660 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
2661 LLVM_DEBUG(llvm::dbgs()
2662 << "computeSliceParameters: raw size: " << size << "\n");
2663 LLVM_DEBUG(llvm::dbgs()
2664 << "computeSliceParameters: new offset: " << offset << "\n");
2665 sliceParams.strides.push_back(builder.getIndexAttr(1));
2666
2667 if (omitPartialTileCheck) {
2668 // We statically know that the partial/boundary tile condition is
2669 // unnecessary.
2670 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2671 sliceParams.sizes.push_back(size);
2672 continue;
2673 }
2674
2675 // The size of the subview / extract_slice should be trimmed to avoid
2676 // out-of-bounds accesses, unless:
2677 // a. We statically know the subshape size divides the shape size evenly.
2678 // b. The subshape size is 1. According to the way the loops are set up,
2679 // tensors with "0" dimensions would never be constructed.
2680 int64_t shapeSize = shape[r];
2681 std::optional<int64_t> sizeCst = getConstantIntValue(size);
2682 auto hasTileSizeOne = sizeCst == 1;
2683 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2684 ((shapeSize % *sizeCst) == 0);
2685 if (!hasTileSizeOne && !dividesEvenly) {
2686 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
2687 << ", size: " << size
2688 << ": make sure in bound with affine.min\n");
2689
2690 AffineExpr dim0, dim1, dim2;
2691 MLIRContext *context = builder.getContext();
2692 bindDims(context, dim0, dim1, dim2);
2693
2694 // Get the dimension size for this dimension. We need to first calculate
2695 // the max index and then plus one. This is important because for
2696 // convolution ops, we have its input window dimension's affine map of the
2697 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
2698 // dimension and `s0` is stride. Directly use the dimension size of
2699 // output/filer window dimensions will cause incorrect calculation.
2701 {ArrayRef<AffineExpr>{dim0 - 1}}, context)
2702 .front();
2704 {ArrayRef<AffineExpr>{dim0 + 1}}, context)
2705 .front();
2706 SmallVector<OpFoldResult> maxIndices =
2707 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
2708 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
2709 {ub});
2710 }));
2711 OpFoldResult maxIndex =
2712 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
2713 OpFoldResult d =
2714 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
2715
2716 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
2718 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
2719 .front();
2720 size =
2721 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
2722 }
2723 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2724 sliceParams.sizes.push_back(size);
2725 }
2726 return sliceParams;
2727}
2728
2731 ArrayRef<OpFoldResult> tileSizes) {
2733 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2734 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
2735 bool isTiled = !isZeroInteger(tileSizes[idx]);
2736 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
2737 LLVM_DEBUG(llvm::dbgs()
2738 << "computeTileOffsets: " << offsets.back() << "\n");
2739 }
2740 return offsets;
2741}
2742
2744 ArrayRef<OpFoldResult> tileSizes,
2745 ArrayRef<OpFoldResult> sizeBounds) {
2747 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2748 bool isTiled = !isZeroInteger(tileSizes[idx]);
2749 // Before composing, we need to make range a closed interval.
2750 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
2751 AffineExpr d0 = getAffineDimExpr(0, b.getContext());
2752 IRRewriter rewriter(b);
2753 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
2754 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
2755 }
2756 return sizes;
2757}
2758
2760 if (op.hasPureBufferSemantics())
2761 return {};
2762 return llvm::to_vector(
2763 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
2764 return operands[opOperand.getOperandNumber()].getType();
2765 }));
2766}
2767
2769 LinalgOp op, ValueRange operands,
2770 ValueRange results) {
2771 if (op.hasPureBufferSemantics())
2772 return {};
2773 SmallVector<Value> tensorResults;
2774 tensorResults.reserve(results.size());
2775 // Insert a insert_slice for each output tensor.
2776 unsigned resultIdx = 0;
2777 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2778 // TODO: use an interface/adaptor to avoid leaking position in
2779 // `tiledOperands`.
2780 Value outputTensor = operands[opOperand.getOperandNumber()];
2781 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
2782 Value inserted = tensor::InsertSliceOp::create(
2783 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2784 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2785 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2786 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2787 tensorResults.push_back(inserted);
2788 } else {
2789 tensorResults.push_back(results[resultIdx]);
2790 }
2791 ++resultIdx;
2792 }
2793 return tensorResults;
2794}
2795
2797computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
2798 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
2799 ArrayRef<OpFoldResult> tileSizes,
2800 ArrayRef<OpFoldResult> sizeBounds,
2801 bool omitPartialTileCheck) {
2802 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
2803 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2804 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
2805 "expected as many ivs as non-zero sizes");
2806
2807 // Construct (potentially temporary) mins and maxes on which to apply maps
2808 // that define tile subshapes.
2810 computeTileOffsets(builder, loc, ivs, tileSizes);
2811 SmallVector<OpFoldResult> subShapeSizes =
2812 computeTileSizes(builder, loc, tileSizes, sizeBounds);
2813
2814 assert(static_cast<int64_t>(valuesToTile.size()) <=
2815 linalgOp->getNumOperands() &&
2816 "more value to tile than operands.");
2818 allSliceParams.reserve(valuesToTile.size());
2819 for (auto [opOperand, val] :
2820 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2821 Value shapedOp = val;
2822 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
2823 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2824 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
2825 // an extract/insert slice pair for all output tensors simplifies follow up
2826 // transformations such as padding and bufferization since the
2827 // extract/insert slice pairs make the accessed iteration argument
2828 // subdomains explicit.
2829
2830 Type operandType = opOperand.get().getType();
2831 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2832 linalgOp.isDpsInit(&opOperand))) {
2833 allSliceParams.push_back(std::nullopt);
2834 LLVM_DEBUG(llvm::dbgs()
2835 << ": not tiled: use shape: " << operandType << "\n");
2836 continue;
2837 }
2838 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
2839
2840 allSliceParams.push_back(computeSliceParameters(
2841 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2842 omitPartialTileCheck));
2843 }
2844
2845 return allSliceParams;
2846}
2847
2849 LinalgOp linalgOp, ValueRange valuesToTile,
2851 ArrayRef<OpFoldResult> tileSizes,
2852 ArrayRef<OpFoldResult> sizeBounds,
2853 bool omitPartialTileCheck) {
2855 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
2856 tileSizes, sizeBounds, omitPartialTileCheck);
2857 SmallVector<Value> tiledShapes;
2858 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2859 Value valueToTile = std::get<0>(item);
2860 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2861 tiledShapes.push_back(
2862 sliceParams.has_value()
2863 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
2864 ->getResult(0)
2865 : valueToTile);
2866 }
2867 return tiledShapes;
2868}
2869
2870void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
2871 ArrayRef<OpFoldResult> offsets) {
2872 IRRewriter rewriter(b);
2873 offsetIndices(rewriter, linalgOp, offsets);
2874}
2875
2876void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
2877 ArrayRef<OpFoldResult> offsets) {
2878 if (!linalgOp.hasIndexSemantics())
2879 return;
2880
2881 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2882 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2883 continue;
2885 b.setInsertionPointAfter(indexOp);
2886 AffineExpr index, offset;
2887 bindDims(b.getContext(), index, offset);
2889 b, indexOp.getLoc(), index + offset,
2890 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2891 Value materialized =
2892 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
2893 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
2894 return use.getOwner() != materialized.getDefiningOp();
2895 });
2896 }
2897}
2898
2899/// Get the reassociation maps to fold the result of a extract_slice (or source
2900/// of a insert_slice) operation with given offsets, and sizes to its
2901/// rank-reduced version. This is only done for the cases where the size is 1
2902/// and offset is 0. Strictly speaking the offset 0 is not required in general,
2903/// but non-zero offsets are not handled by SPIR-V backend at this point (and
2904/// potentially cannot be handled).
2905std::optional<SmallVector<ReassociationIndices>>
2909 for (const auto &it : llvm::enumerate(mixedSizes)) {
2910 auto dim = it.index();
2911 auto size = it.value();
2912 curr.push_back(dim);
2913 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2914 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2915 continue;
2916 reassociation.emplace_back(ReassociationIndices{});
2917 std::swap(reassociation.back(), curr);
2918 }
2919 // When the reassociations are not empty, then fold the remaining
2920 // unit-dimensions into the last dimension. If the reassociations so far is
2921 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
2922 if (!curr.empty() && !reassociation.empty())
2923 reassociation.back().append(curr.begin(), curr.end());
2924 return reassociation;
2925}
2926
2927} // namespace linalg
2928} // namespace mlir
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
Definition Utils.cpp:150
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:74
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
Definition Utils.cpp:126
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
Definition Types.cpp:108
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Helper class for building convolution op matchers with minimal boilerplate.
Definition Utils.cpp:533
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
Definition Utils.cpp:561
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
Definition Utils.cpp:578
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
Definition Utils.cpp:555
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
Definition Utils.cpp:552
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
Definition Utils.cpp:542
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
Definition Utils.cpp:571
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op)
Definition Utils.cpp:1517
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op)
Definition Utils.cpp:1374
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op)
Definition Utils.cpp:1409
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op)
Definition Utils.cpp:1631
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwMaxOp >(LinalgOp op)
Definition Utils.cpp:2031
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2848
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinOp >(LinalgOp op)
Definition Utils.cpp:2061
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
Definition Utils.cpp:329
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op)
Definition Utils.cpp:1342
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op)
Definition Utils.cpp:667
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwSumOp >(LinalgOp op)
Definition Utils.cpp:1842
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
Definition Utils.cpp:2743
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op)
Definition Utils.cpp:1772
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
Definition Utils.cpp:2299
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op)
Definition Utils.cpp:1055
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
Definition Utils.cpp:2419
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
Definition Utils.cpp:2729
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op)
Definition Utils.cpp:1017
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
Definition Utils.cpp:507
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op)
Definition Utils.cpp:1280
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:394
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
Definition Utils.cpp:273
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:380
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op)
Definition Utils.cpp:1481
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op)
Definition Utils.cpp:1311
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
Definition Utils.cpp:201
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
Definition Utils.cpp:405
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
Definition Utils.cpp:244
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
Definition Utils.cpp:364
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op)
Definition Utils.cpp:908
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op)
Definition Utils.cpp:636
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2906
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinUnsignedOp >(LinalgOp op)
Definition Utils.cpp:2091
DistributionMethod
Scheme used to distribute loops to processors.
Definition Utils.h:276
@ None
No Distribution.
Definition Utils.h:321
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
Definition Utils.h:306
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
Definition Utils.h:288
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
Definition Utils.h:318
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:385
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:2768
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op)
Definition Utils.cpp:1593
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op)
Definition Utils.cpp:1555
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxUnsignedOp >(LinalgOp op)
Definition Utils.cpp:2000
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2870
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Matches sum pooling body pattern.
Definition Utils.cpp:400
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op)
Definition Utils.cpp:729
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op)
Definition Utils.cpp:1242
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
Definition Utils.cpp:2797
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op)
Definition Utils.cpp:698
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op)
Definition Utils.cpp:764
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op)
Definition Utils.cpp:873
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op)
Definition Utils.cpp:836
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
Definition Utils.cpp:2599
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op)
Definition Utils.cpp:1129
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwMaxOp >(LinalgOp op)
Definition Utils.cpp:1876
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
Definition Utils.cpp:494
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:389
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
Definition Utils.cpp:460
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op)
Definition Utils.cpp:1738
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op)
Definition Utils.cpp:1202
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op)
Definition Utils.cpp:1091
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op)
Definition Utils.cpp:1704
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
Definition Utils.cpp:2580
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op)
Definition Utils.cpp:945
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op)
Definition Utils.cpp:981
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op)
Definition Utils.cpp:1670
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMinOp >(LinalgOp op)
Definition Utils.cpp:2196
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op)
Definition Utils.cpp:1807
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition Utils.cpp:2231
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op)
Definition Utils.cpp:1164
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcSumOp >(LinalgOp op)
Definition Utils.cpp:1910
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op)
Definition Utils.cpp:608
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
Definition Utils.cpp:418
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcSumOp >(LinalgOp op)
Definition Utils.cpp:2122
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op)
Definition Utils.cpp:1444
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op)
Definition Utils.cpp:801
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxOp >(LinalgOp op)
Definition Utils.cpp:1970
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
Definition Utils.cpp:2398
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2759
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwSumOp >(LinalgOp op)
Definition Utils.cpp:1940
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:2612
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMaxOp >(LinalgOp op)
Definition Utils.cpp:2159
auto m_Val(Value v)
Definition Matchers.h:539
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:777
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
Definition Matchers.h:379
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:311
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:334
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A struct containing dilations and strides inferred from convolution ops.
Definition Utils.h:110
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
Definition Utils.h:390
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:326
DistributionMethod distributionMethod
Definition Utils.h:329
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
Definition Utils.cpp:93
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:172
SmallVector< OpFoldResult > strides
Definition Utils.h:175
SmallVector< OpFoldResult > sizes
Definition Utils.h:174
SmallVector< OpFoldResult > offsets
Definition Utils.h:173
LoopVector loops
Definition SCF.h:67